Coverage for src/dataknobs_data/backends/sqlite_async.py: 44%

234 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-08 15:25 -0600

1"""Async SQLite backend implementation using aiosqlite.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from pathlib import Path 

7from typing import Any, TYPE_CHECKING 

8 

9import aiosqlite 

10from dataknobs_config import ConfigurableBase 

11 

12from ..database import AsyncDatabase 

13from ..pooling import ConnectionPoolManager 

14from ..query import Query 

15from ..query_logic import ComplexQuery 

16from ..vector import VectorOperationsMixin 

17from ..vector.bulk_embed_mixin import BulkEmbedMixin 

18from ..vector.python_vector_search import PythonVectorSearchMixin 

19from .sql_base import SQLQueryBuilder, SQLTableManager 

20from .sqlite_mixins import SQLiteVectorSupport 

21from .vector_config_mixin import VectorConfigMixin 

22 

23if TYPE_CHECKING: 

24 from collections.abc import AsyncIterator 

25 from ..records import Record 

26 from ..streaming import StreamConfig, StreamResult 

27 

28 

29logger = logging.getLogger(__name__) 

30 

31# Global pool manager for SQLite connections 

32_pool_manager = ConnectionPoolManager() 

33 

34 

35class AsyncSQLiteDatabase( # type: ignore[misc] 

36 AsyncDatabase, 

37 ConfigurableBase, 

38 VectorConfigMixin, 

39 SQLiteVectorSupport, 

40 PythonVectorSearchMixin, # Provides python_vector_search_async 

41 BulkEmbedMixin, # Must come before VectorOperationsMixin to override bulk_embed_and_store 

42 VectorOperationsMixin 

43): 

44 """Asynchronous SQLite database backend using aiosqlite.""" 

45 

46 def __init__(self, config: dict[str, Any] | None = None): 

47 """Initialize async SQLite database. 

48  

49 Args: 

50 config: Configuration with the following optional keys: 

51 - path: Database file path (default: ":memory:") 

52 - table: Table name (default: "records") 

53 - timeout: Connection timeout in seconds (default: 5.0) 

54 - journal_mode: Journal mode (WAL, DELETE, etc.) (default: WAL for file-based) 

55 - synchronous: Synchronous mode (NORMAL, FULL, OFF) (default: NORMAL) 

56 - pool_size: Number of connections in pool (default: 5) 

