Coverage for integrations / channels / memory / embeddings.py: 74.6%

299 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Embedding Cache - Cache embeddings with TTL. 

3 

4Provides persistent caching of text embeddings with configurable TTL, 

5batch operations, and optional Redis backend for distributed setups. 

6Designed for Docker environments with container-compatible storage. 

7""" 

8 

9import asyncio 

10import hashlib 

11import json 

12import os 

13import sqlite3 

14import threading 

15import time 

16from dataclasses import dataclass, field 

17from datetime import datetime, timedelta 

18from typing import Any, Callable, Dict, List, Optional, Tuple, Union 

19 

20try: 

21 import redis 

22 REDIS_AVAILABLE = True 

23except ImportError: 

24 REDIS_AVAILABLE = False 

25 

26 

27@dataclass 

28class EmbeddingConfig: 

29 """Configuration for the embedding cache.""" 

30 

31 # Storage settings 

32 db_path: Optional[str] = None 

33 redis_url: Optional[str] = None # e.g., "redis://localhost:6379/0" 

34 

35 # Cache settings 

36 ttl_days: int = 30 

37 max_entries: int = 100000 

38 

39 # Embedding settings 

40 default_model: str = "default" 

41 embedding_dims: int = 384 

42 

43 # Performance settings 

44 batch_size: int = 32 

45 cleanup_interval_hours: int = 24 

46 

47 

48@dataclass 

49class EmbeddingResult: 

50 """Result of an embedding lookup or computation.""" 

51 

52 text_hash: str 

53 embedding: List[float] 

54 model: str 

55 cached: bool = False 

56 created_at: Optional[datetime] = None 

57 expires_at: Optional[datetime] = None 

58 

59 def to_dict(self) -> Dict[str, Any]: 

60 """Convert to dictionary representation.""" 

61 return { 

62 "text_hash": self.text_hash, 

63 "embedding": self.embedding, 

64 "model": self.model, 

65 "cached": self.cached, 

66 "created_at": self.created_at.isoformat() if self.created_at else None, 

67 "expires_at": self.expires_at.isoformat() if self.expires_at else None, 

68 } 

69 

70 

71@dataclass 

72class CacheStats: 

73 """Statistics about the embedding cache.""" 

74 

75 total_entries: int = 0 

76 cache_hits: int = 0 

77 cache_misses: int = 0 

78 total_lookups: int = 0 

79 expired_entries: int = 0 

80 storage_bytes: int = 0 

81 oldest_entry: Optional[datetime] = None 

82 newest_entry: Optional[datetime] = None 

83 

84 @property 

85 def hit_rate(self) -> float: 

86 """Calculate cache hit rate.""" 

87 if self.total_lookups == 0: 

88 return 0.0 

89 return self.cache_hits / self.total_lookups 

90 

91 def to_dict(self) -> Dict[str, Any]: 

92 """Convert to dictionary representation.""" 

93 return { 

94 "total_entries": self.total_entries, 

95 "cache_hits": self.cache_hits, 

96 "cache_misses": self.cache_misses, 

97 "total_lookups": self.total_lookups, 

98 "hit_rate": self.hit_rate, 

99 "expired_entries": self.expired_entries, 

100 "storage_bytes": self.storage_bytes, 

101 "oldest_entry": self.oldest_entry.isoformat() if self.oldest_entry else None, 

102 "newest_entry": self.newest_entry.isoformat() if self.newest_entry else None, 

103 } 

104 

105 

106class RedisBackend: 

107 """Redis backend for distributed embedding cache.""" 

108 

109 def __init__(self, redis_url: str, ttl_days: int = 30, prefix: str = "emb:"): 

110 """ 

111 Initialize Redis backend. 

112 

113 Args: 

114 redis_url: Redis connection URL. 

115 ttl_days: TTL for cached embeddings. 

116 prefix: Key prefix for namespacing. 

