Coverage for src/dataknobs_data/backends/sqlite.py: 39%

304 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-08 15:25 -0600

1"""SQLite backend implementation with sync and async support.""" 

2 

3from __future__ import annotations 

4 

5import json 

6import logging 

7import sqlite3 

8import time 

9import uuid 

10from pathlib import Path 

11from typing import Any, TYPE_CHECKING 

12 

13import numpy as np 

14from dataknobs_config import ConfigurableBase 

15 

16from ..database import SyncDatabase 

17from ..query import Query 

18from ..query_logic import ComplexQuery 

19from ..records import Record 

20from ..vector.bulk_embed_mixin import BulkEmbedMixin 

21from ..vector.mixins import VectorOperationsMixin 

22from ..vector.python_vector_search import PythonVectorSearchMixin 

23from .sql_base import SQLQueryBuilder, SQLRecordSerializer, SQLTableManager 

24from .sqlite_mixins import SQLiteVectorSupport 

25from .vector_config_mixin import VectorConfigMixin 

26 

27if TYPE_CHECKING: 

28 from collections.abc import Iterator 

29 from ..streaming import StreamConfig, StreamResult 

30 from ..vector.types import DistanceMetric, VectorSearchResult 

31 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36class SyncSQLiteDatabase( # type: ignore[misc] 

37 SyncDatabase, 

38 ConfigurableBase, 

39 VectorConfigMixin, 

40 PythonVectorSearchMixin, # Provides python_vector_search_sync 

41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store 

42 VectorOperationsMixin, 

43 SQLiteVectorSupport, 

44 SQLRecordSerializer, # Use the standard SQL serializer 

45): 

46 """Synchronous SQLite database backend.""" 

47 

48 def __init__(self, config: dict[str, Any] | None = None): 

49 """Initialize SQLite database. 

50  

51 Args: 

52 config: Configuration with the following optional keys: 

53 - path: Database file path (default: ":memory:") 

54 - table: Table name (default: "records") 

55 - timeout: Connection timeout in seconds (default: 5.0) 

56 - check_same_thread: Allow sharing across threads (default: False) 

57 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: None) 

58 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: None) 

59 - vector_enabled: Enable vector support (default: False) 

60 - vector_metric: Distance metric for vector search (default: "cosine") 

61 """ 

62 super().__init__(config) 

63 SQLiteVectorSupport.__init__(self) 

64 

65 # Parse vector configuration using the mixin 

66 self._parse_vector_config(config) 

67 

68 self.db_path = self.config.get("path", ":memory:") 

69 self.table_name = self.config.get("table", "records") 

70 self.timeout = self.config.get("timeout", 5.0) 

71 self.check_same_thread = self.config.get("check_same_thread", False) 

72 self.journal_mode = self.config.get("journal_mode") 

73 self.synchronous = self.config.get("synchronous") 

74 

75 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark") 

76 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite") 

77 

78 self.conn: sqlite3.Connection | None = None 

79 self._connected = False 

80 

81 @classmethod 

82 def from_config(cls, config: dict) -> SyncSQLiteDatabase: 

83 """Create from config dictionary.""" 

84 return cls(config) 

85 

86 def connect(self) -> None: 

87 """Connect to the SQLite database.""" 

88 if self._connected: 

89 return 

90 

91 # Create directory if needed for file-based database 

92 if self.db_path != ":memory:": 

93 db_file = Path(self.db_path) 

94 db_file.parent.mkdir(parents=True, exist_ok=True) 

95 

96 # Connect to database 

97 self.conn = sqlite3.connect( 

98 self.db_path, 

99 timeout=self.timeout, 

100 check_same_thread=self.check_same_thread 

101 ) 

102 

103 # Enable row factory for dict-like access 

104 self.conn.row_factory = sqlite3.Row 

105 

106 # Configure SQLite for better performance 

107 self._configure_sqlite() 

108 

109 # Create table if it doesn't exist 

110 self._ensure_table() 

111 

112 self._connected = True 

