Coverage for src/dataknobs_data/backends/postgres.py: 13%
728 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-08 15:25 -0600
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-08 15:25 -0600
1"""PostgreSQL backend implementation with proper connection management and vector support."""
3from __future__ import annotations
5import json
6import logging
7import time
8import uuid
9from typing import TYPE_CHECKING, Any, cast
11import asyncpg
12from dataknobs_config import ConfigurableBase
14from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB
16from ..database import AsyncDatabase, SyncDatabase
17from ..pooling import ConnectionPoolManager
18from ..pooling.postgres import PostgresPoolConfig, create_asyncpg_pool, validate_asyncpg_pool
19from ..query import Operator, Query
20from ..query_logic import ComplexQuery
21from ..streaming import (
22 StreamConfig,
23 StreamResult,
24 async_process_batch_with_fallback,
25 process_batch_with_fallback,
26)
27from ..vector.mixins import VectorOperationsMixin
28from .postgres_mixins import (
29 PostgresBaseConfig,
30 PostgresConnectionValidator,
31 PostgresErrorHandler,
32 PostgresTableManager,
33 PostgresVectorSupport,
34)
35from .sql_base import SQLQueryBuilder, SQLRecordSerializer
37if TYPE_CHECKING:
38 import numpy as np
40 from collections.abc import AsyncIterator, Iterator, Callable, Awaitable
41 from ..fields import VectorField
42 from ..records import Record
43 from ..vector.types import DistanceMetric, VectorSearchResult
45logger = logging.getLogger(__name__)
48class SyncPostgresDatabase(
49 SyncDatabase,
50 ConfigurableBase,
51 VectorOperationsMixin,
52 SQLRecordSerializer,
53 PostgresBaseConfig,
54 PostgresTableManager,
55 PostgresVectorSupport,
56 PostgresConnectionValidator,
57 PostgresErrorHandler,
58):
59 """Synchronous PostgreSQL database backend with proper connection management."""
61 def __init__(self, config: dict[str, Any] | None = None):
62 """Initialize PostgreSQL database configuration.
64 Args:
65 config: Configuration with the following optional keys:
66 - host: PostgreSQL host (default: from env/localhost)
67 - port: PostgreSQL port (default: 5432)
68 - database: Database name (default: from env/postgres)
69 - user: Username (default: from env/postgres)
70 - password: Password (default: from env)
71 - table: Table name (default: "records")
72 - schema: Schema name (default: "public")
73 - enable_vector: Enable vector support (default: False)
74 """
75 super().__init__(config)
77 # Parse configuration using mixin
78 table_name, schema_name, conn_config = self._parse_postgres_config(config or {})
79 self._init_postgres_attributes(table_name, schema_name)
81 # Store connection config for later use
82 self._conn_config = conn_config
83 self.db = None # Will be initialized in connect()
84 self.query_builder = None # Will be initialized in connect()
86 @classmethod
87 def from_config(cls, config: dict) -> SyncPostgresDatabase:
88 """Create from config dictionary."""
89 return cls(config)
91 def connect(self) -> None:
92 """Connect to the PostgreSQL database."""
93 if self._connected:
94 return # Already connected
96 # Initialize query builder with pyformat style for psycopg2
97 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
99 # Create connection using existing utilities
100 if not any(key in self._conn_config for key in ["host", "database", "user"]):
101 # Use dotenv connector for environment-based config
102 connector = DotenvPostgresConnector()
103 self.db = PostgresDB(connector)
104 else:
105 # Direct configuration - map 'database' to 'db' for PostgresDB
106 self.db = PostgresDB(
107 host=self._conn_config.get("host", "localhost"),
108 db=self._conn_config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database'
109 user=self._conn_config.get("user", "postgres"),
110 pwd=self._conn_config.get("password"), # Note: PostgresDB expects 'pwd' not 'password'
111 port=self._conn_config.get("port", 5432),
112 )
114 # Create table if it doesn't exist
115 self._ensure_table()
117 # Detect and enable vector support if requested
118 if self.vector_enabled:
119 self._detect_vector_support()
121 self._connected = True
122 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}")
124 def close(self) -> None:
125 """Close the database connection."""
126 if self.db:
127 # PostgresDB manages its own connections via context managers
128 # but we can mark as disconnected
129 self._connected = False # type: ignore[unreachable]
131 def _initialize(self) -> None:
132 """Initialize method - connection setup moved to connect()."""
133 # Configuration parsing stays here, actual connection in connect()
134 pass
136 def _detect_vector_support(self) -> None:
137 """Detect and enable vector support if pgvector is available."""
138 from .postgres_vector import check_pgvector_extension_sync, install_pgvector_extension_sync
140 try:
141 # Check if pgvector is installed
142 if check_pgvector_extension_sync(self.db):
143 self._vector_enabled = True
144 logger.info("pgvector extension detected and enabled")
145 else:
146 # Try to install it
147 if install_pgvector_extension_sync(self.db):
148 self._vector_enabled = True
149 logger.info("pgvector extension installed and enabled")
150 else:
151 logger.debug("pgvector extension not available")
152 except Exception as e:
153 logger.debug(f"Could not enable vector support: {e}")
154 self._vector_enabled = False
156 def _ensure_table(self) -> None:
157 """Ensure the records table exists."""
158 if not self.db:
159 raise RuntimeError("Database not connected. Call connect() first.")
161 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) # type: ignore[unreachable]
162 self.db.execute(create_table_sql)
165 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]:
166 """Convert a Record to a database row."""
167 return {
168 "id": id or str(uuid.uuid4()),
169 "data": self.record_to_json(record),
170 "metadata": json.dumps(record.metadata) if record.metadata else None,
171 }
173 def _row_to_record(self, row: dict[str, Any]) -> Record:
174 """Convert a database row to a Record."""
175 return self.row_to_record(row)
177 def create(self, record: Record) -> str:
178 """Create a new record."""
179 self._check_connection()
180 # Use record's ID if it has one, otherwise generate a new one
181 id = record.id if record.id else str(uuid.uuid4())
182 row = self._record_to_row(record, id)
184 sql = f"""
185 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
186 VALUES (%(id)s, %(data)s, %(metadata)s)
187 """
188 self.db.execute(sql, row)
189 return id
191 def read(self, id: str) -> Record | None:
192 """Read a record by ID."""
193 self._check_connection()
194 sql = f"""
195 SELECT id, data, metadata
196 FROM {self.schema_name}.{self.table_name}
197 WHERE id = %(id)s
198 """
199 df = self.db.query(sql, {"id": id})
201 if df.empty:
202 return None
204 row = df.iloc[0].to_dict()
205 return self._row_to_record(row)
207 def update(self, id: str, record: Record) -> bool:
208 """Update an existing record."""
209 self._check_connection()
210 row = self._record_to_row(record, id)
212 sql = f"""
213 UPDATE {self.schema_name}.{self.table_name}
214 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP
215 WHERE id = %(id)s
216 """
217 result = self.db.execute(sql, row)
218 # PostgresDB.execute returns number of affected rows
219 return result > 0 if isinstance(result, int) else False
221 def delete(self, id: str) -> bool:
222 """Delete a record by ID."""
223 self._check_connection()
224 sql = f"""
225 DELETE FROM {self.schema_name}.{self.table_name}
226 WHERE id = %(id)s
227 """
228 result = self.db.execute(sql, {"id": id})
229 return result > 0 if isinstance(result, int) else False
231 def exists(self, id: str) -> bool:
232 """Check if a record exists."""
233 self._check_connection()
234 sql = f"""
235 SELECT 1 FROM {self.schema_name}.{self.table_name}
236 WHERE id = %(id)s
237 LIMIT 1
238 """
239 df = self.db.query(sql, {"id": id})
240 return not df.empty
242 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
243 """Update or insert a record.
245 Can be called as:
246 - upsert(id, record) - explicit ID and record
247 - upsert(record) - extract ID from record using Record's built-in logic
248 """
249 self._check_connection()
251 # Determine ID and record based on arguments
252 if isinstance(id_or_record, str):
253 id = id_or_record
254 if record is None:
255 raise ValueError("Record required when ID is provided")
256 else:
257 record = id_or_record
258 id = record.id
259 if id is None:
260 import uuid # type: ignore[unreachable]
261 id = str(uuid.uuid4())
262 record.storage_id = id
264 if self.exists(id):
265 self.update(id, record)
266 else:
267 # Insert with specific ID
268 row = self._record_to_row(record, id)
269 sql = f"""
270 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
271 VALUES (%(id)s, %(data)s, %(metadata)s)
272 """
273 self.db.execute(sql, row)
274 return id
276 def search(self, query: Query | ComplexQuery) -> list[Record]:
277 """Search for records matching the query."""
278 self._check_connection()
280 # Handle ComplexQuery with native SQL support
281 if isinstance(query, ComplexQuery):
282 sql_query, params_list = self.query_builder.build_complex_search_query(query)
283 else:
284 sql_query, params_list = self.query_builder.build_search_query(query)
286 # Build params dict for psycopg2
287 # The query builder now generates %(p0)s style placeholders directly
288 params_dict = {}
289 if params_list:
290 for i, param in enumerate(params_list):
291 params_dict[f"p{i}"] = param
293 # Execute query
294 df = self.db.query(sql_query, params_dict)
296 # Convert to records
297 records = []
298 for _, row in df.iterrows():
299 row_dict = row.to_dict()
300 record = self._row_to_record(row_dict)
302 # Populate storage_id from database ID
303 record.storage_id = str(row_dict['id'])
305 # Apply field projection if specified
306 if query.fields:
307 record = record.project(query.fields)
309 records.append(record)
311 return records
313 def _count_all(self) -> int:
314 """Count all records in the database."""
315 self._check_connection()
316 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}"
317 df = self.db.query(sql)
318 return int(df.iloc[0]["count"]) if not df.empty else 0
320 def clear(self) -> int:
321 """Clear all records from the database."""
322 self._check_connection()
323 # Get count first
324 count = self._count_all()
326 # Delete all records
327 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}"
328 self.db.execute(sql)
330 return count
332 def create_batch(self, records: list[Record]) -> list[str]:
333 """Create multiple records efficiently using a single query.
335 Uses multi-value INSERT for better performance.
337 Args:
338 records: List of records to create
340 Returns:
341 List of created record IDs
342 """
343 if not records:
344 return []
346 self._check_connection()
348 # Create a query builder for PostgreSQL with pyformat style
349 from .sql_base import SQLQueryBuilder
350 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
352 # Use the shared batch create query builder
353 query, params_list, ids = query_builder.build_batch_create_query(records)
355 # Build params dict for psycopg2
356 params_dict = {}
357 for i, param in enumerate(params_list):
358 params_dict[f"p{i}"] = param
360 # Execute the batch insert and get returned IDs
361 result_df = self.db.query(query, params_dict)
363 # PostgreSQL RETURNING clause gives us the actual inserted IDs
364 if not result_df.empty:
365 return result_df['id'].tolist()
366 return ids
368 def delete_batch(self, ids: list[str]) -> list[bool]:
369 """Delete multiple records efficiently using a single query.
371 Uses single DELETE with IN clause for better performance.
373 Args:
374 ids: List of record IDs to delete
376 Returns:
377 List of success flags for each deletion
378 """
379 if not ids:
380 return []
382 self._check_connection()
384 # Create a query builder for PostgreSQL with pyformat style
385 from .sql_base import SQLQueryBuilder
386 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
388 # Use the shared batch delete query builder (includes RETURNING clause)
389 query, params_list = query_builder.build_batch_delete_query(ids)
391 # Build params dict for psycopg2
392 params_dict = {}
393 for i, param in enumerate(params_list):
394 params_dict[f"p{i}"] = param
396 # Execute the batch delete and get returned IDs
397 result_df = self.db.query(query, params_dict)
399 # Get list of deleted IDs from RETURNING clause
400 deleted_ids = set(result_df['id'].tolist()) if not result_df.empty else set()
402 # Return results based on which IDs were actually deleted
403 results = []
404 for id in ids:
405 results.append(id in deleted_ids)
407 return results
409 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
410 """Update multiple records efficiently using a single query.
412 Uses PostgreSQL's CASE expressions for batch updates via shared SQL builder.
414 Args:
415 updates: List of (id, record) tuples to update
417 Returns:
418 List of success flags for each update
419 """
420 if not updates:
421 return []
423 self._check_connection()
425 # Create a query builder for PostgreSQL with pyformat style
426 from .sql_base import SQLQueryBuilder
427 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat")
429 # Use the shared batch update query builder
430 query, params_list = query_builder.build_batch_update_query(updates)
432 # Build params dict for psycopg2
433 params_dict = {}
434 for i, param in enumerate(params_list):
435 params_dict[f"p{i}"] = param
437 # Execute the batch update and get returned IDs (query now includes RETURNING clause)
438 result_df = self.db.query(query, params_dict)
440 # Get list of updated IDs from RETURNING clause
441 updated_ids = set(result_df['id'].tolist()) if not result_df.empty else set()
443 results = []
444 for record_id, _ in updates:
445 results.append(record_id in updated_ids)
447 return results
449 def stream_read(
450 self,
451 query: Query | None = None,
452 config: StreamConfig | None = None
453 ) -> Iterator[Record]:
454 """Stream records from PostgreSQL."""
455 self._check_connection()
456 config = config or StreamConfig()
458 # Build SQL query
459 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}"
460 params = {}
462 if query and query.filters:
463 # Add WHERE clause (simplified for now)
464 where_clauses = []
465 for i, filter in enumerate(query.filters):
466 field_path = f"data->>'{filter.field}'"
467 param_name = f"param_{i}"
469 if filter.operator == Operator.EQ:
470 where_clauses.append(f"{field_path} = %({param_name})s")
471 params[param_name] = str(filter.value)
473 if where_clauses:
474 sql += " WHERE " + " AND ".join(where_clauses)
476 # Use cursor for streaming
477 # Note: PostgresDB may need modification to support cursors
478 # For now, we'll fetch in batches
479 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s"
481 offset = 0
482 while True:
483 params["offset"] = offset
484 df = self.db.query(sql, params)
486 if df.empty:
487 break
489 for _, row in df.iterrows():
490 record = self._row_to_record(row.to_dict())
491 if query and query.fields:
492 record = record.project(query.fields)
493 yield record
495 offset += config.batch_size
497 # If we got less than batch_size, we're done
498 if len(df) < config.batch_size:
499 break
501 def stream_write(
502 self,
503 records: Iterator[Record],
504 config: StreamConfig | None = None
505 ) -> StreamResult:
506 """Stream records into PostgreSQL."""
507 self._check_connection()
508 config = config or StreamConfig()
509 result = StreamResult()
510 start_time = time.time()
511 quitting = False
513 batch = []
514 for record in records:
515 batch.append(record)
517 if len(batch) >= config.batch_size:
518 # Write batch with graceful fallback
519 # Use lambda wrapper for _write_batch
520 continue_processing = process_batch_with_fallback(
521 batch,
522 lambda b: self._write_batch(b),
523 self.create,
524 result,
525 config
526 )
528 if not continue_processing:
529 quitting = True
530 break
532 batch = []
534 # Write remaining batch
535 if batch and not quitting:
536 process_batch_with_fallback(
537 batch,
538 lambda b: self._write_batch(b),
539 self.create,
540 result,
541 config
542 )
544 result.duration = time.time() - start_time
545 return result
547 def _write_batch(self, records: list[Record]) -> list[str]:
548 """Write a batch of records to the database.
550 Returns:
551 List of created record IDs
552 """
553 # Build batch insert SQL
554 values = []
555 params = {}
556 ids = []
558 for i, record in enumerate(records):
559 id = str(uuid.uuid4())
560 ids.append(id)
561 row = self._record_to_row(record, id)
562 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)")
563 params[f"id_{i}"] = row["id"]
564 params[f"data_{i}"] = row["data"]
565 params[f"metadata_{i}"] = row["metadata"]
567 sql = f"""
568 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
569 VALUES {', '.join(values)}
570 """
571 self.db.execute(sql, params)
572 return ids
574 def vector_search(
575 self,
576 query_vector: np.ndarray | list[float] | VectorField,
577 field_name: str,
578 k: int = 10,
579 filter: Query | None = None,
580 metric: DistanceMetric | str = "cosine"
581 ) -> list[VectorSearchResult]:
582 """Search for similar vectors using PostgreSQL pgvector.
584 Args:
585 query_vector: Query vector (numpy array, list, or VectorField)
586 field_name: Name of vector field to search (must be in data JSON)
587 limit: Maximum number of results
588 filters: Optional filters to apply
589 metric: Distance metric to use (cosine, euclidean, l2, inner_product)
591 Returns:
592 List of VectorSearchResult objects ordered by similarity
593 """
594 if not self._vector_enabled:
595 raise RuntimeError("Vector search not available - pgvector not installed")
597 self._check_connection()
599 from ..fields import VectorField
600 from ..vector.types import DistanceMetric, VectorSearchResult
601 from .postgres_vector import format_vector_for_postgres, get_vector_operator
603 # Convert query vector to proper format
604 if isinstance(query_vector, VectorField):
605 vector_str = format_vector_for_postgres(query_vector.value)
606 else:
607 vector_str = format_vector_for_postgres(query_vector)
609 # Get the appropriate operator
610 if isinstance(metric, DistanceMetric):
611 metric_str = metric.value
612 else:
613 metric_str = str(metric).lower()
615 operator = get_vector_operator(metric_str)
617 # Build the query - vectors are stored in JSON data field
618 # Use centralized vector extraction logic
619 vector_expr = self.get_vector_extraction_sql(field_name, dialect="postgres")
621 # Build the base SQL with pyformat placeholders
622 sql = f"""
623 SELECT
624 id,
625 data,
626 metadata,
627 {vector_expr} {operator} %(p0)s::vector AS distance
628 FROM {self.schema_name}.{self.table_name}
629 WHERE data ? %(p1)s -- Check field exists
630 """
632 params: list[Any] = [vector_str, field_name]
634 # Add filters if provided using the query builder
635 if filter:
636 # Query builder will generate pyformat placeholders since we configured it that way
637 where_clause, filter_params = self.query_builder.build_where_clause(filter, len(params) + 1)
638 if where_clause:
639 sql += where_clause
640 params.extend(filter_params)
642 # Order by distance and limit
643 next_param = len(params)
644 sql += f" ORDER BY distance LIMIT %(p{next_param})s"
645 params.append(k)
647 # Build param dict for psycopg2
648 param_dict = {}
649 for i, param in enumerate(params):
650 param_dict[f"p{i}"] = param
652 df = self.db.query(sql, param_dict)
654 # Convert results
655 results = []
656 for _, row in df.iterrows():
657 record = self._row_to_record(row)
659 # Calculate similarity score from distance
660 distance = row["distance"]
661 if metric_str in ["cosine", "cosine_similarity"]:
662 score = 1.0 - distance # Cosine distance to similarity
663 elif metric_str in ["euclidean", "l2"]:
664 score = 1.0 / (1.0 + distance) # Convert distance to similarity
665 elif metric_str in ["inner_product", "dot_product"]:
666 score = -distance # Negative because pgvector uses negative for descending
667 else:
668 score = -distance # Default: lower distance = better
670 result = VectorSearchResult(
671 record=record,
672 score=float(score),
673 vector_field=field_name
674 )
675 results.append(result)
677 return results
679 def has_vector_support(self) -> bool:
680 """Check if this database has vector support enabled.
682 Returns:
683 True if vector operations are supported
684 """
685 return self._vector_enabled
687 def enable_vector_support(self) -> bool:
688 """Enable vector support for this database if possible.
690 Returns:
691 True if vector support is now enabled
692 """
693 if self._vector_enabled:
694 return True
696 self._detect_vector_support()
697 return self._vector_enabled
699 def bulk_embed_and_store(
700 self,
701 records: list[Record],
702 text_field: str | list[str],
703 vector_field: str = "embedding",
704 embedding_fn: Any = None,
705 batch_size: int = 100,
706 model_name: str | None = None,
707 model_version: str | None = None,
708 ) -> list[str]:
709 """Embed text fields and store vectors with records (stub for abstract requirement).
711 This is a placeholder implementation to satisfy the abstract method requirement.
712 Full implementation would require actual embedding function.
713 """
714 raise NotImplementedError("bulk_embed_and_store requires an embedding function")
717# Global pool manager instance for async PostgreSQL connections
718_pool_manager = ConnectionPoolManager[asyncpg.Pool]()
721class AsyncPostgresDatabase(
722 AsyncDatabase,
723 VectorOperationsMixin,
724 ConfigurableBase,
725 PostgresBaseConfig,
726 PostgresTableManager,
727 PostgresVectorSupport,
728 PostgresConnectionValidator,
729 PostgresErrorHandler,
730):
731 """Native async PostgreSQL database backend with vector support and event loop-aware connection pooling."""
733 def __init__(self, config: dict[str, Any] | None = None):
734 """Initialize async PostgreSQL database."""
735 super().__init__(config)
737 # Parse configuration using mixin
738 table_name, schema_name, conn_config = self._parse_postgres_config(config or {})
739 self._init_postgres_attributes(table_name, schema_name)
741 # Extract pool configuration
742 self._pool_config = PostgresPoolConfig.from_dict(conn_config)
743 self._pool: asyncpg.Pool | None = None
745 @classmethod
746 def from_config(cls, config: dict) -> AsyncPostgresDatabase:
747 """Create from config dictionary."""
748 return cls(config)
750 async def connect(self) -> None:
751 """Connect to the database."""
752 if self._connected:
753 return
755 # Get or create pool for current event loop
756 from ..pooling import BasePoolConfig
757 self._pool = await _pool_manager.get_pool(
758 self._pool_config,
759 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_asyncpg_pool),
760 validate_asyncpg_pool
761 )
763 # Initialize query builder
764 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
766 # Ensure table exists
767 await self._ensure_table()
769 # Check and enable vector support if requested
770 if self.vector_enabled:
771 await self._detect_vector_support()
773 self._connected = True
774 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}")
776 async def close(self) -> None:
777 """Close the database connection and properly close the pool."""
778 if self._connected:
779 # Properly close the pool if we have one
780 if self._pool:
781 try:
782 await self._pool.close()
783 except Exception as e:
784 logger.warning(f"Error closing connection pool: {e}")
785 self._pool = None
786 self._connected = False
788 def _initialize(self) -> None:
789 """Initialize is handled in connect."""
790 pass
792 async def _ensure_table(self) -> None:
793 """Ensure the records table exists."""
794 if not self._pool:
795 raise RuntimeError("Database not connected. Call connect() first.")
797 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name)
799 async with self._pool.acquire() as conn:
800 await conn.execute(create_table_sql)
802 async def _detect_vector_support(self) -> None:
803 """Detect and enable vector support if pgvector is available."""
804 from .postgres_vector import check_pgvector_extension, install_pgvector_extension
806 async with self._pool.acquire() as conn:
807 # Check if pgvector is available
808 if await check_pgvector_extension(conn):
809 self._vector_enabled = True
810 logger.info("pgvector extension detected and enabled")
811 else:
812 # Try to install it
813 if await install_pgvector_extension(conn):
814 self._vector_enabled = True
815 logger.info("pgvector extension installed and enabled")
816 else:
817 logger.debug("pgvector extension not available")
819 async def _ensure_vector_column(self, field_name: str, dimensions: int) -> None:
820 """Ensure a vector column exists for the given field.
822 Args:
823 field_name: Name of the vector field
824 dimensions: Number of dimensions
825 """
826 if not self._vector_enabled:
827 return
829 column_name = f"vector_{field_name}"
831 # Check if column already exists
832 check_sql = """
833 SELECT column_name FROM information_schema.columns
834 WHERE table_schema = $1 AND table_name = $2 AND column_name = $3
835 """
837 async with self._pool.acquire() as conn:
838 existing = await conn.fetchval(check_sql, self.schema_name, self.table_name, column_name)
840 if not existing:
841 # Add vector column
842 alter_sql = f"""
843 ALTER TABLE {self.schema_name}.{self.table_name}
844 ADD COLUMN IF NOT EXISTS {column_name} vector({dimensions})
845 """
846 try:
847 await conn.execute(alter_sql)
848 self._vector_dimensions[field_name] = dimensions
849 logger.info(f"Added vector column {column_name} with {dimensions} dimensions")
851 # Create index for the vector column
852 from .postgres_vector import build_vector_index_sql, get_optimal_index_type
854 # Get row count for optimal index selection
855 count_sql = f"SELECT COUNT(*) FROM {self.schema_name}.{self.table_name}"
856 count = await conn.fetchval(count_sql)
858 index_type, index_params = get_optimal_index_type(count)
859 index_sql = build_vector_index_sql(
860 self.table_name,
861 self.schema_name,
862 column_name,
863 dimensions,
864 metric="cosine",
865 index_type=index_type,
866 index_params=index_params
867 )
869 # Note: IVFFlat requires table to have data before creating index
870 if count > 0 or index_type != "ivfflat":
871 await conn.execute(index_sql)
872 logger.info(f"Created {index_type} index for {column_name}")
874 except Exception as e:
875 logger.warning(f"Could not create vector column {column_name}: {e}")
876 else:
877 self._vector_dimensions[field_name] = dimensions
879 def _check_connection(self) -> None:
880 """Check if async database is connected."""
881 self._check_async_connection()
883 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]:
884 """Convert a Record to a database row using common serializer."""
885 from .sql_base import SQLRecordSerializer
887 return {
888 "id": id or str(uuid.uuid4()),
889 "data": SQLRecordSerializer.record_to_json(record),
890 "metadata": json.dumps(record.metadata) if record.metadata else None,
891 }
893 def _row_to_record(self, row: asyncpg.Record) -> Record:
894 """Convert a database row to a Record using the common serializer."""
895 from .sql_base import SQLRecordSerializer
897 # Convert asyncpg.Record to dict format expected by SQLRecordSerializer
898 data_json = row.get("data", {})
899 if not isinstance(data_json, str):
900 data_json = json.dumps(data_json)
902 metadata_json = row.get("metadata")
903 if metadata_json and not isinstance(metadata_json, str):
904 metadata_json = json.dumps(metadata_json)
906 # Use the common serializer to reconstruct the record
907 return SQLRecordSerializer.json_to_record(data_json, metadata_json)
909 async def create(self, record: Record) -> str:
910 """Create a new record with vector support."""
911 self._check_connection()
913 # Check for vector fields and ensure columns exist
914 from ..fields import VectorField
915 for field_name, field_obj in record.fields.items():
916 if isinstance(field_obj, VectorField) and self._vector_enabled:
917 await self._ensure_vector_column(field_name, field_obj.dimensions)
919 # Use record's ID if it has one, otherwise generate a new one
920 id = record.id if record.id else str(uuid.uuid4())
921 row = self._record_to_row(record, id)
923 # Build dynamic SQL based on vector columns present
924 columns = ["id", "data", "metadata"]
925 values = [row["id"], row["data"], row["metadata"]]
926 placeholders = ["$1", "$2", "$3"]
928 # Add vector columns
929 param_num = 4
930 for key, value in row.items():
931 if key.startswith("vector_"):
932 columns.append(key)
933 values.append(value)
934 placeholders.append(f"${param_num}")
935 param_num += 1
937 sql = f"""
938 INSERT INTO {self.schema_name}.{self.table_name} ({', '.join(columns)})
939 VALUES ({', '.join(placeholders)})
940 """
942 async with self._pool.acquire() as conn:
943 await conn.execute(sql, *values)
945 return id
947 async def read(self, id: str) -> Record | None:
948 """Read a record by ID."""
949 self._check_connection()
950 sql = f"""
951 SELECT id, data, metadata
952 FROM {self.schema_name}.{self.table_name}
953 WHERE id = $1
954 """
956 async with self._pool.acquire() as conn:
957 row = await conn.fetchrow(sql, id)
959 if not row:
960 return None
962 return self._row_to_record(row)
964 async def update(self, id: str, record: Record) -> bool:
965 """Update an existing record."""
966 self._check_connection()
967 row = self._record_to_row(record, id)
969 sql = f"""
970 UPDATE {self.schema_name}.{self.table_name}
971 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP
972 WHERE id = $1
973 """
975 async with self._pool.acquire() as conn:
976 result = await conn.execute(sql, row["id"], row["data"], row["metadata"])
978 # Returns UPDATE n where n is rows affected
979 return result.split()[-1] != "0"
981 async def delete(self, id: str) -> bool:
982 """Delete a record by ID."""
983 self._check_connection()
984 sql = f"""
985 DELETE FROM {self.schema_name}.{self.table_name}
986 WHERE id = $1
987 """
989 async with self._pool.acquire() as conn:
990 result = await conn.execute(sql, id)
992 # Returns DELETE n where n is rows affected
993 return result.split()[-1] != "0"
995 async def exists(self, id: str) -> bool:
996 """Check if a record exists."""
997 self._check_connection()
998 sql = f"""
999 SELECT 1 FROM {self.schema_name}.{self.table_name}
1000 WHERE id = $1
1001 LIMIT 1
1002 """
1004 async with self._pool.acquire() as conn:
1005 row = await conn.fetchrow(sql, id)
1007 return row is not None
1009 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str:
1010 """Update or insert a record.
1012 Can be called as:
1013 - upsert(id, record) - explicit ID and record
1014 - upsert(record) - extract ID from record using Record's built-in logic
1015 """
1016 self._check_connection()
1018 # Determine ID and record based on arguments
1019 if isinstance(id_or_record, str):
1020 id = id_or_record
1021 if record is None:
1022 raise ValueError("Record required when ID is provided")
1023 else:
1024 record = id_or_record
1025 id = record.id
1026 if id is None:
1027 import uuid # type: ignore[unreachable]
1028 id = str(uuid.uuid4())
1029 record.storage_id = id
1031 row = self._record_to_row(record, id)
1033 sql = f"""
1034 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata)
1035 VALUES ($1, $2, $3)
1036 ON CONFLICT (id) DO UPDATE
1037 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP
1038 """
1040 async with self._pool.acquire() as conn:
1041 await conn.execute(sql, row["id"], row["data"], row["metadata"])
1043 return id
1045 async def search(self, query: Query | ComplexQuery) -> list[Record]:
1046 """Search for records matching the query."""
1047 self._check_connection()
1049 # Initialize query builder if not already done
1050 if not hasattr(self, 'query_builder'):
1051 self.query_builder = SQLQueryBuilder(
1052 self.table_name, self.schema_name, dialect="postgres"
1053 )
1055 # Handle ComplexQuery with native SQL support
1056 if isinstance(query, ComplexQuery):
1057 sql, params = self.query_builder.build_complex_search_query(query)
1058 else:
1059 sql, params = self.query_builder.build_search_query(query)
1061 # Execute query with asyncpg (already uses positional parameters)
1062 async with self._pool.acquire() as conn:
1063 rows = await conn.fetch(sql, *params)
1065 # Convert to records
1066 records = []
1067 for row in rows:
1068 record = self._row_to_record(row)
1070 # Populate storage_id from database ID
1071 record.storage_id = str(row['id'])
1073 # Apply field projection if specified
1074 if query.fields:
1075 record = record.project(query.fields)
1077 records.append(record)
1079 return records
1081 async def _count_all(self) -> int:
1082 """Count all records in the database."""
1083 self._check_connection()
1084 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}"
1086 async with self._pool.acquire() as conn:
1087 row = await conn.fetchrow(sql)
1089 return row["count"] if row else 0
1091 async def clear(self) -> int:
1092 """Clear all records from the database."""
1093 self._check_connection()
1094 # Get count first
1095 count = await self._count_all()
1097 # Delete all records
1098 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}"
1100 async with self._pool.acquire() as conn:
1101 await conn.execute(sql)
1103 return count
1105 async def create_batch(self, records: list[Record]) -> list[str]:
1106 """Create multiple records efficiently using a single query.
1108 Uses multi-value INSERT with RETURNING for better performance.
1110 Args:
1111 records: List of records to create
1113 Returns:
1114 List of created record IDs
1115 """
1116 if not records:
1117 return []
1119 self._check_connection()
1121 # Create a query builder for PostgreSQL
1122 from .sql_base import SQLQueryBuilder
1123 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1125 # Use the shared batch create query builder
1126 query, params, ids = query_builder.build_batch_create_query(records)
1128 # Execute the batch insert with RETURNING
1129 async with self._pool.acquire() as conn:
1130 rows = await conn.fetch(query, *params)
1132 # Return the actual inserted IDs from RETURNING clause
1133 if rows:
1134 return [row["id"] for row in rows]
1135 return ids # Fallback to generated IDs
1137 async def delete_batch(self, ids: list[str]) -> list[bool]:
1138 """Delete multiple records efficiently using a single query.
1140 Uses single DELETE with IN clause and RETURNING for verification.
1142 Args:
1143 ids: List of record IDs to delete
1145 Returns:
1146 List of success flags for each deletion
1147 """
1148 if not ids:
1149 return []
1151 self._check_connection()
1153 # Create a query builder for PostgreSQL
1154 from .sql_base import SQLQueryBuilder
1155 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1157 # Use the shared batch delete query builder
1158 query, params = query_builder.build_batch_delete_query(ids)
1160 # Execute the batch delete with RETURNING
1161 async with self._pool.acquire() as conn:
1162 rows = await conn.fetch(query, *params)
1164 # Convert returned rows to set of deleted IDs
1165 deleted_ids = {row["id"] for row in rows}
1167 # Return results for each deletion
1168 results = []
1169 for id in ids:
1170 results.append(id in deleted_ids)
1172 return results
1174 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
1175 """Update multiple records efficiently using a single query.
1177 Uses PostgreSQL's CASE expressions for batch updates with native asyncpg.
1179 Args:
1180 updates: List of (id, record) tuples to update
1182 Returns:
1183 List of success flags for each update
1184 """
1185 if not updates:
1186 return []
1188 self._check_connection()
1190 # Create a query builder for PostgreSQL
1191 from .sql_base import SQLQueryBuilder
1192 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres")
1194 # Use the shared batch update query builder
1195 # It already produces positional parameters ($1, $2) for PostgreSQL
1196 query, params = query_builder.build_batch_update_query(updates)
1198 # Add RETURNING clause for PostgreSQL to get updated IDs
1199 query = query.rstrip() + " RETURNING id"
1201 # Execute the batch update
1202 async with self._pool.acquire() as conn:
1203 rows = await conn.fetch(query, *params)
1205 # Convert returned rows to set of updated IDs
1206 updated_ids = {row["id"] for row in rows}
1208 # Return results for each update
1209 results = []
1210 for record_id, _ in updates:
1211 results.append(record_id in updated_ids)
1213 return results
1215 async def vector_search(
1216 self,
1217 query_vector: np.ndarray | list[float] | VectorField,
1218 field_name: str,
1219 k: int = 10,
1220 filter: Query | None = None,
1221 metric: DistanceMetric | str = "cosine"
1222 ) -> list[VectorSearchResult]:
1223 """Search for similar vectors using PostgreSQL pgvector.
1225 Args:
1226 query_vector: Query vector (numpy array, list, or VectorField)
1227 field_name: Name of vector field to search
1228 limit: Maximum number of results
1229 filters: Optional filters to apply
1230 metric: Distance metric to use
1232 Returns:
1233 List of VectorSearchResult objects
1234 """
1235 if not self._vector_enabled:
1236 raise RuntimeError("Vector search not available - pgvector not installed")
1238 self._check_connection()
1240 from ..fields import VectorField
1241 from ..vector.types import DistanceMetric, VectorSearchResult
1242 from .postgres_vector import format_vector_for_postgres, get_vector_operator
1244 # Convert query vector to proper format
1245 if isinstance(query_vector, VectorField):
1246 vector_str = format_vector_for_postgres(query_vector.value)
1247 else:
1248 vector_str = format_vector_for_postgres(query_vector)
1250 # Get the appropriate operator
1251 if isinstance(metric, DistanceMetric):
1252 metric_str = metric.value
1253 else:
1254 metric_str = str(metric).lower()
1255 operator = get_vector_operator(metric_str)
1257 vector_column = f"vector_{field_name}"
1259 # Build query
1260 sql = f"""
1261 SELECT id, data, metadata, {vector_column},
1262 {vector_column} {operator} $1::vector AS distance
1263 FROM {self.schema_name}.{self.table_name}
1264 WHERE {vector_column} IS NOT NULL
1265 """
1267 params = [vector_str]
1268 param_num = 2
1270 # Add filters if provided using the query builder
1271 if filter:
1272 # First get the where clause from query builder
1273 where_clause, filter_params = self.query_builder.build_where_clause(filter, param_num)
1274 if where_clause:
1275 # Convert %s placeholders to $N for asyncpg
1276 for param in filter_params:
1277 where_clause = where_clause.replace("%s", f"${param_num}", 1)
1278 params.append(param)
1279 param_num += 1
1280 sql += where_clause
1282 # Order by distance and limit
1283 sql += f"""
1284 ORDER BY distance
1285 LIMIT {k}
1286 """
1288 # Execute query
1289 async with self._pool.acquire() as conn:
1290 rows = await conn.fetch(sql, *params)
1292 # Convert to VectorSearchResult objects
1293 results = []
1294 for row in rows:
1295 record = self._row_to_record(row)
1297 # Convert distance to similarity score (1 - normalized_distance for cosine)
1298 distance = float(row['distance'])
1299 if metric_str == "cosine":
1300 score = 1.0 - min(distance, 2.0) / 2.0 # Normalize cosine distance [0,2] to similarity [0,1]
1301 elif metric_str in ["euclidean", "l2"]:
1302 score = 1.0 / (1.0 + distance) # Convert distance to similarity
1303 else:
1304 score = 1.0 - distance # Generic conversion
1306 result = VectorSearchResult(
1307 record=record,
1308 score=score,
1309 vector_field=field_name,
1310 metadata={"distance": distance, "metric": metric_str}
1311 )
1312 results.append(result)
1314 return results
1316 async def enable_vector_support(self) -> bool:
1317 """Enable vector support for this database.
1319 Returns:
1320 True if vector support is enabled
1321 """
1322 if self._vector_enabled:
1323 return True
1325 await self._detect_vector_support()
1326 return self._vector_enabled
1328 async def has_vector_support(self) -> bool:
1329 """Check if this database has vector support enabled.
1331 Returns:
1332 True if vector support is available
1333 """
1334 return self._vector_enabled
1336 async def bulk_embed_and_store(
1337 self,
1338 records: list[Record],
1339 text_field: str | list[str],
1340 vector_field: str,
1341 embedding_fn: Any | None = None,
1342 batch_size: int = 100,
1343 model_name: str | None = None,
1344 model_version: str | None = None,
1345 ) -> list[str]:
1346 """Embed text fields and store vectors with records.
1348 This is a placeholder implementation. In a real scenario, you would:
1349 1. Extract text from the specified fields
1350 2. Call the embedding function to generate vectors
1351 3. Store the vectors alongside the records
1353 Args:
1354 records: Records to process
1355 text_field: Field name(s) containing text to embed
1356 vector_field: Field name to store vectors in
1357 embedding_fn: Function to generate embeddings
1358 batch_size: Number of records to process at once
1359 model_name: Name of the embedding model
1360 model_version: Version of the embedding model
1362 Returns:
1363 List of record IDs that were processed
1364 """
1365 if not embedding_fn:
1366 raise ValueError("embedding_fn is required for bulk_embed_and_store")
1368 from ..fields import VectorField
1370 processed_ids = []
1372 # Process in batches
1373 for i in range(0, len(records), batch_size):
1374 batch = records[i:i + batch_size]
1376 # Extract texts
1377 texts = []
1378 for record in batch:
1379 if isinstance(text_field, list):
1380 text = " ".join(str(record.fields.get(f, {}).value) for f in text_field if f in record.fields)
1381 else:
1382 text = str(record.fields.get(text_field, {}).value) if text_field in record.fields else ""
1383 texts.append(text)
1385 # Generate embeddings
1386 if texts:
1387 embeddings = await embedding_fn(texts)
1389 # Store vectors with records
1390 for j, record in enumerate(batch):
1391 if j < len(embeddings):
1392 vector = embeddings[j]
1394 # Add vector field to record
1395 record.fields[vector_field] = VectorField(
1396 name=vector_field,
1397 value=vector,
1398 dimensions=len(vector) if hasattr(vector, '__len__') else None,
1399 source_field=text_field if isinstance(text_field, str) else ",".join(text_field),
1400 model_name=model_name,
1401 model_version=model_version,
1402 )
1404 # Create or update record
1405 if record.has_storage_id():
1406 if record.storage_id is None:
1407 raise ValueError("Record has_storage_id() returned True but storage_id is None")
1408 await self.update(record.storage_id, record)
1409 else:
1410 record_id = await self.create(record)
1411 record.storage_id = record_id
1413 if record.storage_id is None:
1414 raise ValueError("Record storage_id is None after create/update")
1415 processed_ids.append(record.storage_id)
1417 return processed_ids
1419 async def create_vector_index(
1420 self,
1421 vector_field: str,
1422 dimensions: int,
1423 metric: DistanceMetric | str = "cosine",
1424 index_type: str = "ivfflat",
1425 lists: int | None = None,
1426 ) -> bool:
1427 """Create a vector index for efficient similarity search.
1429 Args:
1430 vector_field: Name of the vector field to index
1431 dimensions: Number of dimensions in the vectors
1432 metric: Distance metric for the index
1433 index_type: Type of index (ivfflat, hnsw)
1434 lists: Number of lists for IVFFlat index
1436 Returns:
1437 True if index was created successfully
1438 """
1439 from .postgres_vector import (
1440 build_vector_column_expression,
1441 build_vector_index_sql,
1442 get_optimal_index_type,
1443 get_vector_count_sql,
1444 )
1446 self._check_connection()
1448 if not self._vector_enabled:
1449 return False
1451 # Determine optimal parameters if not provided
1452 if not lists and index_type == "ivfflat":
1453 # Count vectors to determine optimal lists
1454 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field)
1455 async with self._pool.acquire() as conn:
1456 count = await conn.fetchval(count_sql) or 0
1457 _, params = get_optimal_index_type(count)
1458 lists = params.get("lists", 100)
1460 # Convert metric enum to string if needed
1461 if hasattr(metric, 'value'):
1462 metric_str = metric.value
1463 else:
1464 metric_str = str(metric).lower()
1466 # Build vector column expression for index
1467 column_expr = build_vector_column_expression(vector_field, dimensions, for_index=True)
1469 # Build index SQL - pass field_name for proper index naming
1470 index_sql = build_vector_index_sql(
1471 table_name=self.table_name,
1472 schema_name=self.schema_name,
1473 column_name=column_expr,
1474 dimensions=dimensions,
1475 metric=metric_str,
1476 index_type=index_type,
1477 index_params={"lists": lists} if lists else None,
1478 field_name=vector_field
1479 )
1481 # Create the index
1482 try:
1483 logger.debug(f"Creating vector index with SQL: {index_sql}")
1484 async with self._pool.acquire() as conn:
1485 await conn.execute(index_sql)
1486 return True
1487 except Exception as e:
1488 logger.warning(f"Failed to create vector index: {e}")
1489 logger.debug(f"Index SQL was: {index_sql}")
1490 return False
1492 async def drop_vector_index(self, vector_field: str, metric: str = "cosine") -> bool:
1493 """Drop a vector index.
1495 Args:
1496 vector_field: Name of the vector field
1497 metric: Distance metric used in the index
1499 Returns:
1500 True if index was dropped successfully
1501 """
1502 from .postgres_vector import get_vector_index_name
1504 self._check_connection()
1506 index_name = get_vector_index_name(self.table_name, vector_field, metric)
1508 try:
1509 async with self._pool.acquire() as conn:
1510 await conn.execute(f"DROP INDEX IF EXISTS {self.schema_name}.{index_name}")
1511 return True
1512 except Exception as e:
1513 logger.warning(f"Failed to drop vector index: {e}")
1514 return False
1516 async def get_vector_index_stats(self, vector_field: str) -> dict[str, Any]:
1517 """Get statistics about a vector field and its index.
1519 Args:
1520 vector_field: Name of the vector field
1522 Returns:
1523 Dictionary with index statistics
1524 """
1525 from .postgres_vector import get_index_check_sql, get_vector_count_sql
1527 self._check_connection()
1529 stats = {
1530 "field": vector_field,
1531 "indexed": False,
1532 "vector_count": 0,
1533 }
1535 try:
1536 async with self._pool.acquire() as conn:
1537 # Count vectors
1538 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field)
1539 stats["vector_count"] = await conn.fetchval(count_sql) or 0
1541 # Check for index
1542 index_sql, params = get_index_check_sql(self.schema_name, self.table_name, vector_field)
1543 stats["indexed"] = await conn.fetchval(index_sql, *params) or False
1544 except Exception as e:
1545 logger.warning(f"Failed to get vector index stats: {e}")
1547 return stats
1549 async def stream_read(
1550 self,
1551 query: Query | None = None,
1552 config: StreamConfig | None = None
1553 ) -> AsyncIterator[Record]:
1554 """Stream records from PostgreSQL using cursor."""
1555 self._check_connection()
1556 config = config or StreamConfig()
1558 # Build SQL query
1559 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}"
1560 params = []
1562 if query and query.filters:
1563 where_clauses = []
1564 param_count = 0
1566 for filter in query.filters:
1567 param_count += 1
1568 field_path = f"data->>'{filter.field}'"
1570 if filter.operator == Operator.EQ:
1571 where_clauses.append(f"{field_path} = ${param_count}")
1572 params.append(str(filter.value))
1574 if where_clauses:
1575 sql += " WHERE " + " AND ".join(where_clauses)
1577 # Use cursor for efficient streaming
1578 async with self._pool.acquire() as conn:
1579 async with conn.transaction():
1580 cursor = await conn.cursor(sql, *params)
1582 batch = []
1583 async for row in cursor:
1584 record = self._row_to_record(row)
1585 if query and query.fields:
1586 record = record.project(query.fields)
1588 batch.append(record)
1590 if len(batch) >= config.batch_size:
1591 for rec in batch:
1592 yield rec
1593 batch = []
1595 # Yield remaining records
1596 for rec in batch:
1597 yield rec
1599 async def stream_write(
1600 self,
1601 records: AsyncIterator[Record],
1602 config: StreamConfig | None = None
1603 ) -> StreamResult:
1604 """Stream records into PostgreSQL using batch inserts."""
1605 self._check_connection()
1606 config = config or StreamConfig()
1607 result = StreamResult()
1608 start_time = time.time()
1609 quitting = False
1611 batch = []
1612 async for record in records:
1613 batch.append(record)
1615 if len(batch) >= config.batch_size:
1616 # Write batch with graceful fallback
1617 # Use lambda wrapper for _write_batch
1618 async def batch_func(b):
1619 await self._write_batch(b)
1620 return [r.id for r in b]
1622 continue_processing = await async_process_batch_with_fallback(
1623 batch,
1624 batch_func,
1625 self.create,
1626 result,
1627 config
1628 )
1630 if not continue_processing:
1631 quitting = True
1632 break
1634 batch = []
1636 # Write remaining batch
1637 if batch and not quitting:
1638 async def batch_func(b):
1639 await self._write_batch(b)
1640 return [r.id for r in b]
1642 await async_process_batch_with_fallback(
1643 batch,
1644 batch_func,
1645 self.create,
1646 result,
1647 config
1648 )
1650 result.duration = time.time() - start_time
1651 return result
1653 async def _write_batch(self, records: list[Record]) -> list[str]:
1654 """Write a batch of records using COPY for performance.
1656 Returns:
1657 List of created record IDs
1658 """
1659 if not records:
1660 return []
1662 # Prepare data for COPY
1663 rows = []
1664 ids = []
1665 for record in records:
1666 row_data = self._record_to_row(record)
1667 ids.append(row_data["id"])
1668 rows.append((
1669 row_data["id"],
1670 row_data["data"],
1671 row_data["metadata"]
1672 ))
1674 # Use COPY for efficient bulk insert
1675 async with self._pool.acquire() as conn:
1676 await conn.copy_records_to_table(
1677 f"{self.schema_name}.{self.table_name}",
1678 records=rows,
1679 columns=["id", "data", "metadata"]
1680 )
1682 return ids