117 """ 

118 if not REDIS_AVAILABLE: 

119 raise ImportError("redis package is required for Redis backend") 

120 

121 self.client = redis.from_url(redis_url) 

122 self.ttl_seconds = ttl_days * 24 * 3600 

123 self.prefix = prefix 

124 

125 def _key(self, text_hash: str, model: str) -> str: 

126 """Generate Redis key.""" 

127 return f"{self.prefix}{model}:{text_hash}" 

128 

129 def get(self, text_hash: str, model: str) -> Optional[List[float]]: 

130 """Get cached embedding.""" 

131 data = self.client.get(self._key(text_hash, model)) 

132 if data: 

133 return json.loads(data) 

134 return None 

135 

136 def set(self, text_hash: str, model: str, embedding: List[float]) -> None: 

137 """Set cached embedding.""" 

138 self.client.setex( 

139 self._key(text_hash, model), 

140 self.ttl_seconds, 

141 json.dumps(embedding), 

142 ) 

143 

144 def delete(self, text_hash: str, model: str) -> bool: 

145 """Delete cached embedding.""" 

146 return self.client.delete(self._key(text_hash, model)) > 0 

147 

148 def exists(self, text_hash: str, model: str) -> bool: 

149 """Check if embedding exists.""" 

150 return self.client.exists(self._key(text_hash, model)) > 0 

151 

152 def get_batch(self, keys: List[Tuple[str, str]]) -> Dict[str, List[float]]: 

153 """Get multiple embeddings at once.""" 

154 if not keys: 

155 return {} 

156 

157 redis_keys = [self._key(h, m) for h, m in keys] 

158 values = self.client.mget(redis_keys) 

159 

160 result = {} 

161 for (text_hash, model), value in zip(keys, values): 

162 if value: 

163 result[text_hash] = json.loads(value) 

164 return result 

165 

166 def cleanup(self) -> int: 

167 """Redis handles TTL automatically, so this is a no-op.""" 

168 return 0 

169 

170 

171class SQLiteBackend: 

172 """SQLite backend for local embedding cache.""" 

173 

174 # Default paths for Docker environments 

175 DEFAULT_DATA_DIR = "/app/data" 

176 DEFAULT_TEMP_DIR = "/tmp/embeddings" 

177 

178 def __init__(self, db_path: Optional[str] = None, ttl_days: int = 30): 

179 """ 

180 Initialize SQLite backend. 

181 

182 Args: 

183 db_path: Path to SQLite database. 

184 ttl_days: TTL for cached embeddings. 

185 """ 

186 if db_path: 

187 self.db_path = db_path 

188 else: 

189 # Use appropriate path for Docker vs local 

190 if os.path.exists("/tmp"): 

191 db_dir = self.DEFAULT_TEMP_DIR 

192 elif os.path.exists(self.DEFAULT_DATA_DIR): 

193 db_dir = os.path.join(self.DEFAULT_DATA_DIR, ".embeddings") 

194 else: 

195 db_dir = os.path.join(os.path.abspath("."), ".embeddings") 

196 

197 os.makedirs(db_dir, exist_ok=True) 

198 self.db_path = os.path.join(db_dir, "embeddings.db") 

199 

200 self.ttl_days = ttl_days 

201 self._lock = threading.RLock() 

202 self._conn: Optional[sqlite3.Connection] = None 

203 self._ensure_schema() 

204 

205 def _ensure_connection(self) -> sqlite3.Connection: 

206 """Ensure database connection.""" 

207 if self._conn is None: 

208 os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True) 

209 

210 self._conn = sqlite3.connect( 

211 self.db_path, 

212 check_same_thread=False, 

213 isolation_level=None, 

214 timeout=30.0, 

215 ) 

216 self._conn.row_factory = sqlite3.Row 

217 

218 # Enable WAL mode for better concurrency 

219 self._conn.execute("PRAGMA journal_mode=WAL") 

220 self._conn.execute("PRAGMA busy_timeout=30000") 

221 

222 return self._conn 

223 

224 def _ensure_schema(self) -> None: 

225 """Create database schema.""" 

226 conn = self._ensure_connection() 

227 with self._lock: 

228 conn.execute(""" 