113 logger.info(f"Connected to SQLite database: {self.db_path}") 

114 

115 def close(self) -> None: 

116 """Close the database connection.""" 

117 if self.conn: 

118 self.conn.close() 

119 self.conn = None 

120 self._connected = False 

121 logger.info(f"Disconnected from SQLite database: {self.db_path}") 

122 

123 def _configure_sqlite(self) -> None: 

124 """Configure SQLite settings for performance.""" 

125 if not self.conn: 

126 return 

127 

128 cursor = self.conn.cursor() 

129 

130 # Set journal mode if specified 

131 if self.journal_mode: 

132 cursor.execute(f"PRAGMA journal_mode = {self.journal_mode}") 

133 logger.debug(f"Set journal_mode to {self.journal_mode}") 

134 

135 # Set synchronous mode if specified 

136 if self.synchronous: 

137 cursor.execute(f"PRAGMA synchronous = {self.synchronous}") 

138 logger.debug(f"Set synchronous to {self.synchronous}") 

139 

140 # Enable foreign keys 

141 cursor.execute("PRAGMA foreign_keys = ON") 

142 

143 # Optimize for performance 

144 cursor.execute("PRAGMA temp_store = MEMORY") 

145 cursor.execute("PRAGMA mmap_size = 30000000000") 

146 

147 cursor.close() 

148 

149 def _ensure_table(self) -> None: 

150 """Ensure the table exists.""" 

151 if not self.conn: 

152 raise RuntimeError("Database not connected. Call connect() first.") 

153 

154 cursor = self.conn.cursor() 

155 cursor.executescript(self.table_manager.get_create_table_sql()) 

156 self.conn.commit() 

157 cursor.close() 

158 

159 def _check_connection(self) -> None: 

160 """Check if database is connected.""" 

161 if not self._connected or not self.conn: 

162 raise RuntimeError("Database not connected. Call connect() first.") 

163 

164 def create(self, record: Record) -> str: 

165 """Create a new record.""" 

166 self._check_connection() 

167 

168 # Update vector dimensions tracking if needed 

169 if self._has_vector_fields(record): 

170 self._update_vector_dimensions(record) 

171 

172 # Use centralized method to prepare record 

173 record, storage_id = self._prepare_record_for_storage(record) 

174 

175 # Use the standard SQL serializer 

176 data_json = self.record_to_json(record) 

177 metadata_json = json.dumps(record.metadata) if record.metadata else None 

178 

179 # Build insert query for SQLite's standard table structure 

180 query = f"INSERT INTO {self.table_name} (id, data, metadata) VALUES (?, ?, ?)" 

181 params = [storage_id, data_json, metadata_json] 

182 

183 cursor = self.conn.cursor() 

184 

185 try: 

186 cursor.execute(query, params) 

187 self.conn.commit() 

188 return storage_id 

189 except sqlite3.IntegrityError as e: 

190 self.conn.rollback() 

191 raise ValueError(f"Record with ID {record.id} already exists") from e 

192 finally: 

193 cursor.close() 

194 

195 def read(self, id: str) -> Record | None: 

196 """Read a record by ID.""" 

197 self._check_connection() 

198 

199 query, params = self.query_builder.build_read_query(id) 

200 cursor = self.conn.cursor() 

201 

202 try: 

203 cursor.execute(query, params) 

204 row = cursor.fetchone() 

205 

206 if row: 

207 # Use the standard SQL serializer 

208 record = self.row_to_record(dict(row)) 

209 # Use centralized method to prepare record 

210 return self._prepare_record_from_storage(record, id) 

211 return None 

212 finally: 

213 cursor.close() 

214 

215 def update(self, id: str, record: Record) -> bool: 

216 """Update an existing record.""" 

217 self._check_connection() 

218 

219 # Update vector dimensions tracking if needed 

220 if self._has_vector_fields(record): 

221 self._update_vector_dimensions(record) 

222 

223 # Use the standard SQL serializer 

224 data_json = self.record_to_json(record) 

