Coverage for src/dataknobs_data/backends/postgres.py: 13%

728 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-08 15:25 -0600

1"""PostgreSQL backend implementation with proper connection management and vector support.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import time 

8import uuid 

9from typing import TYPE_CHECKING, Any, cast 

10 

11import asyncpg 

12from dataknobs_config import ConfigurableBase 

13 

14from dataknobs_utils.sql_utils import DotenvPostgresConnector, PostgresDB 

15 

16from ..database import AsyncDatabase, SyncDatabase 

17from ..pooling import ConnectionPoolManager 

18from ..pooling.postgres import PostgresPoolConfig, create_asyncpg_pool, validate_asyncpg_pool 

19from ..query import Operator, Query 

20from ..query_logic import ComplexQuery 

21from ..streaming import ( 

22 StreamConfig, 

23 StreamResult, 

24 async_process_batch_with_fallback, 

25 process_batch_with_fallback, 

26) 

27from ..vector.mixins import VectorOperationsMixin 

28from .postgres_mixins import ( 

29 PostgresBaseConfig, 

30 PostgresConnectionValidator, 

31 PostgresErrorHandler, 

32 PostgresTableManager, 

33 PostgresVectorSupport, 

34) 

35from .sql_base import SQLQueryBuilder, SQLRecordSerializer 

36 

37if TYPE_CHECKING: 

38 import numpy as np 

39 

40 from collections.abc import AsyncIterator, Iterator, Callable, Awaitable 

41 from ..fields import VectorField 

42 from ..records import Record 

43 from ..vector.types import DistanceMetric, VectorSearchResult 

44 

45logger = logging.getLogger(__name__) 

46 

47 

48class SyncPostgresDatabase( 

49 SyncDatabase, 

50 ConfigurableBase, 

51 VectorOperationsMixin, 

52 SQLRecordSerializer, 

53 PostgresBaseConfig, 

54 PostgresTableManager, 

55 PostgresVectorSupport, 

56 PostgresConnectionValidator, 

57 PostgresErrorHandler, 

58): 

59 """Synchronous PostgreSQL database backend with proper connection management.""" 

60 

61 def __init__(self, config: dict[str, Any] | None = None): 

62 """Initialize PostgreSQL database configuration. 

63 

64 Args: 

65 config: Configuration with the following optional keys: 

66 - host: PostgreSQL host (default: from env/localhost) 

67 - port: PostgreSQL port (default: 5432) 

68 - database: Database name (default: from env/postgres) 

69 - user: Username (default: from env/postgres) 

70 - password: Password (default: from env) 

71 - table: Table name (default: "records") 

72 - schema: Schema name (default: "public") 

73 - enable_vector: Enable vector support (default: False) 

74 """ 

75 super().__init__(config) 

76 

77 # Parse configuration using mixin 

78 table_name, schema_name, conn_config = self._parse_postgres_config(config or {}) 

79 self._init_postgres_attributes(table_name, schema_name) 

80 

81 # Store connection config for later use 

82 self._conn_config = conn_config 

83 self.db = None # Will be initialized in connect() 

84 self.query_builder = None # Will be initialized in connect() 

85 

86 @classmethod 

87 def from_config(cls, config: dict) -> SyncPostgresDatabase: 

88 """Create from config dictionary.""" 

89 return cls(config) 

90 

91 def connect(self) -> None: 

92 """Connect to the PostgreSQL database.""" 

93 if self._connected: 

94 return # Already connected 

95 

96 # Initialize query builder with pyformat style for psycopg2 

97 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

98 

99 # Create connection using existing utilities 

100 if not any(key in self._conn_config for key in ["host", "database", "user"]): 

101 # Use dotenv connector for environment-based config 

102 connector = DotenvPostgresConnector() 

103 self.db = PostgresDB(connector) 

104 else: 

105 # Direct configuration - map 'database' to 'db' for PostgresDB 

106 self.db = PostgresDB( 

107 host=self._conn_config.get("host", "localhost"), 

108 db=self._conn_config.get("database", "postgres"), # Note: PostgresDB expects 'db' not 'database' 

109 user=self._conn_config.get("user", "postgres"), 

110 pwd=self._conn_config.get("password"), # Note: PostgresDB expects 'pwd' not 'password' 

111 port=self._conn_config.get("port", 5432), 

112 ) 

113 

114 # Create table if it doesn't exist 

115 self._ensure_table() 

116 

117 # Detect and enable vector support if requested 

118 if self.vector_enabled: 

119 self._detect_vector_support() 

120 

121 self._connected = True 

122 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}") 

123 

124 def close(self) -> None: 

125 """Close the database connection.""" 

126 if self.db: 

127 # PostgresDB manages its own connections via context managers 

128 # but we can mark as disconnected 

129 self._connected = False # type: ignore[unreachable] 

130 

131 def _initialize(self) -> None: 

132 """Initialize method - connection setup moved to connect().""" 

133 # Configuration parsing stays here, actual connection in connect() 

134 pass 

135 

136 def _detect_vector_support(self) -> None: 

137 """Detect and enable vector support if pgvector is available.""" 

138 from .postgres_vector import check_pgvector_extension_sync, install_pgvector_extension_sync 

139 

140 try: 

141 # Check if pgvector is installed 

142 if check_pgvector_extension_sync(self.db): 

143 self._vector_enabled = True 

144 logger.info("pgvector extension detected and enabled") 

145 else: 

146 # Try to install it 

147 if install_pgvector_extension_sync(self.db): 

148 self._vector_enabled = True 

149 logger.info("pgvector extension installed and enabled") 

150 else: 

151 logger.debug("pgvector extension not available") 

152 except Exception as e: 

153 logger.debug(f"Could not enable vector support: {e}") 

154 self._vector_enabled = False 

155 

156 def _ensure_table(self) -> None: 

157 """Ensure the records table exists.""" 

158 if not self.db: 

159 raise RuntimeError("Database not connected. Call connect() first.") 

160 

161 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) # type: ignore[unreachable] 

162 self.db.execute(create_table_sql) 

163 

164 

165 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

166 """Convert a Record to a database row.""" 

167 return { 

168 "id": id or str(uuid.uuid4()), 

169 "data": self.record_to_json(record), 

170 "metadata": json.dumps(record.metadata) if record.metadata else None, 

171 } 

172 

173 def _row_to_record(self, row: dict[str, Any]) -> Record: 

174 """Convert a database row to a Record.""" 

175 return self.row_to_record(row) 

176 

177 def create(self, record: Record) -> str: 

178 """Create a new record.""" 

179 self._check_connection() 

180 # Use record's ID if it has one, otherwise generate a new one 

181 id = record.id if record.id else str(uuid.uuid4()) 

182 row = self._record_to_row(record, id) 

183 

184 sql = f""" 

185 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

186 VALUES (%(id)s, %(data)s, %(metadata)s) 

187 """ 

188 self.db.execute(sql, row) 

189 return id 

190 

191 def read(self, id: str) -> Record | None: 

192 """Read a record by ID.""" 

193 self._check_connection() 

194 sql = f""" 

195 SELECT id, data, metadata 

196 FROM {self.schema_name}.{self.table_name} 