229 CREATE TABLE IF NOT EXISTS embeddings ( 

230 id INTEGER PRIMARY KEY AUTOINCREMENT, 

231 text_hash TEXT NOT NULL, 

232 model TEXT NOT NULL, 

233 embedding TEXT NOT NULL, 

234 created_at REAL NOT NULL, 

235 expires_at REAL NOT NULL, 

236 access_count INTEGER DEFAULT 0, 

237 last_accessed REAL, 

238 UNIQUE(text_hash, model) 

239 ) 

240 """) 

241 

242 conn.execute("CREATE INDEX IF NOT EXISTS idx_emb_hash ON embeddings(text_hash)") 

243 conn.execute("CREATE INDEX IF NOT EXISTS idx_emb_model ON embeddings(model)") 

244 conn.execute("CREATE INDEX IF NOT EXISTS idx_emb_expires ON embeddings(expires_at)") 

245 

246 def get(self, text_hash: str, model: str) -> Optional[List[float]]: 

247 """Get cached embedding.""" 

248 conn = self._ensure_connection() 

249 now = time.time() 

250 

251 with self._lock: 

252 row = conn.execute( 

253 """ 

254 SELECT embedding FROM embeddings 

255 WHERE text_hash = ? AND model = ? AND expires_at > ? 

256 """, 

257 (text_hash, model, now) 

258 ).fetchone() 

259 

260 if row: 

261 # Update access stats 

262 conn.execute( 

263 """ 

264 UPDATE embeddings 

265 SET access_count = access_count + 1, last_accessed = ? 

266 WHERE text_hash = ? AND model = ? 

267 """, 

268 (now, text_hash, model) 

269 ) 

270 return json.loads(row["embedding"]) 

271 

272 return None 

273 

274 def set(self, text_hash: str, model: str, embedding: List[float]) -> None: 

275 """Set cached embedding.""" 

276 conn = self._ensure_connection() 

277 now = time.time() 

278 expires = now + (self.ttl_days * 24 * 3600) 

279 

280 with self._lock: 

281 conn.execute( 

282 """ 

283 INSERT OR REPLACE INTO embeddings 

284 (text_hash, model, embedding, created_at, expires_at, last_accessed) 

285 VALUES (?, ?, ?, ?, ?, ?) 

286 """, 

287 (text_hash, model, json.dumps(embedding), now, expires, now) 

288 ) 

289 

290 def delete(self, text_hash: str, model: str = None) -> bool: 

291 """Delete cached embedding.""" 

292 conn = self._ensure_connection() 

293 with self._lock: 

294 if model: 

295 cursor = conn.execute( 

296 "DELETE FROM embeddings WHERE text_hash = ? AND model = ?", 

297 (text_hash, model) 

298 ) 

299 else: 

300 cursor = conn.execute( 

301 "DELETE FROM embeddings WHERE text_hash = ?", 

302 (text_hash,) 

303 ) 

304 return cursor.rowcount > 0 

305 

306 def exists(self, text_hash: str, model: str) -> bool: 

307 """Check if embedding exists.""" 

308 conn = self._ensure_connection() 

309 now = time.time() 

310 with self._lock: 

311 row = conn.execute( 

312 """ 

313 SELECT 1 FROM embeddings 

314 WHERE text_hash = ? AND model = ? AND expires_at > ? 

315 """, 

316 (text_hash, model, now) 

317 ).fetchone() 

318 return row is not None 

319 

320 def get_batch(self, keys: List[Tuple[str, str]]) -> Dict[str, List[float]]: 

321 """Get multiple embeddings at once.""" 

322 if not keys: 

323 return {} 

324 

325 conn = self._ensure_connection() 

326 now = time.time() 

327 result = {} 

328 

329 with self._lock: 

330 # Use OR conditions — SQLite doesn't support multi-column IN with params 

331 conditions = " OR ".join(["(text_hash = ? AND model = ?)"] * len(keys)) 

332 params = [] 

333 for text_hash, model in keys: 

334 params.extend([text_hash, model]) 

335 params.append(now) 

336 

337 rows = conn.execute( 

338 f""" 