57 """ 

58 super().__init__(config) 

59 config = config or {} 

60 self.db_path = config.get("path", ":memory:") 

61 self.table_name = config.get("table", "records") 

62 self.timeout = config.get("timeout", 5.0) 

63 self.journal_mode = config.get("journal_mode", "WAL" if self.db_path != ":memory:" else None) 

64 self.synchronous = config.get("synchronous", "NORMAL") 

65 self.pool_size = config.get("pool_size", 5) 

66 

67 # Start with standard query builder, will customize after mixins are initialized 

68 self.query_builder = SQLQueryBuilder(self.table_name, dialect="sqlite", param_style="qmark") 

69 self.table_manager = SQLTableManager(self.table_name, dialect="sqlite") 

70 

71 self.db: aiosqlite.Connection | None = None 

72 self._connected = False 

73 

74 # Initialize vector support 

75 self._parse_vector_config(config) 

76 self._init_vector_state() 

77 

78 @classmethod 

79 def from_config(cls, config: dict) -> AsyncSQLiteDatabase: 

80 """Create from config dictionary.""" 

81 return cls(config) 

82 

83 async def connect(self) -> None: 

84 """Connect to the SQLite database.""" 

85 if self._connected: 

86 return 

87 

88 # Create directory if needed for file-based database 

89 if self.db_path != ":memory:": 

90 db_file = Path(self.db_path) 

91 db_file.parent.mkdir(parents=True, exist_ok=True) 

92 

93 # Connect to database 

94 self.db = await aiosqlite.connect( 

95 self.db_path, 

96 timeout=self.timeout 

97 ) 

98 

99 # Enable row factory for dict-like access 

100 self.db.row_factory = aiosqlite.Row 

101 

102 # Configure SQLite for better performance 

103 await self._configure_sqlite() 

104 

105 # Create table if it doesn't exist 

106 await self._ensure_table() 

107 

108 self._connected = True 

109 logger.info(f"Connected to async SQLite database: {self.db_path}") 

110 

111 async def close(self) -> None: 

112 """Close the database connection.""" 

113 if self.db: 

114 await self.db.close() 

115 self.db = None 

116 self._connected = False 

117 logger.info(f"Disconnected from async SQLite database: {self.db_path}") 

118 

119 async def _configure_sqlite(self) -> None: 

120 """Configure SQLite settings for performance.""" 

121 if not self.db: 

122 return 

123 

124 # Set journal mode if specified 

125 if self.journal_mode: 

126 await self.db.execute(f"PRAGMA journal_mode = {self.journal_mode}") 

127 logger.debug(f"Set journal_mode to {self.journal_mode}") 

128 

129 # Set synchronous mode 

130 await self.db.execute(f"PRAGMA synchronous = {self.synchronous}") 

131 logger.debug(f"Set synchronous to {self.synchronous}") 

132 

133 # Enable foreign keys 

134 await self.db.execute("PRAGMA foreign_keys = ON") 

135 

136 # Optimize for performance 

137 await self.db.execute("PRAGMA temp_store = MEMORY") 

138 await self.db.execute("PRAGMA mmap_size = 30000000000") 

139 

140 await self.db.commit() 

141 

142 async def _ensure_table(self) -> None: 

143 """Ensure the table exists.""" 

144 if not self.db: 

145 raise RuntimeError("Database not connected. Call connect() first.") 

146 

147 await self.db.executescript(self.table_manager.get_create_table_sql()) 

148 await self.db.commit() 

149 

150 def _check_connection(self) -> None: 

151 """Check if database is connected.""" 

152 if not self._connected or not self.db: 

153 raise RuntimeError("Database not connected. Call connect() first.") 

154 

155 async def create(self, record: Record) -> str: 

156 """Create a new record.""" 

157 self._check_connection() 

158 

159 query, params = self.query_builder.build_create_query(record) 

160 

161 try: 

162 await self.db.execute(query, params) 

163 await self.db.commit() 

164 

165 # SQLite doesn't support RETURNING, so we use the ID we generated 

166 record_id = params[0] # ID is the first parameter 

167 return record_id 

168 except aiosqlite.IntegrityError as e: 

169 await self.db.rollback() 

170 raise ValueError(f"Record with ID {params[0]} already exists") from e 

171 

172 async def read(self, id: str) -> Record | None: 

173 """Read a record by ID.""" 

174 self._check_connection() 

175 

176 query, params = self.query_builder.build_read_query(id) 

177 

178 async with self.db.execute(query, params) as cursor: 

179 row = await cursor.fetchone() 

180 

181 if row: 

182 return SQLQueryBuilder.row_to_record(dict(row)) 

183 return None 

184 

185 async def update(self, id: str, record: Record) -> bool: 

186 """Update an existing record.""" 

187 self._check_connection() 

188 

189 query, params = self.query_builder.build_update_query(id, record) 

190 

191 cursor = await self.db.execute(query, params) 

192 await self.db.commit() 

193 return cursor.rowcount > 0 

194 

195 async def delete(self, id: str) -> bool: 

196 """Delete a record by ID.""" 

197 self._check_connection() 

198 

199 query, params = self.query_builder.build_delete_query(id) 

200 

201 cursor = await self.db.execute(query, params) 

202 await self.db.commit() 

203 return cursor.rowcount > 0 

204 

205 async def exists(self, id: str) -> bool: 

206 """Check if a record exists.""" 

207 self._check_connection() 

208 

209 query, params = self.query_builder.build_exists_query(id) 

210 

211 async with self.db.execute(query, params) as cursor: 

212 result = await cursor.fetchone() 

213 return result is not None 

214 

215 async def search(self, query: Query | ComplexQuery) -> list[Record]: 

216 """Search for records matching a query.""" 

217 self._check_connection() 

218 

219 # Handle ComplexQuery with native SQL support 

220 if isinstance(query, ComplexQuery): 

221 sql_query, params = self.query_builder.build_complex_search_query(query) 

222 else: 

223 sql_query, params = self.query_builder.build_search_query(query) 

224 

225 async with self.db.execute(sql_query, params) as cursor: 

226 rows = await cursor.fetchall() 

227 

228 records = [] 

229 for row in rows: 

230 row_dict = dict(row) 

231 record = SQLQueryBuilder.row_to_record(row_dict) 

232 

233 # Populate storage_id from database ID 

234 record.storage_id = str(row_dict['id']) 

235 

236 records.append(record) 

237 

238 # Apply field projection if specified 

239 if query.fields: 

240 records = [r.project(query.fields) for r in records] 

241 

242 return records 

243 

244 async def count(self, query: Query | None = None) -> int: 

245 """Count records matching a query.""" 

246 self._check_connection() 

247 

248 sql_query, params = self.query_builder.build_count_query(query) 

249 

250 async with self.db.execute(sql_query, params) as cursor: 

251 result = await cursor.fetchone() 

252 return result[0] if result else 0 

253 

254 async def create_batch(self, records: list[Record]) -> list[str]: 

255 """Create multiple records efficiently using a single query. 