197 WHERE id = %(id)s 

198 """ 

199 df = self.db.query(sql, {"id": id}) 

200 

201 if df.empty: 

202 return None 

203 

204 row = df.iloc[0].to_dict() 

205 return self._row_to_record(row) 

206 

207 def update(self, id: str, record: Record) -> bool: 

208 """Update an existing record.""" 

209 self._check_connection() 

210 row = self._record_to_row(record, id) 

211 

212 sql = f""" 

213 UPDATE {self.schema_name}.{self.table_name} 

214 SET data = %(data)s, metadata = %(metadata)s, updated_at = CURRENT_TIMESTAMP 

215 WHERE id = %(id)s 

216 """ 

217 result = self.db.execute(sql, row) 

218 # PostgresDB.execute returns number of affected rows 

219 return result > 0 if isinstance(result, int) else False 

220 

221 def delete(self, id: str) -> bool: 

222 """Delete a record by ID.""" 

223 self._check_connection() 

224 sql = f""" 

225 DELETE FROM {self.schema_name}.{self.table_name} 

226 WHERE id = %(id)s 

227 """ 

228 result = self.db.execute(sql, {"id": id}) 

229 return result > 0 if isinstance(result, int) else False 

230 

231 def exists(self, id: str) -> bool: 

232 """Check if a record exists.""" 

233 self._check_connection() 

234 sql = f""" 

235 SELECT 1 FROM {self.schema_name}.{self.table_name} 

236 WHERE id = %(id)s 

237 LIMIT 1 

238 """ 

239 df = self.db.query(sql, {"id": id}) 

240 return not df.empty 

241 

242 def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

243 """Update or insert a record. 

244  

245 Can be called as: 

246 - upsert(id, record) - explicit ID and record 

247 - upsert(record) - extract ID from record using Record's built-in logic 

248 """ 

249 self._check_connection() 

250 

251 # Determine ID and record based on arguments 

252 if isinstance(id_or_record, str): 

253 id = id_or_record 

254 if record is None: 

255 raise ValueError("Record required when ID is provided") 

256 else: 

257 record = id_or_record 

258 id = record.id 

259 if id is None: 

260 import uuid # type: ignore[unreachable] 

261 id = str(uuid.uuid4()) 

262 record.storage_id = id 

263 

264 if self.exists(id): 

265 self.update(id, record) 

266 else: 

267 # Insert with specific ID 

268 row = self._record_to_row(record, id) 

269 sql = f""" 

270 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

271 VALUES (%(id)s, %(data)s, %(metadata)s) 

272 """ 

273 self.db.execute(sql, row) 

274 return id 

275 

276 def search(self, query: Query | ComplexQuery) -> list[Record]: 

277 """Search for records matching the query.""" 

278 self._check_connection() 

279 

280 # Handle ComplexQuery with native SQL support 

281 if isinstance(query, ComplexQuery): 

282 sql_query, params_list = self.query_builder.build_complex_search_query(query) 

283 else: 

284 sql_query, params_list = self.query_builder.build_search_query(query) 

285 

286 # Build params dict for psycopg2 

287 # The query builder now generates %(p0)s style placeholders directly 

288 params_dict = {} 

289 if params_list: 

290 for i, param in enumerate(params_list): 

291 params_dict[f"p{i}"] = param 

292 

293 # Execute query 

294 df = self.db.query(sql_query, params_dict) 

295 

296 # Convert to records 

297 records = [] 

298 for _, row in df.iterrows(): 

299 row_dict = row.to_dict() 

300 record = self._row_to_record(row_dict) 

301 

302 # Populate storage_id from database ID 

303 record.storage_id = str(row_dict['id']) 

304 

305 # Apply field projection if specified 

306 if query.fields: 

307 record = record.project(query.fields) 

308 

309 records.append(record) 

310 

311 return records 

312 

313 def _count_all(self) -> int: 

314 """Count all records in the database.""" 

315 self._check_connection() 

316 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

317 df = self.db.query(sql) 

318 return int(df.iloc[0]["count"]) if not df.empty else 0 

319 

320 def clear(self) -> int: 

321 """Clear all records from the database.""" 

322 self._check_connection() 

323 # Get count first 

324 count = self._count_all() 

325 

326 # Delete all records 

327 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

328 self.db.execute(sql) 

329 

330 return count 

331 

332 def create_batch(self, records: list[Record]) -> list[str]: 

333 """Create multiple records efficiently using a single query. 

334  

335 Uses multi-value INSERT for better performance. 

336  

337 Args: 

338 records: List of records to create 

339  

340 Returns: 

341 List of created record IDs 

342 """ 

343 if not records: 

344 return [] 

345 

346 self._check_connection() 

347 

348 # Create a query builder for PostgreSQL with pyformat style 

349 from .sql_base import SQLQueryBuilder 

350 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

351 

352 # Use the shared batch create query builder 

353 query, params_list, ids = query_builder.build_batch_create_query(records) 

354 

355 # Build params dict for psycopg2 

356 params_dict = {} 

357 for i, param in enumerate(params_list): 

358 params_dict[f"p{i}"] = param 

359 

360 # Execute the batch insert and get returned IDs 

361 result_df = self.db.query(query, params_dict) 

362 

363 # PostgreSQL RETURNING clause gives us the actual inserted IDs 

364 if not result_df.empty: 

365 return result_df['id'].tolist() 

366 return ids 

367 

368 def delete_batch(self, ids: list[str]) -> list[bool]: 

369 """Delete multiple records efficiently using a single query. 

370  

371 Uses single DELETE with IN clause for better performance. 

372  

373 Args: 

374 ids: List of record IDs to delete 

375  

376 Returns: 

377 List of success flags for each deletion 

378 """ 

379 if not ids: 

380 return [] 

381 

382 self._check_connection() 

383 

384 # Create a query builder for PostgreSQL with pyformat style 

385 from .sql_base import SQLQueryBuilder 

386 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

387 

388 # Use the shared batch delete query builder (includes RETURNING clause) 

389 query, params_list = query_builder.build_batch_delete_query(ids) 

390 

391 # Build params dict for psycopg2 

392 params_dict = {} 

393 for i, param in enumerate(params_list): 

394 params_dict[f"p{i}"] = param 

395 

396 # Execute the batch delete and get returned IDs 

397 result_df = self.db.query(query, params_dict) 

398 

399 # Get list of deleted IDs from RETURNING clause 

400 deleted_ids = set(result_df['id'].tolist()) if not result_df.empty else set() 

401 

402 # Return results based on which IDs were actually deleted 

403 results = [] 

404 for id in ids: 

405 results.append(id in deleted_ids) 

406 

407 return results 

408 

409 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

410 """Update multiple records efficiently using a single query. 

411  

412 Uses PostgreSQL's CASE expressions for batch updates via shared SQL builder. 

413  

414 Args: 

415 updates: List of (id, record) tuples to update 

416  

417 Returns: 

418 List of success flags for each update 