339 SELECT text_hash, embedding FROM embeddings 

340 WHERE ({conditions}) AND expires_at > ? 

341 """, 

342 params 

343 ).fetchall() 

344 

345 for row in rows: 

346 result[row["text_hash"]] = json.loads(row["embedding"]) 

347 

348 return result 

349 

350 def cleanup(self, max_age_days: Optional[int] = None) -> int: 

351 """Remove expired entries.""" 

352 conn = self._ensure_connection() 

353 now = time.time() 

354 

355 with self._lock: 

356 if max_age_days is not None: 

357 cutoff = now - (max_age_days * 24 * 3600) 

358 cursor = conn.execute( 

359 "DELETE FROM embeddings WHERE created_at < ?", 

360 (cutoff,) 

361 ) 

362 else: 

363 cursor = conn.execute( 

364 "DELETE FROM embeddings WHERE expires_at < ?", 

365 (now,) 

366 ) 

367 return cursor.rowcount 

368 

369 def get_stats(self) -> Dict[str, Any]: 

370 """Get cache statistics.""" 

371 conn = self._ensure_connection() 

372 now = time.time() 

373 

374 with self._lock: 

375 total = conn.execute("SELECT COUNT(*) as cnt FROM embeddings").fetchone()["cnt"] 

376 expired = conn.execute( 

377 "SELECT COUNT(*) as cnt FROM embeddings WHERE expires_at < ?", 

378 (now,) 

379 ).fetchone()["cnt"] 

380 

381 dates = conn.execute( 

382 "SELECT MIN(created_at) as oldest, MAX(created_at) as newest FROM embeddings" 

383 ).fetchone() 

384 

385 # Estimate storage size 

386 size_row = conn.execute( 

387 "SELECT SUM(LENGTH(embedding)) as total FROM embeddings" 

388 ).fetchone() 

389 storage = size_row["total"] or 0 

390 

391 return { 

392 "total_entries": total, 

393 "expired_entries": expired, 

394 "storage_bytes": storage, 

395 "oldest_entry": datetime.fromtimestamp(dates["oldest"]) if dates["oldest"] else None, 

396 "newest_entry": datetime.fromtimestamp(dates["newest"]) if dates["newest"] else None, 

397 } 

398 

399 def close(self) -> None: 

400 """Close database connection.""" 

401 with self._lock: 

402 if self._conn: 

403 self._conn.close() 

404 self._conn = None 

405 

406 

407class EmbeddingCache: 

408 """ 

409 Cache embeddings with TTL. 

410 

411 Provides efficient caching of text embeddings with configurable storage 

412 backend (SQLite for local, Redis for distributed). 

413 

414 Features: 

415 - Automatic TTL-based expiration 

416 - Batch operations for efficiency 

417 - Support for multiple embedding models 

418 - Optional Redis backend for distributed setups 

419 """ 

420 

421 def __init__( 

422 self, 

423 config: Optional[EmbeddingConfig] = None, 

424 embedding_fn: Optional[Callable[[str], List[float]]] = None, 

425 async_embedding_fn: Optional[Callable[[str], List[float]]] = None, 

426 ): 

427 """ 

428 Initialize the embedding cache. 

429 

430 Args: 

431 config: Cache configuration. 

432 embedding_fn: Synchronous function to compute embeddings. 

433 async_embedding_fn: Async function to compute embeddings. 

434 """ 

435 self.config = config or EmbeddingConfig() 

436 self.embedding_fn = embedding_fn 

437 self.async_embedding_fn = async_embedding_fn 

438 

439 # Initialize backend 

440 if self.config.redis_url and REDIS_AVAILABLE: 

441 self._backend = RedisBackend( 

442 self.config.redis_url, 

443 ttl_days=self.config.ttl_days, 

444 ) 

445 self._backend_type = "redis" 

446 else: 

447 self._backend = SQLiteBackend( 

448 db_path=self.config.db_path, 

449 ttl_days=self.config.ttl_days, 

450 ) 

451 self._backend_type = "sqlite" 

452 

453 # Stats tracking 

454 self._lock = threading.RLock() 

455 self._stats = CacheStats() 

456 self._last_cleanup = time.time() 

457 

458 @staticmethod 

459 def _compute_hash(text: str) -> str: 

460 """Compute SHA256 hash of text.""" 

461 return hashlib.sha256(text.encode("utf-8")).hexdigest() 

462 

463 async def get_embedding( 

464 self, 

465 text: str, 

466 model: Optional[str] = None, 

467 ) -> List[float]: 

468 """ 