225 metadata_json = json.dumps(record.metadata) if record.metadata else None 

226 

227 # Build update query 

228 query = f"UPDATE {self.table_name} SET data = ?, metadata = ? WHERE id = ?" 

229 params = [data_json, metadata_json, id] 

230 

231 cursor = self.conn.cursor() 

232 

233 try: 

234 cursor.execute(query, params) 

235 self.conn.commit() 

236 return cursor.rowcount > 0 

237 finally: 

238 cursor.close() 

239 

240 def delete(self, id: str) -> bool: 

241 """Delete a record by ID.""" 

242 self._check_connection() 

243 

244 query, params = self.query_builder.build_delete_query(id) 

245 cursor = self.conn.cursor() 

246 

247 try: 

248 cursor.execute(query, params) 

249 self.conn.commit() 

250 return cursor.rowcount > 0 

251 finally: 

252 cursor.close() 

253 

254 def exists(self, id: str) -> bool: 

255 """Check if a record exists.""" 

256 self._check_connection() 

257 

258 query, params = self.query_builder.build_exists_query(id) 

259 cursor = self.conn.cursor() 

260 

261 try: 

262 cursor.execute(query, params) 

263 result = cursor.fetchone() 

264 return result is not None 

265 finally: 

266 cursor.close() 

267 

268 def clear(self) -> int: 

269 """Clear all records from the database.""" 

270 self._check_connection() 

271 

272 cursor = self.conn.cursor() 

273 try: 

274 # Get count before clearing 

275 cursor.execute(f"SELECT COUNT(*) FROM {self.table_manager.table_name}") 

276 count = cursor.fetchone()[0] 

277 

278 # Clear the table 

279 cursor.execute(f"DELETE FROM {self.table_manager.table_name}") 

280 self.conn.commit() 

281 

282 return count 

283 finally: 

284 cursor.close() 

285 

286 def search(self, query: Query | ComplexQuery) -> list[Record]: 

287 """Search for records matching a query.""" 

288 self._check_connection() 

289 

290 # Handle ComplexQuery with native SQL support 

291 if isinstance(query, ComplexQuery): 

292 sql_query, params = self.query_builder.build_complex_search_query(query) 

293 else: 

294 sql_query, params = self.query_builder.build_search_query(query) 

295 

296 cursor = self.conn.cursor() 

297 

298 try: 

299 cursor.execute(sql_query, params) 

300 rows = cursor.fetchall() 

301 

302 records = [] 

303 for row in rows: 

304 row_dict = dict(row) 

305 record = self.row_to_record(row_dict) 

306 

307 # Populate storage_id from database ID 

308 record.storage_id = str(row_dict['id']) 

309 

310 records.append(record) 

311 

312 # Apply field projection if specified 

313 if query.fields: 

314 records = [r.project(query.fields) for r in records] 

315 

316 return records 

317 finally: 

318 cursor.close() 

319 

320 def count(self, query: Query | None = None) -> int: 

321 """Count records matching a query.""" 

322 self._check_connection() 

323 

324 sql_query, params = self.query_builder.build_count_query(query) 

325 cursor = self.conn.cursor() 

326 

327 try: 

328 cursor.execute(sql_query, params) 

329 result = cursor.fetchone() 

330 return result[0] if result else 0 

331 finally: 

332 cursor.close() 

333 

334 def create_batch(self, records: list[Record]) -> list[str]: 

335 """Create multiple records efficiently using a single query. 

336  

337 Uses multi-value INSERT for better performance. 

338 """ 

339 if not records: 

340 return [] 

341 

342 self._check_connection() 

343 

344 # Use the shared batch create query builder 

345 query, params, ids = self.query_builder.build_batch_create_query(records) 

346 

347 cursor = self.conn.cursor() 

348 try: 

349 # Execute the batch insert in a transaction 

350 cursor.execute("BEGIN TRANSACTION") 

351 cursor.execute(query, params) 

352 self.conn.commit() 

353 

354 # Return the generated IDs 

355 return ids 