419 """ 

420 if not updates: 

421 return [] 

422 

423 self._check_connection() 

424 

425 # Create a query builder for PostgreSQL with pyformat style 

426 from .sql_base import SQLQueryBuilder 

427 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres", param_style="pyformat") 

428 

429 # Use the shared batch update query builder 

430 query, params_list = query_builder.build_batch_update_query(updates) 

431 

432 # Build params dict for psycopg2 

433 params_dict = {} 

434 for i, param in enumerate(params_list): 

435 params_dict[f"p{i}"] = param 

436 

437 # Execute the batch update and get returned IDs (query now includes RETURNING clause) 

438 result_df = self.db.query(query, params_dict) 

439 

440 # Get list of updated IDs from RETURNING clause 

441 updated_ids = set(result_df['id'].tolist()) if not result_df.empty else set() 

442 

443 results = [] 

444 for record_id, _ in updates: 

445 results.append(record_id in updated_ids) 

446 

447 return results 

448 

449 def stream_read( 

450 self, 

451 query: Query | None = None, 

452 config: StreamConfig | None = None 

453 ) -> Iterator[Record]: 

454 """Stream records from PostgreSQL.""" 

455 self._check_connection() 

456 config = config or StreamConfig() 

457 

458 # Build SQL query 

459 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

460 params = {} 

461 

462 if query and query.filters: 

463 # Add WHERE clause (simplified for now) 

464 where_clauses = [] 

465 for i, filter in enumerate(query.filters): 

466 field_path = f"data->>'{filter.field}'" 

467 param_name = f"param_{i}" 

468 

469 if filter.operator == Operator.EQ: 

470 where_clauses.append(f"{field_path} = %({param_name})s") 

471 params[param_name] = str(filter.value) 

472 

473 if where_clauses: 

474 sql += " WHERE " + " AND ".join(where_clauses) 

475 

476 # Use cursor for streaming 

477 # Note: PostgresDB may need modification to support cursors 

478 # For now, we'll fetch in batches 

479 sql += f" LIMIT {config.batch_size} OFFSET %(offset)s" 

480 

481 offset = 0 

482 while True: 

483 params["offset"] = offset 

484 df = self.db.query(sql, params) 

485 

486 if df.empty: 

487 break 

488 

489 for _, row in df.iterrows(): 

490 record = self._row_to_record(row.to_dict()) 

491 if query and query.fields: 

492 record = record.project(query.fields) 

493 yield record 

494 

495 offset += config.batch_size 

496 

497 # If we got less than batch_size, we're done 

498 if len(df) < config.batch_size: 

499 break 

500 

501 def stream_write( 

502 self, 

503 records: Iterator[Record], 

504 config: StreamConfig | None = None 

505 ) -> StreamResult: 

506 """Stream records into PostgreSQL.""" 

507 self._check_connection() 

508 config = config or StreamConfig() 

509 result = StreamResult() 

510 start_time = time.time() 

511 quitting = False 

512 

513 batch = [] 

514 for record in records: 

515 batch.append(record) 

516 

517 if len(batch) >= config.batch_size: 

518 # Write batch with graceful fallback 

519 # Use lambda wrapper for _write_batch 

520 continue_processing = process_batch_with_fallback( 

521 batch, 

522 lambda b: self._write_batch(b), 

523 self.create, 

524 result, 

525 config 

526 ) 

527 

528 if not continue_processing: 

529 quitting = True 

530 break 

531 

532 batch = [] 

533 

534 # Write remaining batch 

535 if batch and not quitting: 

536 process_batch_with_fallback( 

537 batch, 

538 lambda b: self._write_batch(b), 

539 self.create, 

540 result, 

541 config 

542 ) 

543 

544 result.duration = time.time() - start_time 

545 return result 

546 

547 def _write_batch(self, records: list[Record]) -> list[str]: 

548 """Write a batch of records to the database. 

549  

550 Returns: 

551 List of created record IDs 

552 """ 

553 # Build batch insert SQL 

554 values = [] 

555 params = {} 

556 ids = [] 

557 

558 for i, record in enumerate(records): 

559 id = str(uuid.uuid4()) 

560 ids.append(id) 

561 row = self._record_to_row(record, id) 

562 values.append(f"(%(id_{i})s, %(data_{i})s, %(metadata_{i})s)") 

563 params[f"id_{i}"] = row["id"] 

564 params[f"data_{i}"] = row["data"] 

565 params[f"metadata_{i}"] = row["metadata"] 

566 

567 sql = f""" 

568 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

569 VALUES {', '.join(values)} 

570 """ 

571 self.db.execute(sql, params) 

572 return ids 

573 

574 def vector_search( 

575 self, 

576 query_vector: np.ndarray | list[float] | VectorField, 

577 field_name: str, 

578 k: int = 10, 

579 filter: Query | None = None, 

580 metric: DistanceMetric | str = "cosine" 

581 ) -> list[VectorSearchResult]: 

582 """Search for similar vectors using PostgreSQL pgvector. 

583  

584 Args: 

585 query_vector: Query vector (numpy array, list, or VectorField) 

586 field_name: Name of vector field to search (must be in data JSON) 

587 limit: Maximum number of results 

588 filters: Optional filters to apply 

589 metric: Distance metric to use (cosine, euclidean, l2, inner_product) 

590  

591 Returns: 

592 List of VectorSearchResult objects ordered by similarity 

593 """ 

594 if not self._vector_enabled: 

595 raise RuntimeError("Vector search not available - pgvector not installed") 

596 

597 self._check_connection() 

598 

599 from ..fields import VectorField 

600 from ..vector.types import DistanceMetric, VectorSearchResult 

601 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

602 

603 # Convert query vector to proper format 

604 if isinstance(query_vector, VectorField): 

605 vector_str = format_vector_for_postgres(query_vector.value) 

606 else: 

607 vector_str = format_vector_for_postgres(query_vector) 

608 

609 # Get the appropriate operator 

610 if isinstance(metric, DistanceMetric): 

611 metric_str = metric.value 

612 else: 

613 metric_str = str(metric).lower() 

614 

615 operator = get_vector_operator(metric_str) 

616 

617 # Build the query - vectors are stored in JSON data field 

618 # Use centralized vector extraction logic 

619 vector_expr = self.get_vector_extraction_sql(field_name, dialect="postgres") 

620 

621 # Build the base SQL with pyformat placeholders 

622 sql = f""" 

623 SELECT  

624 id,  

625 data, 

626 metadata, 

627 {vector_expr} {operator} %(p0)s::vector AS distance 

628 FROM {self.schema_name}.{self.table_name} 

629 WHERE data ? %(p1)s -- Check field exists 