469 Get embedding for text, computing if not cached. 

470 

471 Args: 

472 text: Text to get embedding for. 

473 model: Optional model identifier. 

474 

475 Returns: 

476 Embedding vector as list of floats. 

477 

478 Raises: 

479 ValueError: If no embedding function is configured. 

480 """ 

481 model = model or self.config.default_model 

482 text_hash = self._compute_hash(text) 

483 

484 # Check cache 

485 cached = self._backend.get(text_hash, model) 

486 with self._lock: 

487 self._stats.total_lookups += 1 

488 if cached: 

489 self._stats.cache_hits += 1 

490 return cached 

491 self._stats.cache_misses += 1 

492 

493 # Compute embedding 

494 if self.async_embedding_fn: 

495 embedding = await self.async_embedding_fn(text) 

496 elif self.embedding_fn: 

497 embedding = self.embedding_fn(text) 

498 else: 

499 raise ValueError("No embedding function configured") 

500 

501 # Store in cache 

502 self._backend.set(text_hash, model, embedding) 

503 

504 # Periodic cleanup 

505 self._maybe_cleanup() 

506 

507 return embedding 

508 

509 async def get_batch( 

510 self, 

511 texts: List[str], 

512 model: Optional[str] = None, 

513 ) -> List[List[float]]: 

514 """ 

515 Get embeddings for multiple texts. 

516 

517 Args: 

518 texts: List of texts to get embeddings for. 

519 model: Optional model identifier. 

520 

521 Returns: 

522 List of embedding vectors. 

523 """ 

524 model = model or self.config.default_model 

525 

526 # Compute hashes 

527 hashes = [self._compute_hash(text) for text in texts] 

528 

529 # Batch lookup 

530 cached = self._backend.get_batch([(h, model) for h in hashes]) 

531 

532 with self._lock: 

533 self._stats.total_lookups += len(texts) 

534 self._stats.cache_hits += len(cached) 

535 self._stats.cache_misses += len(texts) - len(cached) 

536 

537 # Find missing embeddings 

538 results: List[Optional[List[float]]] = [None] * len(texts) 

539 to_compute: List[Tuple[int, str]] = [] 

540 

541 for i, (text, text_hash) in enumerate(zip(texts, hashes)): 

542 if text_hash in cached: 

543 results[i] = cached[text_hash] 

544 else: 

545 to_compute.append((i, text)) 

546 

547 # Compute missing embeddings 

548 if to_compute: 

549 for i, text in to_compute: 

550 if self.async_embedding_fn: 

551 embedding = await self.async_embedding_fn(text) 

552 elif self.embedding_fn: 

553 embedding = self.embedding_fn(text) 

554 else: 

555 raise ValueError("No embedding function configured") 

556 

557 results[i] = embedding 

558 self._backend.set(hashes[i], model, embedding) 

559 

560 # Periodic cleanup 

561 self._maybe_cleanup() 

562 

563 return results # type: ignore 

564 

565 def get_embedding_sync( 

566 self, 

567 text: str, 

568 model: Optional[str] = None, 

569 ) -> List[float]: 

570 """ 

571 Synchronous version of get_embedding. 

572 

573 Args: 

574 text: Text to get embedding for. 

575 model: Optional model identifier. 

576 

577 Returns: 

578 Embedding vector as list of floats. 