356 except Exception: 

357 self.conn.rollback() 

358 raise 

359 finally: 

360 cursor.close() 

361 

362 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

363 """Update multiple records efficiently using a single query. 

364  

365 Uses CASE expressions for batch updates, similar to PostgreSQL. 

366 """ 

367 if not updates: 

368 return [] 

369 

370 self._check_connection() 

371 

372 # Use the shared batch update query builder 

373 query, params = self.query_builder.build_batch_update_query(updates) 

374 

375 cursor = self.conn.cursor() 

376 try: 

377 # Execute the batch update in a transaction 

378 cursor.execute("BEGIN TRANSACTION") 

379 cursor.execute(query, params) 

380 self.conn.commit() 

381 

382 # Check which records were actually updated 

383 # SQLite doesn't have RETURNING, so we need to verify each ID 

384 update_ids = [record_id for record_id, _ in updates] 

385 placeholders = ", ".join(["?" for _ in update_ids]) 

386 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

387 cursor.execute(check_query, update_ids) 

388 existing_ids = {row[0] for row in cursor.fetchall()} 

389 

390 # Return results for each update 

391 results = [] 

392 for record_id, _ in updates: 

393 results.append(record_id in existing_ids) 

394 

395 return results 

396 except Exception: 

397 self.conn.rollback() 

398 raise 

399 finally: 

400 cursor.close() 

401 

402 def delete_batch(self, ids: list[str]) -> list[bool]: 

403 """Delete multiple records efficiently using a single query. 

404  

405 Uses single DELETE with IN clause for better performance. 

406 """ 

407 if not ids: 

408 return [] 

409 

410 self._check_connection() 

411 

412 # Check which IDs exist before deletion 

413 placeholders = ", ".join(["?" for _ in ids]) 

414 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

415 

416 cursor = self.conn.cursor() 

417 try: 

418 cursor.execute(check_query, ids) 

419 existing_ids = {row[0] for row in cursor.fetchall()} 

420 

421 # Use the shared batch delete query builder 

422 query, params = self.query_builder.build_batch_delete_query(ids) 

423 

424 # Execute the batch delete in a transaction 

425 cursor.execute("BEGIN TRANSACTION") 

426 cursor.execute(query, params) 

427 self.conn.commit() 

428 

429 # Return results based on which IDs existed 

430 results = [] 

431 for id in ids: 

432 results.append(id in existing_ids) 

433 

434 return results 

435 except Exception: 

436 self.conn.rollback() 

437 raise 

438 finally: 

439 cursor.close() 

440 

441 def _initialize(self) -> None: 

442 """Initialize method - connection setup handled in connect().""" 

443 pass 

444 

445 def _count_all(self) -> int: 

446 """Count all records in the database.""" 

447 self._check_connection() 

448 cursor = self.conn.cursor() 

449 try: 

450 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}") 

451 result = cursor.fetchone() 

452 return result[0] if result else 0 

453 finally: 

454 cursor.close() 

455 

456 def stream_read( 

457 self, 

458 query: Query | None = None, 

459 config: StreamConfig | None = None 

460 ) -> Iterator[Record]: 

461 """Stream records from database.""" 

462 from ..streaming import StreamConfig 

463 

464 config = config or StreamConfig() 

465 query = query or Query() 

466 

467 # Use the existing stream method's logic but yield individual records 

468 offset = 0 

469 while True: 

470 # Fetch a batch 

471 query_copy = query.copy() 

472 query_copy.offset(offset).limit(config.batch_size) 

473 batch = self.search(query_copy) 

474 

475 if not batch: 

476 break 

477 

478 for record in batch: 

479 yield record 

480 

481 offset += len(batch) 

482 

483 # If we got less than batch_size, we're done 

484 if len(batch) < config.batch_size: 

485 break 

486 

487 def stream_write( 

488 self, 

489 records: Iterator[Record], 

490 config: StreamConfig | None = None 

491 ) -> StreamResult: 

492 """Stream records into database.""" 