630 """ 

631 

632 params: list[Any] = [vector_str, field_name] 

633 

634 # Add filters if provided using the query builder 

635 if filter: 

636 # Query builder will generate pyformat placeholders since we configured it that way 

637 where_clause, filter_params = self.query_builder.build_where_clause(filter, len(params) + 1) 

638 if where_clause: 

639 sql += where_clause 

640 params.extend(filter_params) 

641 

642 # Order by distance and limit 

643 next_param = len(params) 

644 sql += f" ORDER BY distance LIMIT %(p{next_param})s" 

645 params.append(k) 

646 

647 # Build param dict for psycopg2 

648 param_dict = {} 

649 for i, param in enumerate(params): 

650 param_dict[f"p{i}"] = param 

651 

652 df = self.db.query(sql, param_dict) 

653 

654 # Convert results 

655 results = [] 

656 for _, row in df.iterrows(): 

657 record = self._row_to_record(row) 

658 

659 # Calculate similarity score from distance 

660 distance = row["distance"] 

661 if metric_str in ["cosine", "cosine_similarity"]: 

662 score = 1.0 - distance # Cosine distance to similarity 

663 elif metric_str in ["euclidean", "l2"]: 

664 score = 1.0 / (1.0 + distance) # Convert distance to similarity 

665 elif metric_str in ["inner_product", "dot_product"]: 

666 score = -distance # Negative because pgvector uses negative for descending 

667 else: 

668 score = -distance # Default: lower distance = better 

669 

670 result = VectorSearchResult( 

671 record=record, 

672 score=float(score), 

673 vector_field=field_name 

674 ) 

675 results.append(result) 

676 

677 return results 

678 

679 def has_vector_support(self) -> bool: 

680 """Check if this database has vector support enabled. 

681  

682 Returns: 

683 True if vector operations are supported 

684 """ 

685 return self._vector_enabled 

686 

687 def enable_vector_support(self) -> bool: 

688 """Enable vector support for this database if possible. 

689  

690 Returns: 

691 True if vector support is now enabled 

692 """ 

693 if self._vector_enabled: 

694 return True 

695 

696 self._detect_vector_support() 

697 return self._vector_enabled 

698 

699 def bulk_embed_and_store( 

700 self, 

701 records: list[Record], 

702 text_field: str | list[str], 

703 vector_field: str = "embedding", 

704 embedding_fn: Any = None, 

705 batch_size: int = 100, 

706 model_name: str | None = None, 

707 model_version: str | None = None, 

708 ) -> list[str]: 

709 """Embed text fields and store vectors with records (stub for abstract requirement). 

710  

711 This is a placeholder implementation to satisfy the abstract method requirement. 

712 Full implementation would require actual embedding function. 

713 """ 

714 raise NotImplementedError("bulk_embed_and_store requires an embedding function") 

715 

716 

717# Global pool manager instance for async PostgreSQL connections 

718_pool_manager = ConnectionPoolManager[asyncpg.Pool]() 

719 

720 

721class AsyncPostgresDatabase( 

722 AsyncDatabase, 

723 VectorOperationsMixin, 

724 ConfigurableBase, 

725 PostgresBaseConfig, 

726 PostgresTableManager, 

727 PostgresVectorSupport, 

728 PostgresConnectionValidator, 

729 PostgresErrorHandler, 

730): 

731 """Native async PostgreSQL database backend with vector support and event loop-aware connection pooling.""" 

732 

733 def __init__(self, config: dict[str, Any] | None = None): 

734 """Initialize async PostgreSQL database.""" 

735 super().__init__(config) 

736 

737 # Parse configuration using mixin 

738 table_name, schema_name, conn_config = self._parse_postgres_config(config or {}) 

739 self._init_postgres_attributes(table_name, schema_name) 

740 

741 # Extract pool configuration 

742 self._pool_config = PostgresPoolConfig.from_dict(conn_config) 

743 self._pool: asyncpg.Pool | None = None 

744 

745 @classmethod 

746 def from_config(cls, config: dict) -> AsyncPostgresDatabase: 

747 """Create from config dictionary.""" 

748 return cls(config) 

749 

750 async def connect(self) -> None: 

751 """Connect to the database.""" 

752 if self._connected: 

753 return 

754 

755 # Get or create pool for current event loop 

756 from ..pooling import BasePoolConfig 

757 self._pool = await _pool_manager.get_pool( 

758 self._pool_config, 

759 cast("Callable[[BasePoolConfig], Awaitable[Any]]", create_asyncpg_pool), 

760 validate_asyncpg_pool 

761 ) 

762 

763 # Initialize query builder 

764 self.query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

765 

766 # Ensure table exists 

767 await self._ensure_table() 

768 

769 # Check and enable vector support if requested 

770 if self.vector_enabled: 

771 await self._detect_vector_support() 

772 

773 self._connected = True 

774 self.log_operation("connect", f"Connected to table: {self.schema_name}.{self.table_name}") 

775 

776 async def close(self) -> None: 

777 """Close the database connection and properly close the pool.""" 

778 if self._connected: 

779 # Properly close the pool if we have one 

780 if self._pool: 

781 try: 

782 await self._pool.close() 

783 except Exception as e: 

784 logger.warning(f"Error closing connection pool: {e}") 

785 self._pool = None 

786 self._connected = False 

787 

788 def _initialize(self) -> None: 

789 """Initialize is handled in connect.""" 

790 pass 

791 

792 async def _ensure_table(self) -> None: 

793 """Ensure the records table exists.""" 

794 if not self._pool: 

795 raise RuntimeError("Database not connected. Call connect() first.") 

796 

797 create_table_sql = self.get_create_table_sql(self.schema_name, self.table_name) 

798 

799 async with self._pool.acquire() as conn: 

800 await conn.execute(create_table_sql) 

801 

802 async def _detect_vector_support(self) -> None: 

803 """Detect and enable vector support if pgvector is available.""" 

804 from .postgres_vector import check_pgvector_extension, install_pgvector_extension 

805 

806 async with self._pool.acquire() as conn: 

807 # Check if pgvector is available 

808 if await check_pgvector_extension(conn): 

809 self._vector_enabled = True 

810 logger.info("pgvector extension detected and enabled") 

811 else: 

812 # Try to install it 

813 if await install_pgvector_extension(conn): 

814 self._vector_enabled = True 

815 logger.info("pgvector extension installed and enabled") 

816 else: 

817 logger.debug("pgvector extension not available") 

818 

819 async def _ensure_vector_column(self, field_name: str, dimensions: int) -> None: 

820 """Ensure a vector column exists for the given field. 

821  

822 Args: 

823 field_name: Name of the vector field 

824 dimensions: Number of dimensions 

825 """ 

826 if not self._vector_enabled: 

827 return 

828 

829 column_name = f"vector_{field_name}" 

830 

831 # Check if column already exists 

832 check_sql = """ 

833 SELECT column_name FROM information_schema.columns 

834 WHERE table_schema = $1 AND table_name = $2 AND column_name = $3 

835 """ 

836 

837 async with self._pool.acquire() as conn: 

838 existing = await conn.fetchval(check_sql, self.schema_name, self.table_name, column_name) 

839 

840 if not existing: 

841 # Add vector column 

842 alter_sql = f""" 

843 ALTER TABLE {self.schema_name}.{self.table_name} 

844 ADD COLUMN IF NOT EXISTS {column_name} vector({dimensions}) 