579 """ 

580 model = model or self.config.default_model 

581 text_hash = self._compute_hash(text) 

582 

583 # Check cache 

584 cached = self._backend.get(text_hash, model) 

585 with self._lock: 

586 self._stats.total_lookups += 1 

587 if cached: 

588 self._stats.cache_hits += 1 

589 return cached 

590 self._stats.cache_misses += 1 

591 

592 # Compute embedding 

593 if not self.embedding_fn: 

594 raise ValueError("No synchronous embedding function configured") 

595 

596 embedding = self.embedding_fn(text) 

597 

598 # Store in cache 

599 self._backend.set(text_hash, model, embedding) 

600 

601 return embedding 

602 

603 def invalidate(self, text_hash: str) -> bool: 

604 """ 

605 Invalidate a cached embedding by its hash. 

606 

607 Args: 

608 text_hash: SHA256 hash of the text. 

609 

610 Returns: 

611 True if entry was removed, False if not found. 

612 """ 

613 return self._backend.delete(text_hash) 

614 

615 def invalidate_text(self, text: str, model: Optional[str] = None) -> bool: 

616 """ 

617 Invalidate a cached embedding by text content. 

618 

619 Args: 

620 text: Original text. 

621 model: Optional model identifier. 

622 

623 Returns: 

624 True if entry was removed, False if not found. 

625 """ 

626 text_hash = self._compute_hash(text) 

627 model = model or self.config.default_model 

628 return self._backend.delete(text_hash, model) 

629 

630 def cleanup(self, max_age_days: int = 30) -> int: 

631 """ 

632 Remove expired or old cache entries. 

633 

634 Args: 

635 max_age_days: Maximum age of entries to keep. 

636 

637 Returns: 

638 Number of entries removed. 

639 """ 

640 count = self._backend.cleanup(max_age_days) 

641 with self._lock: 

642 self._stats.expired_entries += count 

643 self._last_cleanup = time.time() 

644 return count 

645 

646 def _maybe_cleanup(self) -> None: 

647 """Run cleanup if enough time has passed.""" 

648 if time.time() - self._last_cleanup > (self.config.cleanup_interval_hours * 3600): 

649 # Run cleanup in background 

650 self._backend.cleanup() 

651 with self._lock: 

652 self._last_cleanup = time.time() 

653 

654 def get_stats(self) -> CacheStats: 

655 """ 

656 Get cache statistics. 

657 

658 Returns: 

659 CacheStats object with current statistics. 

660 """ 

661 if self._backend_type == "sqlite": 

662 backend_stats = self._backend.get_stats() 

663 else: 

664 backend_stats = {} 

665 

666 with self._lock: 

667 stats = CacheStats( 

668 total_entries=backend_stats.get("total_entries", 0), 

669 cache_hits=self._stats.cache_hits, 

670 cache_misses=self._stats.cache_misses, 

671 total_lookups=self._stats.total_lookups, 

672 expired_entries=self._stats.expired_entries + backend_stats.get("expired_entries", 0), 

673 storage_bytes=backend_stats.get("storage_bytes", 0), 

674 oldest_entry=backend_stats.get("oldest_entry"), 

675 newest_entry=backend_stats.get("newest_entry"), 

676 ) 

677 return stats 

678 

679 def clear(self) -> int: 

680 """ 

681 Clear all cached embeddings. 

682 

683 Returns: 

684 Number of entries removed. 

685 """ 

686 if self._backend_type == "sqlite": 

687 conn = self._backend._ensure_connection() 

688 with self._backend._lock: 

689 cursor = conn.execute("DELETE FROM embeddings") 

690 return cursor.rowcount 

691 else: 

692 # Redis: use scan to find and delete keys 

693 count = 0 

694 for key in self._backend.client.scan_iter(f"{self._backend.prefix}*"): 

695 self._backend.client.delete(key) 

696 count += 1 

697 return count 

698 

699 def close(self) -> None: 

700 """Close the cache and release resources.""" 

701 if self._backend_type == "sqlite": 

702 self._backend.close() 

703 

704 def __enter__(self): 

705 return self 

706 

707 def __exit__(self, exc_type, exc_val, exc_tb): 

708 self.close() 

709 return False