493 from ..streaming import StreamConfig, StreamResult 

494 

495 config = config or StreamConfig() 

496 batch = [] 

497 total_written = 0 

498 start_time = time.time() 

499 

500 for record in records: 

501 batch.append(record) 

502 

503 if len(batch) >= config.batch_size: 

504 # Write the batch 

505 self.create_batch(batch) 

506 total_written += len(batch) 

507 batch = [] 

508 

509 # Write any remaining records 

510 if batch: 

511 self.create_batch(batch) 

512 total_written += len(batch) 

513 

514 elapsed = time.time() - start_time 

515 

516 return StreamResult( 

517 total_processed=total_written, 

518 successful=total_written, 

519 failed=0, 

520 duration=elapsed, 

521 total_batches=(total_written + config.batch_size - 1) // config.batch_size 

522 ) 

523 

524 # Vector support methods 

525 def has_vector_support(self) -> bool: 

526 """Check if this backend has vector support. 

527  

528 Returns: 

529 False - SQLite has no native vector support, uses Python-based similarity 

530 """ 

531 return False # No native vector support 

532 

533 def enable_vector_support(self) -> bool: 

534 """Enable vector support for this backend. 

535  

536 Returns: 

537 True - Vector support is always available (Python-based) 

538 """ 

539 # SQLite doesn't need any special setup for vector support 

540 # We handle vectors as JSON strings 

541 self.vector_enabled = True 

542 return True 

543 

544 def vector_search( 

545 self, 

546 query_vector: np.ndarray, 

547 field_name: str = "embedding", 

548 k: int = 10, 

549 filter: Query | None = None, 

550 metric: DistanceMetric | None = None, 

551 **kwargs 

552 ) -> list[VectorSearchResult]: 

553 """Perform vector similarity search using Python-based calculations. 

554  

555 Delegates to PythonVectorSearchMixin for the implementation. 

556  

557 Args: 

558 query_vector: Query vector 

559 field_name: Name of the vector field to search 

560 k: Number of results to return 

561 filter: Optional filter conditions 

562 metric: Distance metric (uses instance default if not specified) 

563 **kwargs: Additional arguments for compatibility 

564  

565 Returns: 

566 List of search results with scores 

567 """ 

568 self._check_connection() 

569 

570 # Delegate to the mixin's implementation 

571 return self.python_vector_search_sync( 

572 query_vector=query_vector, 

573 vector_field=field_name, 

574 k=k, 

575 filter=filter, 

576 metric=metric, 

577 **kwargs 

578 ) 

579 

580 def add_vectors( 

581 self, 

582 vectors: list[np.ndarray], 

583 ids: list[str] | None = None, 

584 metadata: list[dict[str, Any]] | None = None, 

585 field_name: str = "embedding", 

586 ) -> list[str]: 

587 """Add vectors to the database. 

588  

589 Args: 

590 vectors: List of vectors to add 

591 ids: Optional list of IDs 

592 metadata: Optional list of metadata dicts 

593 field_name: Name of the vector field 

594  

595 Returns: 

596 List of created record IDs 

597 """ 

598 from collections import OrderedDict 

599 

600 from ..fields import VectorField 

601 

602 # Generate IDs if not provided 

603 if ids is None: 

604 ids = [str(uuid.uuid4()) for _ in vectors] 

605 

606 # Create records with vector fields 

607 records = [] 

608 for i, vector in enumerate(vectors): 

609 # Create vector field 

610 vector_field = VectorField( 

611 name=field_name, 

612 value=vector, 

613 dimensions=len(vector) if isinstance(vector, (list, np.ndarray)) else None 

614 ) 

615 

616 # Create record 

617 record_metadata = metadata[i] if metadata and i < len(metadata) else {} 

618 record = Record( 

619 data=OrderedDict({field_name: vector_field}), 

620 metadata=record_metadata, 

621 storage_id=ids[i] 

622 ) 

623 records.append(record) 

624 

625 # Use batch create for efficiency 

626 return self.create_batch(records)