256  

257 Uses multi-value INSERT for better performance. 

258 """ 

259 if not records: 

260 return [] 

261 

262 self._check_connection() 

263 

264 # Use the shared batch create query builder 

265 query, params, ids = self.query_builder.build_batch_create_query(records) 

266 

267 # Execute the batch insert in a transaction 

268 await self.db.execute("BEGIN TRANSACTION") 

269 

270 try: 

271 await self.db.execute(query, params) 

272 await self.db.commit() 

273 

274 # Return the generated IDs 

275 return ids 

276 except Exception: 

277 await self.db.rollback() 

278 raise 

279 

280 async def update_batch(self, updates: list[tuple[str, Record]]) -> list[bool]: 

281 """Update multiple records efficiently using a single query. 

282  

283 Uses CASE expressions for batch updates, similar to PostgreSQL. 

284 """ 

285 if not updates: 

286 return [] 

287 

288 self._check_connection() 

289 

290 # Use the shared batch update query builder 

291 query, params = self.query_builder.build_batch_update_query(updates) 

292 

293 # Execute the batch update in a transaction 

294 await self.db.execute("BEGIN TRANSACTION") 

295 

296 try: 

297 await self.db.execute(query, params) 

298 await self.db.commit() 

299 

300 # Check which records were actually updated 

301 # SQLite doesn't have RETURNING, so we need to verify each ID 

302 update_ids = [record_id for record_id, _ in updates] 

303 placeholders = ", ".join(["?" for _ in update_ids]) 

304 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

305 

306 async with self.db.execute(check_query, update_ids) as check_cursor: 

307 rows = await check_cursor.fetchall() 

308 existing_ids = {row[0] for row in rows} 

309 

310 # Return results for each update 

311 results = [] 

312 for record_id, _ in updates: 

313 results.append(record_id in existing_ids) 

314 

315 return results 

316 except Exception: 

317 await self.db.rollback() 

318 raise 

319 

320 async def delete_batch(self, ids: list[str]) -> list[bool]: 

321 """Delete multiple records efficiently using a single query. 

322  

323 Uses single DELETE with IN clause for better performance. 