845 """ 

846 try: 

847 await conn.execute(alter_sql) 

848 self._vector_dimensions[field_name] = dimensions 

849 logger.info(f"Added vector column {column_name} with {dimensions} dimensions") 

850 

851 # Create index for the vector column 

852 from .postgres_vector import build_vector_index_sql, get_optimal_index_type 

853 

854 # Get row count for optimal index selection 

855 count_sql = f"SELECT COUNT(*) FROM {self.schema_name}.{self.table_name}" 

856 count = await conn.fetchval(count_sql) 

857 

858 index_type, index_params = get_optimal_index_type(count) 

859 index_sql = build_vector_index_sql( 

860 self.table_name, 

861 self.schema_name, 

862 column_name, 

863 dimensions, 

864 metric="cosine", 

865 index_type=index_type, 

866 index_params=index_params 

867 ) 

868 

869 # Note: IVFFlat requires table to have data before creating index 

870 if count > 0 or index_type != "ivfflat": 

871 await conn.execute(index_sql) 

872 logger.info(f"Created {index_type} index for {column_name}") 

873 

874 except Exception as e: 

875 logger.warning(f"Could not create vector column {column_name}: {e}") 

876 else: 

877 self._vector_dimensions[field_name] = dimensions 

878 

879 def _check_connection(self) -> None: 

880 """Check if async database is connected.""" 

881 self._check_async_connection() 

882 

883 def _record_to_row(self, record: Record, id: str | None = None) -> dict[str, Any]: 

884 """Convert a Record to a database row using common serializer.""" 

885 from .sql_base import SQLRecordSerializer 

886 

887 return { 

888 "id": id or str(uuid.uuid4()), 

889 "data": SQLRecordSerializer.record_to_json(record), 

890 "metadata": json.dumps(record.metadata) if record.metadata else None, 

891 } 

892 

893 def _row_to_record(self, row: asyncpg.Record) -> Record: 

894 """Convert a database row to a Record using the common serializer.""" 

895 from .sql_base import SQLRecordSerializer 

896 

897 # Convert asyncpg.Record to dict format expected by SQLRecordSerializer 

898 data_json = row.get("data", {}) 

899 if not isinstance(data_json, str): 

900 data_json = json.dumps(data_json) 

901 

902 metadata_json = row.get("metadata") 

903 if metadata_json and not isinstance(metadata_json, str): 

904 metadata_json = json.dumps(metadata_json) 

905 

906 # Use the common serializer to reconstruct the record 

907 return SQLRecordSerializer.json_to_record(data_json, metadata_json) 

908 

909 async def create(self, record: Record) -> str: 

910 """Create a new record with vector support.""" 

911 self._check_connection() 

912 

913 # Check for vector fields and ensure columns exist 

914 from ..fields import VectorField 

915 for field_name, field_obj in record.fields.items(): 

916 if isinstance(field_obj, VectorField) and self._vector_enabled: 

917 await self._ensure_vector_column(field_name, field_obj.dimensions) 

918 

919 # Use record's ID if it has one, otherwise generate a new one 

920 id = record.id if record.id else str(uuid.uuid4()) 

921 row = self._record_to_row(record, id) 

922 

923 # Build dynamic SQL based on vector columns present 

924 columns = ["id", "data", "metadata"] 

925 values = [row["id"], row["data"], row["metadata"]] 

926 placeholders = ["$1", "$2", "$3"] 

927 

928 # Add vector columns 

929 param_num = 4 

930 for key, value in row.items(): 

931 if key.startswith("vector_"): 

932 columns.append(key) 

933 values.append(value) 

934 placeholders.append(f"${param_num}") 

935 param_num += 1 

936 

937 sql = f""" 

938 INSERT INTO {self.schema_name}.{self.table_name} ({', '.join(columns)}) 

939 VALUES ({', '.join(placeholders)}) 

940 """ 

941 

942 async with self._pool.acquire() as conn: 

943 await conn.execute(sql, *values) 

944 

945 return id 

946 

947 async def read(self, id: str) -> Record | None: 

948 """Read a record by ID.""" 

949 self._check_connection() 

950 sql = f""" 

951 SELECT id, data, metadata 

952 FROM {self.schema_name}.{self.table_name} 

953 WHERE id = $1 

954 """ 

955 

956 async with self._pool.acquire() as conn: 

957 row = await conn.fetchrow(sql, id) 

958 

959 if not row: 

960 return None 

961 

962 return self._row_to_record(row) 

963 

964 async def update(self, id: str, record: Record) -> bool: 

965 """Update an existing record.""" 

966 self._check_connection() 

967 row = self._record_to_row(record, id) 

968 

969 sql = f""" 

970 UPDATE {self.schema_name}.{self.table_name} 

971 SET data = $2, metadata = $3, updated_at = CURRENT_TIMESTAMP 

972 WHERE id = $1 

973 """ 

974 

975 async with self._pool.acquire() as conn: 

976 result = await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

977 

978 # Returns UPDATE n where n is rows affected 

979 return result.split()[-1] != "0" 

980 

981 async def delete(self, id: str) -> bool: 

982 """Delete a record by ID.""" 

983 self._check_connection() 

984 sql = f""" 

985 DELETE FROM {self.schema_name}.{self.table_name} 

986 WHERE id = $1 

987 """ 

988 

989 async with self._pool.acquire() as conn: 

990 result = await conn.execute(sql, id) 

991 

992 # Returns DELETE n where n is rows affected 

993 return result.split()[-1] != "0" 

994 

995 async def exists(self, id: str) -> bool: 

996 """Check if a record exists.""" 

997 self._check_connection() 

998 sql = f""" 

999 SELECT 1 FROM {self.schema_name}.{self.table_name} 

1000 WHERE id = $1 

1001 LIMIT 1 

1002 """ 

1003 

1004 async with self._pool.acquire() as conn: 

1005 row = await conn.fetchrow(sql, id) 

1006 

1007 return row is not None 

1008 

1009 async def upsert(self, id_or_record: str | Record, record: Record | None = None) -> str: 

1010 """Update or insert a record. 

1011  

1012 Can be called as: 

1013 - upsert(id, record) - explicit ID and record 

1014 - upsert(record) - extract ID from record using Record's built-in logic 

1015 """ 

1016 self._check_connection() 

1017 

1018 # Determine ID and record based on arguments 

1019 if isinstance(id_or_record, str): 

1020 id = id_or_record 

1021 if record is None: 

1022 raise ValueError("Record required when ID is provided") 

1023 else: 

1024 record = id_or_record 

1025 id = record.id 

1026 if id is None: 

1027 import uuid # type: ignore[unreachable] 

1028 id = str(uuid.uuid4()) 

1029 record.storage_id = id 

1030 

1031 row = self._record_to_row(record, id) 

1032 

1033 sql = f""" 

1034 INSERT INTO {self.schema_name}.{self.table_name} (id, data, metadata) 

1035 VALUES ($1, $2, $3) 

1036 ON CONFLICT (id) DO UPDATE 

1037 SET data = EXCLUDED.data, metadata = EXCLUDED.metadata, updated_at = CURRENT_TIMESTAMP 

