Coverage for src/dataknobs_data/backends/sqlite.py: 39%
304 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-08 15:25 -0600
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-08 15:25 -0600
1"""SQLite backend implementation with sync and async support."""
3from __future__ import annotations
5import json
6import logging
7import sqlite3
8import time
9import uuid
10from pathlib import Path
11from typing import Any, TYPE_CHECKING
13import numpy as np
14from dataknobs_config import ConfigurableBase
16from ..database import SyncDatabase
17from ..query import Query
18from ..query_logic import ComplexQuery
19from ..records import Record
20from ..vector.bulk_embed_mixin import BulkEmbedMixin
21from ..vector.mixins import VectorOperationsMixin
22from ..vector.python_vector_search import PythonVectorSearchMixin
23from .sql_base import SQLQueryBuilder, SQLRecordSerializer, SQLTableManager
24from .sqlite_mixins import SQLiteVectorSupport
25from .vector_config_mixin import VectorConfigMixin
27if TYPE_CHECKING:
28 from collections.abc import Iterator
29 from ..streaming import StreamConfig, StreamResult
30 from ..vector.types import DistanceMetric, VectorSearchResult
33logger = logging.getLogger(__name__)
36class SyncSQLiteDatabase( # type: ignore[misc]
37 SyncDatabase,
38 ConfigurableBase,
39 VectorConfigMixin,
40 PythonVectorSearchMixin, # Provides python_vector_search_sync
41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store
42 VectorOperationsMixin,
43 SQLiteVectorSupport,
44 SQLRecordSerializer, # Use the standard SQL serializer
45):
46 """Synchronous SQLite database backend."""
48 def __init__(self, config: dict[str, Any] | None = None):
49 """Initialize SQLite database.
51 Args:
52 config: Configuration with the following optional keys:
53 - path: Database file path (default: ":memory:")
54 - table: Table name (default: "records")
55 - timeout: Connection timeout in seconds (default: 5.0)
56 - check_same_thread: Allow sharing across threads (default: False)
57 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: None)
58 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: None)
59 - vector_enabled: Enable vector support (default: False)
60 - vector_metric: Distance metric for vector search (default: "cosine")
61 """
62 super().__init__(config)
63 SQLiteVectorSupport.__init__(self)
65 # Parse vector configuration using the mixin
66 self._parse_vector_config(config)
68 self.db_path = self.config.get("path", ":memory:")
69 self.table_name = self.config.get("table", "records")
70 self.timeout = self.config.get("timeout", 5.0)
71 self.check_same_thread = self.config.get("check_same_thread", False)
72 self.journal_mode = self.config.get("journal_mode")
73 self.synchronous = self.config.get("synchronous")
75 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark")
76 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite")
78 self.conn: sqlite3.Connection | None = None
79 self._connected = False
81 @classmethod
82 def from_config(cls, config: dict) -> SyncSQLiteDatabase:
83 """Create from config dictionary."""
84 return cls(config)
86 def connect(self) -> None:
87 """Connect to the SQLite database."""
88 if self._connected:
89 return
91 # Create directory if needed for file-based database
92 if self.db_path != ":memory:":
93 db_file = Path(self.db_path)
94 db_file.parent.mkdir(parents=True, exist_ok=True)
96 # Connect to database
97 self.conn = sqlite3.connect(
98 self.db_path,
99 timeout=self.timeout,
100 check_same_thread=self.check_same_thread
101 )
103 # Enable row factory for dict-like access
104 self.conn.row_factory = sqlite3.Row
106 # Configure SQLite for better performance
107 self._configure_sqlite()
109 # Create table if it doesn't exist
110 self._ensure_table()
112 self._connected = True
113 logger.info(f"Connected to SQLite database: {self.db_path}")
115 def close(self) -> None:
116 """Close the database connection."""
117 if self.conn:
118 self.conn.close()
119 self.conn = None
120 self._connected = False
121 logger.info(f"Disconnected from SQLite database: {self.db_path}")
123 def _configure_sqlite(self) -> None:
124 """Configure SQLite settings for performance."""
125 if not self.conn:
126 return
128 cursor = self.conn.cursor()
130 # Set journal mode if specified
131 if self.journal_mode:
132 cursor.execute(f"PRAGMA journal_mode = {self.journal_mode}")
133 logger.debug(f"Set journal_mode to {self.journal_mode}")
135 # Set synchronous mode if specified
136 if self.synchronous:
137 cursor.execute(f"PRAGMA synchronous = {self.synchronous}")
138 logger.debug(f"Set synchronous to {self.synchronous}")
140 # Enable foreign keys
141 cursor.execute("PRAGMA foreign_keys = ON")
143 # Optimize for performance
144 cursor.execute("PRAGMA temp_store = MEMORY")
145 cursor.execute("PRAGMA mmap_size = 30000000000")
147 cursor.close()
149 def _ensure_table(self) -> None:
150 """Ensure the table exists."""
151 if not self.conn:
152 raise RuntimeError("Database not connected. Call connect() first.")
154 cursor = self.conn.cursor()
155 cursor.executescript(self.table_manager.get_create_table_sql())
156 self.conn.commit()
157 cursor.close()
159 def _check_connection(self) -> None:
160 """Check if database is connected."""
161 if not self._connected or not self.conn:
162 raise RuntimeError("Database not connected. Call connect() first.")
164 def create(self, record: Record) -> str:
165 """Create a new record."""
166 self._check_connection()
168 # Update vector dimensions tracking if needed
169 if self._has_vector_fields(record):
170 self._update_vector_dimensions(record)
172 # Use centralized method to prepare record
173 record, storage_id = self._prepare_record_for_storage(record)
175 # Use the standard SQL serializer
176 data_json = self.record_to_json(record)
177 metadata_json = json.dumps(record.metadata) if record.metadata else None
179 # Build insert query for SQLite's standard table structure
180 query = f"INSERT INTO {self.table_name} (id, data, metadata) VALUES (?, ?, ?)"
181 params = [storage_id, data_json, metadata_json]
183 cursor = self.conn.cursor()
185 try:
186 cursor.execute(query, params)
187 self.conn.commit()
188 return storage_id
189 except sqlite3.IntegrityError as e:
190 self.conn.rollback()
191 raise ValueError(f"Record with ID {record.id} already exists") from e
192 finally:
193 cursor.close()
195 def read(self, id: str) -> Record | None:
196 """Read a record by ID."""
197 self._check_connection()
199 query, params = self.query_builder.build_read_query(id)
200 cursor = self.conn.cursor()
202 try:
203 cursor.execute(query, params)
204 row = cursor.fetchone()
206 if row:
207 # Use the standard SQL serializer
208 record = self.row_to_record(dict(row))
209 # Use centralized method to prepare record
210 return self._prepare_record_from_storage(record, id)
211 return None
212 finally:
213 cursor.close()
215 def update(self, id: str, record: Record) -> bool:
216 """Update an existing record."""
217 self._check_connection()
219 # Update vector dimensions tracking if needed
220 if self._has_vector_fields(record):
221 self._update_vector_dimensions(record)
223 # Use the standard SQL serializer
224 data_json = self.record_to_json(record)
225 metadata_json = json.dumps(record.metadata) if record.metadata else None
227 # Build update query
228 query = f"UPDATE {self.table_name} SET data = ?, metadata = ? WHERE id = ?"
229 params = [data_json, metadata_json, id]
231 cursor = self.conn.cursor()
233 try:
234 cursor.execute(query, params)
235 self.conn.commit()
236 return cursor.rowcount > 0
237 finally:
238 cursor.close()
240 def delete(self, id: str) -> bool:
241 """Delete a record by ID."""
242 self._check_connection()
244 query, params = self.query_builder.build_delete_query(id)
245 cursor = self.conn.cursor()
247 try:
248 cursor.execute(query, params)
249 self.conn.commit()
250 return cursor.rowcount > 0
251 finally:
252 cursor.close()
254 def exists(self, id: str) -> bool:
255 """Check if a record exists."""
256 self._check_connection()
258 query, params = self.query_builder.build_exists_query(id)
259 cursor = self.conn.cursor()
261 try:
262 cursor.execute(query, params)
263 result = cursor.fetchone()
264 return result is not None
265 finally:
266 cursor.close()
268 def clear(self) -> int:
269 """Clear all records from the database."""
270 self._check_connection()
272 cursor = self.conn.cursor()
273 try:
274 # Get count before clearing
275 cursor.execute(f"SELECT COUNT(*) FROM {self.table_manager.table_name}")
276 count = cursor.fetchone()[0]
278 # Clear the table
279 cursor.execute(f"DELETE FROM {self.table_manager.table_name}")
280 self.conn.commit()
282 return count
283 finally:
284 cursor.close()
286 def search(self, query: Query | ComplexQuery) -> list[Record]:
287 """Search for records matching a query."""
288 self._check_connection()
290 # Handle ComplexQuery with native SQL support
291 if isinstance(query, ComplexQuery):
292 sql_query, params = self.query_builder.build_complex_search_query(query)
293 else:
294 sql_query, params = self.query_builder.build_search_query(query)
296 cursor = self.conn.cursor()
298 try:
299 cursor.execute(sql_query, params)
300 rows = cursor.fetchall()
302 records = []
303 for row in rows:
304 row_dict = dict(row)
305 record = self.row_to_record(row_dict)
307 # Populate storage_id from database ID
308 record.storage_id = str(row_dict['id'])
310 records.append(record)
312 # Apply field projection if specified
313 if query.fields:
314 records = [r.project(query.fields) for r in records]
316 return records
317 finally:
318 cursor.close()
320 def count(self, query: Query | None = None) -> int:
321 """Count records matching a query."""
322 self._check_connection()
324 sql_query, params = self.query_builder.build_count_query(query)
325 cursor = self.conn.cursor()
327 try:
328 cursor.execute(sql_query, params)
329 result = cursor.fetchone()
330 return result[0] if result else 0
331 finally:
332 cursor.close()
334 def create_batch(self, records: list[Record]) -> list[str]:
335 """Create multiple records efficiently using a single query.
337 Uses multi-value INSERT for better performance.
338 """
339 if not records:
340 return []
342 self._check_connection()
344 # Use the shared batch create query builder
345 query, params, ids = self.query_builder.build_batch_create_query(records)
347 cursor = self.conn.cursor()
348 try:
349 # Execute the batch insert in a transaction
350 cursor.execute("BEGIN TRANSACTION")
351 cursor.execute(query, params)
352 self.conn.commit()
354 # Return the generated IDs
355 return ids
356 except Exception:
357 self.conn.rollback()
358 raise
359 finally:
360 cursor.close()
362 def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
363 """Update multiple records efficiently using a single query.
365 Uses CASE expressions for batch updates, similar to PostgreSQL.
366 """
367 if not updates:
368 return []
370 self._check_connection()
372 # Use the shared batch update query builder
373 query, params = self.query_builder.build_batch_update_query(updates)
375 cursor = self.conn.cursor()
376 try:
377 # Execute the batch update in a transaction
378 cursor.execute("BEGIN TRANSACTION")
379 cursor.execute(query, params)
380 self.conn.commit()
382 # Check which records were actually updated
383 # SQLite doesn't have RETURNING, so we need to verify each ID
384 update_ids = [record_id for record_id, _ in updates]
385 placeholders = ", ".join(["?" for _ in update_ids])
386 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
387 cursor.execute(check_query, update_ids)
388 existing_ids = {row[0] for row in cursor.fetchall()}
390 # Return results for each update
391 results = []
392 for record_id, _ in updates:
393 results.append(record_id in existing_ids)
395 return results
396 except Exception:
397 self.conn.rollback()
398 raise
399 finally:
400 cursor.close()
402 def delete_batch(self, ids: list[str]) -> list[bool]:
403 """Delete multiple records efficiently using a single query.
405 Uses single DELETE with IN clause for better performance.
406 """
407 if not ids:
408 return []
410 self._check_connection()
412 # Check which IDs exist before deletion
413 placeholders = ", ".join(["?" for _ in ids])
414 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
416 cursor = self.conn.cursor()
417 try:
418 cursor.execute(check_query, ids)
419 existing_ids = {row[0] for row in cursor.fetchall()}
421 # Use the shared batch delete query builder
422 query, params = self.query_builder.build_batch_delete_query(ids)
424 # Execute the batch delete in a transaction
425 cursor.execute("BEGIN TRANSACTION")
426 cursor.execute(query, params)
427 self.conn.commit()
429 # Return results based on which IDs existed
430 results = []
431 for id in ids:
432 results.append(id in existing_ids)
434 return results
435 except Exception:
436 self.conn.rollback()
437 raise
438 finally:
439 cursor.close()
441 def _initialize(self) -> None:
442 """Initialize method - connection setup handled in connect()."""
443 pass
445 def _count_all(self) -> int:
446 """Count all records in the database."""
447 self._check_connection()
448 cursor = self.conn.cursor()
449 try:
450 cursor.execute(f"SELECT COUNT(*) FROM {self.table_name}")
451 result = cursor.fetchone()
452 return result[0] if result else 0
453 finally:
454 cursor.close()
456 def stream_read(
457 self,
458 query: Query | None = None,
459 config: StreamConfig | None = None
460 ) -> Iterator[Record]:
461 """Stream records from database."""
462 from ..streaming import StreamConfig
464 config = config or StreamConfig()
465 query = query or Query()
467 # Use the existing stream method's logic but yield individual records
468 offset = 0
469 while True:
470 # Fetch a batch
471 query_copy = query.copy()
472 query_copy.offset(offset).limit(config.batch_size)
473 batch = self.search(query_copy)
475 if not batch:
476 break
478 for record in batch:
479 yield record
481 offset += len(batch)
483 # If we got less than batch_size, we're done
484 if len(batch) < config.batch_size:
485 break
487 def stream_write(
488 self,
489 records: Iterator[Record],
490 config: StreamConfig | None = None
491 ) -> StreamResult:
492 """Stream records into database."""
493 from ..streaming import StreamConfig, StreamResult
495 config = config or StreamConfig()
496 batch = []
497 total_written = 0
498 start_time = time.time()
500 for record in records:
501 batch.append(record)
503 if len(batch) >= config.batch_size:
504 # Write the batch
505 self.create_batch(batch)
506 total_written += len(batch)
507 batch = []
509 # Write any remaining records
510 if batch:
511 self.create_batch(batch)
512 total_written += len(batch)
514 elapsed = time.time() - start_time
516 return StreamResult(
517 total_processed=total_written,
518 successful=total_written,
519 failed=0,
520 duration=elapsed,
521 total_batches=(total_written + config.batch_size - 1) // config.batch_size
522 )
524 # Vector support methods
525 def has_vector_support(self) -> bool:
526 """Check if this backend has vector support.
528 Returns:
529 False - SQLite has no native vector support, uses Python-based similarity
530 """
531 return False # No native vector support
533 def enable_vector_support(self) -> bool:
534 """Enable vector support for this backend.
536 Returns:
537 True - Vector support is always available (Python-based)
538 """
539 # SQLite doesn't need any special setup for vector support
540 # We handle vectors as JSON strings
541 self.vector_enabled = True
542 return True
544 def vector_search(
545 self,
546 query_vector: np.ndarray,
547 field_name: str = "embedding",
548 k: int = 10,
549 filter: Query | None = None,
550 metric: DistanceMetric | None = None,
551 **kwargs
552 ) -> list[VectorSearchResult]:
553 """Perform vector similarity search using Python-based calculations.
555 Delegates to PythonVectorSearchMixin for the implementation.
557 Args:
558 query_vector: Query vector
559 field_name: Name of the vector field to search
560 k: Number of results to return
561 filter: Optional filter conditions
562 metric: Distance metric (uses instance default if not specified)
563 **kwargs: Additional arguments for compatibility
565 Returns:
566 List of search results with scores
567 """
568 self._check_connection()
570 # Delegate to the mixin's implementation
571 return self.python_vector_search_sync(
572 query_vector=query_vector,
573 vector_field=field_name,
574 k=k,
575 filter=filter,
576 metric=metric,
577 **kwargs
578 )
580 def add_vectors(
581 self,
582 vectors: list[np.ndarray],
583 ids: list[str] | None = None,
584 metadata: list[dict[str, Any]] | None = None,
585 field_name: str = "embedding",
586 ) -> list[str]:
587 """Add vectors to the database.
589 Args:
590 vectors: List of vectors to add
591 ids: Optional list of IDs
592 metadata: Optional list of metadata dicts
593 field_name: Name of the vector field
595 Returns:
596 List of created record IDs
597 """
598 from collections import OrderedDict
600 from ..fields import VectorField
602 # Generate IDs if not provided
603 if ids is None:
604 ids = [str(uuid.uuid4()) for _ in vectors]
606 # Create records with vector fields
607 records = []
608 for i, vector in enumerate(vectors):
609 # Create vector field
610 vector_field = VectorField(
611 name=field_name,
612 value=vector,
613 dimensions=len(vector) if isinstance(vector, (list, np.ndarray)) else None
614 )
616 # Create record
617 record_metadata = metadata[i] if metadata and i < len(metadata) else {}
618 record = Record(
619 data=OrderedDict({field_name: vector_field}),
620 metadata=record_metadata,
621 storage_id=ids[i]
622 )
623 records.append(record)
625 # Use batch create for efficiency
626 return self.create_batch(records)