324 """ 

325 if not ids: 

326 return [] 

327 

328 self._check_connection() 

329 

330 # Check which IDs exist before deletion 

331 placeholders = ", ".join(["?" for _ in ids]) 

332 check_query = f"SELECT id FROM {self.table_name} WHERE id IN ({placeholders})" 

333 

334 async with self.db.execute(check_query, ids) as cursor: 

335 rows = await cursor.fetchall() 

336 existing_ids = {row[0] for row in rows} 

337 

338 # Use the shared batch delete query builder 

339 query, params = self.query_builder.build_batch_delete_query(ids) 

340 

341 # Execute the batch delete in a transaction 

342 await self.db.execute("BEGIN TRANSACTION") 

343 

344 try: 

345 await self.db.execute(query, params) 

346 await self.db.commit() 

347 

348 # Return results based on which IDs existed 

349 results = [] 

350 for id in ids: 

351 results.append(id in existing_ids) 

352 

353 return results 

354 except Exception: 

355 await self.db.rollback() 

356 raise 

357 

358 def _initialize(self) -> None: 

359 """Initialize method - connection setup handled in connect().""" 

360 pass 

361 

362 async def _count_all(self) -> int: 

363 """Count all records in the database.""" 

364 self._check_connection() 

365 

366 async with self.db.execute(f"SELECT COUNT(*) FROM {self.table_name}") as cursor: 

367 result = await cursor.fetchone() 

368 return result[0] if result else 0 

369 

370 async def stream_read( 

371 self, 

372 query: Query | None = None, 

373 config: StreamConfig | None = None 

374 ) -> AsyncIterator[Record]: 

375 """Stream records from database.""" 

376 from ..streaming import StreamConfig 

377 

378 config = config or StreamConfig() 

379 query = query or Query() 

380 

381 # Use the existing stream method's logic but yield individual records 

382 offset = 0 

383 while True: 

384 # Fetch a batch 

385 query_copy = query.copy() 

386 query_copy.offset(offset).limit(config.batch_size) 

387 batch = await self.search(query_copy) 

388 

389 if not batch: 

390 break 

391 

392 for record in batch: 

393 yield record 

394 

395 offset += len(batch) 

396 

397 # If we got less than batch_size, we're done 

398 if len(batch) < config.batch_size: 

399 break 

400 

401 async def stream_write( 

402 self, 

403 records: AsyncIterator[Record], 

404 config: StreamConfig | None = None 

405 ) -> StreamResult: 

406 """Stream records into database.""" 

407 import time 

408 

409 from ..streaming import StreamConfig, StreamResult 

410 

411 config = config or StreamConfig() 

412 batch = [] 

413 total_written = 0 

414 start_time = time.time() 

415 

416 async for record in records: 

417 batch.append(record) 

418 

419 if len(batch) >= config.batch_size: 

420 # Write the batch 

421 await self.create_batch(batch) 

422 total_written += len(batch) 

423 batch = [] 

424 

425 # Write any remaining records 

426 if batch: 

427 await self.create_batch(batch) 

428 total_written += len(batch) 

429 

430 elapsed = time.time() - start_time 

431 

432 return StreamResult( 

433 total_processed=total_written, 

434 successful=total_written, 

435 failed=0, 

436 duration=elapsed, 

437 total_batches=(total_written + config.batch_size - 1) // config.batch_size 

438 ) 

439 

440 async def vector_search( 

441 self, 

442 query_vector, 

443 vector_field: str = "embedding", 

444 k: int = 10, 

445 filter=None, 

446 metric=None, 

447 **kwargs 

448 ): 

449 """Perform async vector similarity search using Python-based calculations. 

450  

451 Delegates to PythonVectorSearchMixin for the implementation. 

452  

453 Args: 

454 query_vector: Query vector 

455 vector_field: Name of the vector field to search 

456 k: Number of results to return  

457 filter: Optional filter conditions 

458 metric: Distance metric (uses instance default if not specified) 

459 **kwargs: Additional arguments for compatibility 

460  

461 Returns: 

462 List of VectorSearchResult objects with scores 

463 """ 

464 self._check_connection() 

465 

466 # Delegate to the mixin's implementation 

467 return await self.python_vector_search_async( 

468 query_vector=query_vector, 

469 vector_field=vector_field, 

470 k=k, 

471 filter=filter, 

472 metric=metric, 

473 **kwargs 

474 )