1038 """ 

1039 

1040 async with self._pool.acquire() as conn: 

1041 await conn.execute(sql, row["id"], row["data"], row["metadata"]) 

1042 

1043 return id 

1044 

1045 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

1046 """Search for records matching the query.""" 

1047 self._check_connection() 

1048 

1049 # Initialize query builder if not already done 

1050 if not hasattr(self, 'query_builder'): 

1051 self.query_builder = SQLQueryBuilder( 

1052 self.table_name, self.schema_name, dialect="postgres" 

1053 ) 

1054 

1055 # Handle ComplexQuery with native SQL support 

1056 if isinstance(query, ComplexQuery): 

1057 sql, params = self.query_builder.build_complex_search_query(query) 

1058 else: 

1059 sql, params = self.query_builder.build_search_query(query) 

1060 

1061 # Execute query with asyncpg (already uses positional parameters) 

1062 async with self._pool.acquire() as conn: 

1063 rows = await conn.fetch(sql, *params) 

1064 

1065 # Convert to records 

1066 records = [] 

1067 for row in rows: 

1068 record = self._row_to_record(row) 

1069 

1070 # Populate storage_id from database ID 

1071 record.storage_id = str(row['id']) 

1072 

1073 # Apply field projection if specified 

1074 if query.fields: 

1075 record = record.project(query.fields) 

1076 

1077 records.append(record) 

1078 

1079 return records 

1080 

1081 async def _count_all(self) -> int: 

1082 """Count all records in the database.""" 

1083 self._check_connection() 

1084 sql = f"SELECT COUNT(*) as count FROM {self.schema_name}.{self.table_name}" 

1085 

1086 async with self._pool.acquire() as conn: 

1087 row = await conn.fetchrow(sql) 

1088 

1089 return row["count"] if row else 0 

1090 

1091 async def clear(self) -> int: 

1092 """Clear all records from the database.""" 

1093 self._check_connection() 

1094 # Get count first 

1095 count = await self._count_all() 

1096 

1097 # Delete all records 

1098 sql = f"TRUNCATE TABLE {self.schema_name}.{self.table_name}" 

1099 

1100 async with self._pool.acquire() as conn: 

1101 await conn.execute(sql) 

1102 

1103 return count 

1104 

1105 async def create_batch(self, records: list[Record]) -> list[str]: 

1106 """Create multiple records efficiently using a single query. 

1107  

1108 Uses multi-value INSERT with RETURNING for better performance. 

1109  

1110 Args: 

1111 records: List of records to create 

1112  

1113 Returns: 

1114 List of created record IDs 

1115 """ 

1116 if not records: 

1117 return [] 

1118 

1119 self._check_connection() 

1120 

1121 # Create a query builder for PostgreSQL 

1122 from .sql_base import SQLQueryBuilder 

1123 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1124 

1125 # Use the shared batch create query builder 

1126 query, params, ids = query_builder.build_batch_create_query(records) 

1127 

1128 # Execute the batch insert with RETURNING 

1129 async with self._pool.acquire() as conn: 

1130 rows = await conn.fetch(query, *params) 

1131 

1132 # Return the actual inserted IDs from RETURNING clause 

1133 if rows: 

1134 return [row["id"] for row in rows] 

1135 return ids # Fallback to generated IDs 

1136 

1137 async def delete_batch(self, ids: list[str]) -> list[bool]: 

1138 """Delete multiple records efficiently using a single query. 

1139  

1140 Uses single DELETE with IN clause and RETURNING for verification. 

1141  

1142 Args: 

1143 ids: List of record IDs to delete 

1144  

1145 Returns: 

1146 List of success flags for each deletion 

1147 """ 

1148 if not ids: 

1149 return [] 

1150 

1151 self._check_connection() 

1152 

1153 # Create a query builder for PostgreSQL 

1154 from .sql_base import SQLQueryBuilder 

1155 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1156 

1157 # Use the shared batch delete query builder 

1158 query, params = query_builder.build_batch_delete_query(ids) 

1159 

1160 # Execute the batch delete with RETURNING 

1161 async with self._pool.acquire() as conn: 

1162 rows = await conn.fetch(query, *params) 

1163 

1164 # Convert returned rows to set of deleted IDs 

1165 deleted_ids = {row["id"] for row in rows} 

1166 

1167 # Return results for each deletion 

1168 results = [] 

1169 for id in ids: 

1170 results.append(id in deleted_ids) 

1171 

1172 return results 

1173 

1174 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

1175 """Update multiple records efficiently using a single query. 

1176  

1177 Uses PostgreSQL's CASE expressions for batch updates with native asyncpg. 

1178  

1179 Args: 

1180 updates: List of (id, record) tuples to update 

1181  

1182 Returns: 

1183 List of success flags for each update 

1184 """ 

1185 if not updates: 

1186 return [] 

1187 

1188 self._check_connection() 

1189 

1190 # Create a query builder for PostgreSQL 

1191 from .sql_base import SQLQueryBuilder 

1192 query_builder = SQLQueryBuilder(self.table_name, self.schema_name, dialect="postgres") 

1193 

1194 # Use the shared batch update query builder 

1195 # It already produces positional parameters ($1, $2) for PostgreSQL 

1196 query, params = query_builder.build_batch_update_query(updates) 

1197 

1198 # Add RETURNING clause for PostgreSQL to get updated IDs 

1199 query = query.rstrip() + " RETURNING id" 

1200 

1201 # Execute the batch update 

1202 async with self._pool.acquire() as conn: 

1203 rows = await conn.fetch(query, *params) 

1204 

1205 # Convert returned rows to set of updated IDs 

1206 updated_ids = {row["id"] for row in rows} 

1207 

1208 # Return results for each update 

1209 results = [] 

1210 for record_id, _ in updates: 

1211 results.append(record_id in updated_ids) 

1212 

1213 return results 

1214 

1215 async def vector_search( 

1216 self, 

1217 query_vector: np.ndarray | list[float] | VectorField, 

1218 field_name: str, 

1219 k: int = 10, 

1220 filter: Query | None = None, 

1221 metric: DistanceMetric | str = "cosine" 

1222 ) -> list[VectorSearchResult]: 

1223 """Search for similar vectors using PostgreSQL pgvector. 

1224  

1225 Args: 

1226 query_vector: Query vector (numpy array, list, or VectorField) 

1227 field_name: Name of vector field to search 

1228 limit: Maximum number of results 

1229 filters: Optional filters to apply 

1230 metric: Distance metric to use 

1231  

1232 Returns: 

1233 List of VectorSearchResult objects 

1234 """ 

1235 if not self._vector_enabled: 

1236 raise RuntimeError("Vector search not available - pgvector not installed") 

1237 

1238 self._check_connection() 

1239 

1240 from ..fields import VectorField 

1241 from ..vector.types import DistanceMetric, VectorSearchResult 

1242 from .postgres_vector import format_vector_for_postgres, get_vector_operator 

1243 

1244 # Convert query vector to proper format 

1245 if isinstance(query_vector, VectorField): 

1246 vector_str = format_vector_for_postgres(query_vector.value) 

1247 else: 

1248 vector_str = format_vector_for_postgres(query_vector) 

1249 

1250 # Get the appropriate operator 

1251 if isinstance(metric, DistanceMetric): 

1252 metric_str = metric.value 

1253 else: 

1254 metric_str = str(metric).lower() 

1255 operator = get_vector_operator(metric_str) 

1256 

1257 vector_column = f"vector_{field_name}" 

1258 

1259 # Build query 

1260 sql = f""" 

1261 SELECT id, data, metadata, {vector_column}, 

