Coverage for src/dataknobs_data/backends/sqlite_async.py: 44%
234 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-08 15:25 -0600
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-08 15:25 -0600
1"""Async SQLite backend implementation using aiosqlite."""
3from __future__ import annotations
5import logging
6from pathlib import Path
7from typing import Any, TYPE_CHECKING
9import aiosqlite
10from dataknobs_config import ConfigurableBase
12from ..database import AsyncDatabase
13from ..pooling import ConnectionPoolManager
14from ..query import Query
15from ..query_logic import ComplexQuery
16from ..vector import VectorOperationsMixin
17from ..vector.bulk_embed_mixin import BulkEmbedMixin
18from ..vector.python_vector_search import PythonVectorSearchMixin
19from .sql_base import SQLQueryBuilder, SQLTableManager
20from .sqlite_mixins import SQLiteVectorSupport
21from .vector_config_mixin import VectorConfigMixin
23if TYPE_CHECKING:
24 from collections.abc import AsyncIterator
25 from ..records import Record
26 from ..streaming import StreamConfig, StreamResult
29logger = logging.getLogger(__name__)
31# Global pool manager for SQLite connections
32_pool_manager = ConnectionPoolManager()
35class AsyncSQLiteDatabase( # type: ignore[misc]
36 AsyncDatabase,
37 ConfigurableBase,
38 VectorConfigMixin,
39 SQLiteVectorSupport,
40 PythonVectorSearchMixin, # Provides python_vector_search_async
41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store
42 VectorOperationsMixin
43):
44 """Asynchronous SQLite database backend using aiosqlite."""
46 def __init__(self, config: dict[str, Any] | None = None):
47 """Initialize async SQLite database.
49 Args:
50 config: Configuration with the following optional keys:
51 - path: Database file path (default: ":memory:")
52 - table: Table name (default: "records")
53 - timeout: Connection timeout in seconds (default: 5.0)
54 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: WAL for file-based)
55 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: NORMAL)
56 - pool_size: Number of connections in pool (default: 5)
57 """
58 super().__init__(config)
59 config = config or {}
60 self.db_path = config.get("path", ":memory:")
61 self.table_name = config.get("table", "records")
62 self.timeout = config.get("timeout", 5.0)
63 self.journal_mode = config.get("journal_mode", "WAL" if self.db_path != ":memory:" else None)
64 self.synchronous = config.get("synchronous", "NORMAL")
65 self.pool_size = config.get("pool_size", 5)
67 # Start with standard query builder, will customize after mixins are initialized
68 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark")
69 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite")
71 self.db: aiosqlite.Connection | None = None
72 self._connected = False
74 # Initialize vector support
75 self._parse_vector_config(config)
76 self._init_vector_state()
78 @classmethod
79 def from_config(cls, config: dict) -> AsyncSQLiteDatabase:
80 """Create from config dictionary."""
81 return cls(config)
83 async def connect(self) -> None:
84 """Connect to the SQLite database."""
85 if self._connected:
86 return
88 # Create directory if needed for file-based database
89 if self.db_path != ":memory:":
90 db_file = Path(self.db_path)
91 db_file.parent.mkdir(parents=True, exist_ok=True)
93 # Connect to database
94 self.db = await aiosqlite.connect(
95 self.db_path,
96 timeout=self.timeout
97 )
99 # Enable row factory for dict-like access
100 self.db.row_factory = aiosqlite.Row
102 # Configure SQLite for better performance
103 await self._configure_sqlite()
105 # Create table if it doesn't exist
106 await self._ensure_table()
108 self._connected = True
109 logger.info(f"Connected to async SQLite database: {self.db_path}")
111 async def close(self) -> None:
112 """Close the database connection."""
113 if self.db:
114 await self.db.close()
115 self.db = None
116 self._connected = False
117 logger.info(f"Disconnected from async SQLite database: {self.db_path}")
119 async def _configure_sqlite(self) -> None:
120 """Configure SQLite settings for performance."""
121 if not self.db:
122 return
124 # Set journal mode if specified
125 if self.journal_mode:
126 await self.db.execute(f"PRAGMA journal_mode = {self.journal_mode}")
127 logger.debug(f"Set journal_mode to {self.journal_mode}")
129 # Set synchronous mode
130 await self.db.execute(f"PRAGMA synchronous = {self.synchronous}")
131 logger.debug(f"Set synchronous to {self.synchronous}")
133 # Enable foreign keys
134 await self.db.execute("PRAGMA foreign_keys = ON")
136 # Optimize for performance
137 await self.db.execute("PRAGMA temp_store = MEMORY")
138 await self.db.execute("PRAGMA mmap_size = 30000000000")
140 await self.db.commit()
142 async def _ensure_table(self) -> None:
143 """Ensure the table exists."""
144 if not self.db:
145 raise RuntimeError("Database not connected. Call connect() first.")
147 await self.db.executescript(self.table_manager.get_create_table_sql())
148 await self.db.commit()
150 def _check_connection(self) -> None:
151 """Check if database is connected."""
152 if not self._connected or not self.db:
153 raise RuntimeError("Database not connected. Call connect() first.")
155 async def create(self, record: Record) -> str:
156 """Create a new record."""
157 self._check_connection()
159 query, params = self.query_builder.build_create_query(record)
161 try:
162 await self.db.execute(query, params)
163 await self.db.commit()
165 # SQLite doesn't support RETURNING, so we use the ID we generated
166 record_id = params[0] # ID is the first parameter
167 return record_id
168 except aiosqlite.IntegrityError as e:
169 await self.db.rollback()
170 raise ValueError(f"Record with ID {params[0]} already exists") from e
172 async def read(self, id: str) -> Record | None:
173 """Read a record by ID."""
174 self._check_connection()
176 query, params = self.query_builder.build_read_query(id)
178 async with self.db.execute(query, params) as cursor:
179 row = await cursor.fetchone()
181 if row:
182 return SQLQueryBuilder.row_to_record(dict(row))
183 return None
185 async def update(self, id: str, record: Record) -> bool:
186 """Update an existing record."""
187 self._check_connection()
189 query, params = self.query_builder.build_update_query(id, record)
191 cursor = await self.db.execute(query, params)
192 await self.db.commit()
193 return cursor.rowcount > 0
195 async def delete(self, id: str) -> bool:
196 """Delete a record by ID."""
197 self._check_connection()
199 query, params = self.query_builder.build_delete_query(id)
201 cursor = await self.db.execute(query, params)
202 await self.db.commit()
203 return cursor.rowcount > 0
205 async def exists(self, id: str) -> bool:
206 """Check if a record exists."""
207 self._check_connection()
209 query, params = self.query_builder.build_exists_query(id)
211 async with self.db.execute(query, params) as cursor:
212 result = await cursor.fetchone()
213 return result is not None
215 async def search(self, query: Query | ComplexQuery) -> list[Record]:
216 """Search for records matching a query."""
217 self._check_connection()
219 # Handle ComplexQuery with native SQL support
220 if isinstance(query, ComplexQuery):
221 sql_query, params = self.query_builder.build_complex_search_query(query)
222 else:
223 sql_query, params = self.query_builder.build_search_query(query)
225 async with self.db.execute(sql_query, params) as cursor:
226 rows = await cursor.fetchall()
228 records = []
229 for row in rows:
230 row_dict = dict(row)
231 record = SQLQueryBuilder.row_to_record(row_dict)
233 # Populate storage_id from database ID
234 record.storage_id = str(row_dict['id'])
236 records.append(record)
238 # Apply field projection if specified
239 if query.fields:
240 records = [r.project(query.fields) for r in records]
242 return records
244 async def count(self, query: Query | None = None) -> int:
245 """Count records matching a query."""
246 self._check_connection()
248 sql_query, params = self.query_builder.build_count_query(query)
250 async with self.db.execute(sql_query, params) as cursor:
251 result = await cursor.fetchone()
252 return result[0] if result else 0
254 async def create_batch(self, records: list[Record]) -> list[str]:
255 """Create multiple records efficiently using a single query.
257 Uses multi-value INSERT for better performance.
258 """
259 if not records:
260 return []
262 self._check_connection()
264 # Use the shared batch create query builder
265 query, params, ids = self.query_builder.build_batch_create_query(records)
267 # Execute the batch insert in a transaction
268 await self.db.execute("BEGIN TRANSACTION")
270 try:
271 await self.db.execute(query, params)
272 await self.db.commit()
274 # Return the generated IDs
275 return ids
276 except Exception:
277 await self.db.rollback()
278 raise
280 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]:
281 """Update multiple records efficiently using a single query.
283 Uses CASE expressions for batch updates, similar to PostgreSQL.
284 """
285 if not updates:
286 return []
288 self._check_connection()
290 # Use the shared batch update query builder
291 query, params = self.query_builder.build_batch_update_query(updates)
293 # Execute the batch update in a transaction
294 await self.db.execute("BEGIN TRANSACTION")
296 try:
297 await self.db.execute(query, params)
298 await self.db.commit()
300 # Check which records were actually updated
301 # SQLite doesn't have RETURNING, so we need to verify each ID
302 update_ids = [record_id for record_id, _ in updates]
303 placeholders = ", ".join(["?" for _ in update_ids])
304 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
306 async with self.db.execute(check_query, update_ids) as check_cursor:
307 rows = await check_cursor.fetchall()
308 existing_ids = {row[0] for row in rows}
310 # Return results for each update
311 results = []
312 for record_id, _ in updates:
313 results.append(record_id in existing_ids)
315 return results
316 except Exception:
317 await self.db.rollback()
318 raise
320 async def delete_batch(self, ids: list[str]) -> list[bool]:
321 """Delete multiple records efficiently using a single query.
323 Uses single DELETE with IN clause for better performance.
324 """
325 if not ids:
326 return []
328 self._check_connection()
330 # Check which IDs exist before deletion
331 placeholders = ", ".join(["?" for _ in ids])
332 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})"
334 async with self.db.execute(check_query, ids) as cursor:
335 rows = await cursor.fetchall()
336 existing_ids = {row[0] for row in rows}
338 # Use the shared batch delete query builder
339 query, params = self.query_builder.build_batch_delete_query(ids)
341 # Execute the batch delete in a transaction
342 await self.db.execute("BEGIN TRANSACTION")
344 try:
345 await self.db.execute(query, params)
346 await self.db.commit()
348 # Return results based on which IDs existed
349 results = []
350 for id in ids:
351 results.append(id in existing_ids)
353 return results
354 except Exception:
355 await self.db.rollback()
356 raise
358 def _initialize(self) -> None:
359 """Initialize method - connection setup handled in connect()."""
360 pass
362 async def _count_all(self) -> int:
363 """Count all records in the database."""
364 self._check_connection()
366 async with self.db.execute(f"SELECT COUNT(*) FROM {self.table_name}") as cursor:
367 result = await cursor.fetchone()
368 return result[0] if result else 0
370 async def stream_read(
371 self,
372 query: Query | None = None,
373 config: StreamConfig | None = None
374 ) -> AsyncIterator[Record]:
375 """Stream records from database."""
376 from ..streaming import StreamConfig
378 config = config or StreamConfig()
379 query = query or Query()
381 # Use the existing stream method's logic but yield individual records
382 offset = 0
383 while True:
384 # Fetch a batch
385 query_copy = query.copy()
386 query_copy.offset(offset).limit(config.batch_size)
387 batch = await self.search(query_copy)
389 if not batch:
390 break
392 for record in batch:
393 yield record
395 offset += len(batch)
397 # If we got less than batch_size, we're done
398 if len(batch) < config.batch_size:
399 break
401 async def stream_write(
402 self,
403 records: AsyncIterator[Record],
404 config: StreamConfig | None = None
405 ) -> StreamResult:
406 """Stream records into database."""
407 import time
409 from ..streaming import StreamConfig, StreamResult
411 config = config or StreamConfig()
412 batch = []
413 total_written = 0
414 start_time = time.time()
416 async for record in records:
417 batch.append(record)
419 if len(batch) >= config.batch_size:
420 # Write the batch
421 await self.create_batch(batch)
422 total_written += len(batch)
423 batch = []
425 # Write any remaining records
426 if batch:
427 await self.create_batch(batch)
428 total_written += len(batch)
430 elapsed = time.time() - start_time
432 return StreamResult(
433 total_processed=total_written,
434 successful=total_written,
435 failed=0,
436 duration=elapsed,
437 total_batches=(total_written + config.batch_size - 1) // config.batch_size
438 )
440 async def vector_search(
441 self,
442 query_vector,
443 vector_field: str = "embedding",
444 k: int = 10,
445 filter=None,
446 metric=None,
447 **kwargs
448 ):
449 """Perform async vector similarity search using Python-based calculations.
451 Delegates to PythonVectorSearchMixin for the implementation.
453 Args:
454 query_vector: Query vector
455 vector_field: Name of the vector field to search
456 k: Number of results to return
457 filter: Optional filter conditions
458 metric: Distance metric (uses instance default if not specified)
459 **kwargs: Additional arguments for compatibility
461 Returns:
462 List of VectorSearchResult objects with scores
463 """
464 self._check_connection()
466 # Delegate to the mixin's implementation
467 return await self.python_vector_search_async(
468 query_vector=query_vector,
469 vector_field=vector_field,
470 k=k,
471 filter=filter,
472 metric=metric,
473 **kwargs
474 )