1262 {vector_column} {operator} $1::vector AS distance 

1263 FROM {self.schema_name}.{self.table_name} 

1264 WHERE {vector_column} IS NOT NULL 

1265 """ 

1266 

1267 params = [vector_str] 

1268 param_num = 2 

1269 

1270 # Add filters if provided using the query builder 

1271 if filter: 

1272 # First get the where clause from query builder 

1273 where_clause, filter_params = self.query_builder.build_where_clause(filter, param_num) 

1274 if where_clause: 

1275 # Convert %s placeholders to $N for asyncpg 

1276 for param in filter_params: 

1277 where_clause = where_clause.replace("%s", f"${param_num}", 1) 

1278 params.append(param) 

1279 param_num += 1 

1280 sql += where_clause 

1281 

1282 # Order by distance and limit 

1283 sql += f""" 

1284 ORDER BY distance 

1285 LIMIT {k} 

1286 """ 

1287 

1288 # Execute query 

1289 async with self._pool.acquire() as conn: 

1290 rows = await conn.fetch(sql, *params) 

1291 

1292 # Convert to VectorSearchResult objects 

1293 results = [] 

1294 for row in rows: 

1295 record = self._row_to_record(row) 

1296 

1297 # Convert distance to similarity score (1 - normalized_distance for cosine) 

1298 distance = float(row['distance']) 

1299 if metric_str == "cosine": 

1300 score = 1.0 - min(distance, 2.0) / 2.0 # Normalize cosine distance [0,2] to similarity [0,1] 

1301 elif metric_str in ["euclidean", "l2"]: 

1302 score = 1.0 / (1.0 + distance) # Convert distance to similarity 

1303 else: 

1304 score = 1.0 - distance # Generic conversion 

1305 

1306 result = VectorSearchResult( 

1307 record=record, 

1308 score=score, 

1309 vector_field=field_name, 

1310 metadata={"distance": distance, "metric": metric_str} 

1311 ) 

1312 results.append(result) 

1313 

1314 return results 

1315 

1316 async def enable_vector_support(self) -> bool: 

1317 """Enable vector support for this database. 

1318  

1319 Returns: 

1320 True if vector support is enabled 

1321 """ 

1322 if self._vector_enabled: 

1323 return True 

1324 

1325 await self._detect_vector_support() 

1326 return self._vector_enabled 

1327 

1328 async def has_vector_support(self) -> bool: 

1329 """Check if this database has vector support enabled. 

1330  

1331 Returns: 

1332 True if vector support is available 

1333 """ 

1334 return self._vector_enabled 

1335 

1336 async def bulk_embed_and_store( 

1337 self, 

1338 records: list[Record], 

1339 text_field: str | list[str], 

1340 vector_field: str, 

1341 embedding_fn: Any | None = None, 

1342 batch_size: int = 100, 

1343 model_name: str | None = None, 

1344 model_version: str | None = None, 

1345 ) -> list[str]: 

1346 """Embed text fields and store vectors with records. 

1347  

1348 This is a placeholder implementation. In a real scenario, you would: 

1349 1. Extract text from the specified fields 

1350 2. Call the embedding function to generate vectors 

1351 3. Store the vectors alongside the records 

1352  

1353 Args: 

1354 records: Records to process 

1355 text_field: Field name(s) containing text to embed 

1356 vector_field: Field name to store vectors in 

1357 embedding_fn: Function to generate embeddings 

1358 batch_size: Number of records to process at once 

1359 model_name: Name of the embedding model 

1360 model_version: Version of the embedding model 

1361  

1362 Returns: 

1363 List of record IDs that were processed 

1364 """ 

1365 if not embedding_fn: 

1366 raise ValueError("embedding_fn is required for bulk_embed_and_store") 

1367 

1368 from ..fields import VectorField 

1369 

1370 processed_ids = [] 

1371 

1372 # Process in batches 

1373 for i in range(0, len(records), batch_size): 

1374 batch = records[i:i + batch_size] 

1375 

1376 # Extract texts 

1377 texts = [] 

1378 for record in batch: 

1379 if isinstance(text_field, list): 

1380 text = " ".join(str(record.fields.get(f, {}).value) for f in text_field if f in record.fields) 

1381 else: 

1382 text = str(record.fields.get(text_field, {}).value) if text_field in record.fields else "" 

1383 texts.append(text) 

1384 

1385 # Generate embeddings 

1386 if texts: 

1387 embeddings = await embedding_fn(texts) 

1388 

1389 # Store vectors with records 

1390 for j, record in enumerate(batch): 

1391 if j < len(embeddings): 

1392 vector = embeddings[j] 

1393 

1394 # Add vector field to record 

1395 record.fields[vector_field] = VectorField( 

1396 name=vector_field, 

1397 value=vector, 

1398 dimensions=len(vector) if hasattr(vector, '__len__') else None, 

1399 source_field=text_field if isinstance(text_field, str) else ",".join(text_field), 

1400 model_name=model_name, 

1401 model_version=model_version, 

1402 ) 

1403 

1404 # Create or update record 

1405 if record.has_storage_id(): 

1406 if record.storage_id is None: 

1407 raise ValueError("Record has_storage_id() returned True but storage_id is None") 

1408 await self.update(record.storage_id, record) 

1409 else: 

1410 record_id = await self.create(record) 

1411 record.storage_id = record_id 

1412 

1413 if record.storage_id is None: 

1414 raise ValueError("Record storage_id is None after create/update") 

1415 processed_ids.append(record.storage_id) 

1416 

1417 return processed_ids 

1418 

1419 async def create_vector_index( 

1420 self, 

1421 vector_field: str, 

1422 dimensions: int, 

1423 metric: DistanceMetric | str = "cosine", 

1424 index_type: str = "ivfflat", 

1425 lists: int | None = None, 

1426 ) -> bool: 

1427 """Create a vector index for efficient similarity search. 

1428  

1429 Args: 

1430 vector_field: Name of the vector field to index 

1431 dimensions: Number of dimensions in the vectors 

1432 metric: Distance metric for the index 

1433 index_type: Type of index (ivfflat, hnsw) 

1434 lists: Number of lists for IVFFlat index 

1435  

1436 Returns: 

1437 True if index was created successfully 

1438 """ 

1439 from .postgres_vector import ( 

1440 build_vector_column_expression, 

1441 build_vector_index_sql, 

1442 get_optimal_index_type, 

1443 get_vector_count_sql, 

1444 ) 

1445 

1446 self._check_connection() 

1447 

1448 if not self._vector_enabled: 

1449 return False 

1450 

1451 # Determine optimal parameters if not provided 

1452 if not lists and index_type == "ivfflat": 

1453 # Count vectors to determine optimal lists 

1454 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field) 

1455 async with self._pool.acquire() as conn: 

1456 count = await conn.fetchval(count_sql) or 0 

1457 _, params = get_optimal_index_type(count) 

1458 lists = params.get("lists", 100) 

1459 

1460 # Convert metric enum to string if needed 

1461 if hasattr(metric, 'value'): 

1462 metric_str = metric.value 

1463 else: 

1464 metric_str = str(metric).lower() 

1465 

1466 # Build vector column expression for index 

1467 column_expr = build_vector_column_expression(vector_field, dimensions, for_index=True) 

1468 

1469 # Build index SQL - pass field_name for proper index naming 

1470 index_sql = build_vector_index_sql( 

1471 table_name=self.table_name, 

1472 schema_name=self.schema_name, 

1473 column_name=column_expr, 

1474 dimensions=dimensions, 

1475 metric=metric_str, 

1476 index_type=index_type, 

1477 index_params={"lists": lists} if lists else None, 

1478 field_name=vector_field 

1479 ) 

1480 

1481 # Create the index 

1482 try: 

1483 logger.debug(f"Creating vector index with SQL: {index_sql}") 

1484 async with self._pool.acquire() as conn: 

1485 await conn.execute(index_sql) 

1486 return True 

1487 except Exception as e: 

1488 logger.warning(f"Failed to create vector index: {e}") 

1489 logger.debug(f"Index SQL was: {index_sql}") 

1490 return False 

1491 

1492 async def drop_vector_index(self, vector_field: str, metric: str = "cosine") -> bool: 

1493 """Drop a vector index. 

1494  

1495 Args: 

1496 vector_field: Name of the vector field 

1497 metric: Distance metric used in the index 

1498  

1499 Returns: 

1500 True if index was dropped successfully 

1501 """ 

1502 from .postgres_vector import get_vector_index_name 

1503 

1504 self._check_connection() 

1505 

1506 index_name = get_vector_index_name(self.table_name, vector_field, metric) 

1507 

1508 try: 

1509 async with self._pool.acquire() as conn: 

1510 await conn.execute(f"DROP INDEX IF EXISTS {self.schema_name}.{index_name}") 

1511 return True 

1512 except Exception as e: 

1513 logger.warning(f"Failed to drop vector index: {e}") 

1514 return False 

1515 

1516 async def get_vector_index_stats(self, vector_field: str) -> dict[str, Any]: 

1517 """Get statistics about a vector field and its index. 

1518  

1519 Args: 

1520 vector_field: Name of the vector field 

1521  

1522 Returns: 

1523 Dictionary with index statistics 

1524 """ 

1525 from .postgres_vector import get_index_check_sql, get_vector_count_sql 

1526 

1527 self._check_connection() 

1528 

1529 stats = { 

1530 "field": vector_field, 

1531 "indexed": False, 

1532 "vector_count": 0, 

1533 } 

1534 

1535 try: 

1536 async with self._pool.acquire() as conn: 

1537 # Count vectors 

1538 count_sql = get_vector_count_sql(self.schema_name, self.table_name, vector_field) 

1539 stats["vector_count"] = await conn.fetchval(count_sql) or 0 

1540 

1541 # Check for index 

1542 index_sql, params = get_index_check_sql(self.schema_name, self.table_name, vector_field) 

1543 stats["indexed"] = await conn.fetchval(index_sql, *params) or False 

1544 except Exception as e: 

1545 logger.warning(f"Failed to get vector index stats: {e}") 

1546 

1547 return stats 

1548 

1549 async def stream_read( 

1550 self, 

1551 query: Query | None = None, 

1552 config: StreamConfig | None = None 

1553 ) -> AsyncIterator[Record]: 

1554 """Stream records from PostgreSQL using cursor.""" 

1555 self._check_connection() 

1556 config = config or StreamConfig() 

1557 

1558 # Build SQL query 

1559 sql = f"SELECT id, data, metadata FROM {self.schema_name}.{self.table_name}" 

1560 params = [] 

1561 

1562 if query and query.filters: 

1563 where_clauses = [] 

1564 param_count = 0 

1565 

1566 for filter in query.filters: 

1567 param_count += 1 

1568 field_path = f"data->>'{filter.field}'" 

1569 

1570 if filter.operator == Operator.EQ: 

1571 where_clauses.append(f"{field_path} = ${param_count}") 

1572 params.append(str(filter.value)) 

1573 

1574 if where_clauses: 

1575 sql += " WHERE " + " AND ".join(where_clauses) 

1576 

1577 # Use cursor for efficient streaming 

1578 async with self._pool.acquire() as conn: 

1579 async with conn.transaction(): 

1580 cursor = await conn.cursor(sql, *params) 

1581 

1582 batch = [] 

1583 async for row in cursor: 

1584 record = self._row_to_record(row) 

1585 if query and query.fields: 

1586 record = record.project(query.fields) 

1587 

1588 batch.append(record) 

1589 

1590 if len(batch) >= config.batch_size: 

1591 for rec in batch: 

1592 yield rec 

1593 batch = [] 

1594 

1595 # Yield remaining records 

1596 for rec in batch: 

1597 yield rec 

1598 

1599 async def stream_write( 

1600 self, 

1601 records: AsyncIterator[Record], 

1602 config: StreamConfig | None = None 

1603 ) -> StreamResult: 

1604 """Stream records into PostgreSQL using batch inserts.""" 

1605 self._check_connection() 

1606 config = config or StreamConfig() 

1607 result = StreamResult() 

1608 start_time = time.time() 

1609 quitting = False 

1610 

1611 batch = [] 

1612 async for record in records: 

1613 batch.append(record) 

1614 

1615 if len(batch) >= config.batch_size: 

1616 # Write batch with graceful fallback 

1617 # Use lambda wrapper for _write_batch 

1618 async def batch_func(b): 

1619 await self._write_batch(b) 

1620 return [r.id for r in b] 

1621 

1622 continue_processing = await async_process_batch_with_fallback( 

1623 batch, 

1624 batch_func, 

1625 self.create, 

1626 result, 

1627 config 

1628 ) 

1629 

1630 if not continue_processing: 

1631 quitting = True 

1632 break 

1633 

1634 batch = [] 

1635 

1636 # Write remaining batch 

1637 if batch and not quitting: 

1638 async def batch_func(b): 

1639 await self._write_batch(b) 

1640 return [r.id for r in b] 

1641 

1642 await async_process_batch_with_fallback( 

1643 batch, 

1644 batch_func, 

1645 self.create, 

1646 result, 

1647 config 

1648 ) 

1649 

1650 result.duration = time.time() - start_time 

1651 return result 

1652 

1653 async def _write_batch(self, records: list[Record]) -> list[str]: 

1654 """Write a batch of records using COPY for performance. 

1655  

1656 Returns: 

1657 List of created record IDs 

1658 """ 

1659 if not records: 

1660 return [] 

1661 

1662 # Prepare data for COPY 

1663 rows = [] 

1664 ids = [] 

1665 for record in records: 

1666 row_data = self._record_to_row(record) 

1667 ids.append(row_data["id"]) 

1668 rows.append(( 

1669 row_data["id"], 

1670 row_data["data"], 

1671 row_data["metadata"] 

1672 )) 

1673 

1674 # Use COPY for efficient bulk insert 

1675 async with self._pool.acquire() as conn: 

1676 await conn.copy_records_to_table( 

1677 f"{self.schema_name}.{self.table_name}", 

1678 records=rows, 

1679 columns=["id", "data", "metadata"] 

1680 ) 

1681 

1682 return ids