Coverage for integrations / channels / memory / embeddings.py: 74.6%
299 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Embedding Cache - Cache embeddings with TTL.
4Provides persistent caching of text embeddings with configurable TTL,
5batch operations, and optional Redis backend for distributed setups.
6Designed for Docker environments with container-compatible storage.
7"""
9import asyncio
10import hashlib
11import json
12import os
13import sqlite3
14import threading
15import time
16from dataclasses import dataclass, field
17from datetime import datetime, timedelta
18from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20try:
21 import redis
22 REDIS_AVAILABLE = True
23except ImportError:
24 REDIS_AVAILABLE = False
27@dataclass
28class EmbeddingConfig:
29 """Configuration for the embedding cache."""
31 # Storage settings
32 db_path: Optional[str] = None
33 redis_url: Optional[str] = None # e.g., "redis://localhost:6379/0"
35 # Cache settings
36 ttl_days: int = 30
37 max_entries: int = 100000
39 # Embedding settings
40 default_model: str = "default"
41 embedding_dims: int = 384
43 # Performance settings
44 batch_size: int = 32
45 cleanup_interval_hours: int = 24
48@dataclass
49class EmbeddingResult:
50 """Result of an embedding lookup or computation."""
52 text_hash: str
53 embedding: List[float]
54 model: str
55 cached: bool = False
56 created_at: Optional[datetime] = None
57 expires_at: Optional[datetime] = None
59 def to_dict(self) -> Dict[str, Any]:
60 """Convert to dictionary representation."""
61 return {
62 "text_hash": self.text_hash,
63 "embedding": self.embedding,
64 "model": self.model,
65 "cached": self.cached,
66 "created_at": self.created_at.isoformat() if self.created_at else None,
67 "expires_at": self.expires_at.isoformat() if self.expires_at else None,
68 }
71@dataclass
72class CacheStats:
73 """Statistics about the embedding cache."""
75 total_entries: int = 0
76 cache_hits: int = 0
77 cache_misses: int = 0
78 total_lookups: int = 0
79 expired_entries: int = 0
80 storage_bytes: int = 0
81 oldest_entry: Optional[datetime] = None
82 newest_entry: Optional[datetime] = None
84 @property
85 def hit_rate(self) -> float:
86 """Calculate cache hit rate."""
87 if self.total_lookups == 0:
88 return 0.0
89 return self.cache_hits / self.total_lookups
91 def to_dict(self) -> Dict[str, Any]:
92 """Convert to dictionary representation."""
93 return {
94 "total_entries": self.total_entries,
95 "cache_hits": self.cache_hits,
96 "cache_misses": self.cache_misses,
97 "total_lookups": self.total_lookups,
98 "hit_rate": self.hit_rate,
99 "expired_entries": self.expired_entries,
100 "storage_bytes": self.storage_bytes,
101 "oldest_entry": self.oldest_entry.isoformat() if self.oldest_entry else None,
102 "newest_entry": self.newest_entry.isoformat() if self.newest_entry else None,
103 }
106class RedisBackend:
107 """Redis backend for distributed embedding cache."""
109 def __init__(self, redis_url: str, ttl_days: int = 30, prefix: str = "emb:"):
110 """
111 Initialize Redis backend.
113 Args:
114 redis_url: Redis connection URL.
115 ttl_days: TTL for cached embeddings.
116 prefix: Key prefix for namespacing.
117 """
118 if not REDIS_AVAILABLE:
119 raise ImportError("redis package is required for Redis backend")
121 self.client = redis.from_url(redis_url)
122 self.ttl_seconds = ttl_days * 24 * 3600
123 self.prefix = prefix
125 def _key(self, text_hash: str, model: str) -> str:
126 """Generate Redis key."""
127 return f"{self.prefix}{model}:{text_hash}"
129 def get(self, text_hash: str, model: str) -> Optional[List[float]]:
130 """Get cached embedding."""
131 data = self.client.get(self._key(text_hash, model))
132 if data:
133 return json.loads(data)
134 return None
136 def set(self, text_hash: str, model: str, embedding: List[float]) -> None:
137 """Set cached embedding."""
138 self.client.setex(
139 self._key(text_hash, model),
140 self.ttl_seconds,
141 json.dumps(embedding),
142 )
144 def delete(self, text_hash: str, model: str) -> bool:
145 """Delete cached embedding."""
146 return self.client.delete(self._key(text_hash, model)) > 0
148 def exists(self, text_hash: str, model: str) -> bool:
149 """Check if embedding exists."""
150 return self.client.exists(self._key(text_hash, model)) > 0
152 def get_batch(self, keys: List[Tuple[str, str]]) -> Dict[str, List[float]]:
153 """Get multiple embeddings at once."""
154 if not keys:
155 return {}
157 redis_keys = [self._key(h, m) for h, m in keys]
158 values = self.client.mget(redis_keys)
160 result = {}
161 for (text_hash, model), value in zip(keys, values):
162 if value:
163 result[text_hash] = json.loads(value)
164 return result
166 def cleanup(self) -> int:
167 """Redis handles TTL automatically, so this is a no-op."""
168 return 0
171class SQLiteBackend:
172 """SQLite backend for local embedding cache."""
174 # Default paths for Docker environments
175 DEFAULT_DATA_DIR = "/app/data"
176 DEFAULT_TEMP_DIR = "/tmp/embeddings"
178 def __init__(self, db_path: Optional[str] = None, ttl_days: int = 30):
179 """
180 Initialize SQLite backend.
182 Args:
183 db_path: Path to SQLite database.
184 ttl_days: TTL for cached embeddings.
185 """
186 if db_path:
187 self.db_path = db_path
188 else:
189 # Use appropriate path for Docker vs local
190 if os.path.exists("/tmp"):
191 db_dir = self.DEFAULT_TEMP_DIR
192 elif os.path.exists(self.DEFAULT_DATA_DIR):
193 db_dir = os.path.join(self.DEFAULT_DATA_DIR, ".embeddings")
194 else:
195 db_dir = os.path.join(os.path.abspath("."), ".embeddings")
197 os.makedirs(db_dir, exist_ok=True)
198 self.db_path = os.path.join(db_dir, "embeddings.db")
200 self.ttl_days = ttl_days
201 self._lock = threading.RLock()
202 self._conn: Optional[sqlite3.Connection] = None
203 self._ensure_schema()
205 def _ensure_connection(self) -> sqlite3.Connection:
206 """Ensure database connection."""
207 if self._conn is None:
208 os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True)
210 self._conn = sqlite3.connect(
211 self.db_path,
212 check_same_thread=False,
213 isolation_level=None,
214 timeout=30.0,
215 )
216 self._conn.row_factory = sqlite3.Row
218 # Enable WAL mode for better concurrency
219 self._conn.execute("PRAGMA journal_mode=WAL")
220 self._conn.execute("PRAGMA busy_timeout=30000")
222 return self._conn
224 def _ensure_schema(self) -> None:
225 """Create database schema."""
226 conn = self._ensure_connection()
227 with self._lock:
228 conn.execute("""
229 CREATE TABLE IF NOT EXISTS embeddings (
230 id INTEGER PRIMARY KEY AUTOINCREMENT,
231 text_hash TEXT NOT NULL,
232 model TEXT NOT NULL,
233 embedding TEXT NOT NULL,
234 created_at REAL NOT NULL,
235 expires_at REAL NOT NULL,
236 access_count INTEGER DEFAULT 0,
237 last_accessed REAL,
238 UNIQUE(text_hash, model)
239 )
240 """)
242 conn.execute("CREATE INDEX IF NOT EXISTS idx_emb_hash ON embeddings(text_hash)")
243 conn.execute("CREATE INDEX IF NOT EXISTS idx_emb_model ON embeddings(model)")
244 conn.execute("CREATE INDEX IF NOT EXISTS idx_emb_expires ON embeddings(expires_at)")
246 def get(self, text_hash: str, model: str) -> Optional[List[float]]:
247 """Get cached embedding."""
248 conn = self._ensure_connection()
249 now = time.time()
251 with self._lock:
252 row = conn.execute(
253 """
254 SELECT embedding FROM embeddings
255 WHERE text_hash = ? AND model = ? AND expires_at > ?
256 """,
257 (text_hash, model, now)
258 ).fetchone()
260 if row:
261 # Update access stats
262 conn.execute(
263 """
264 UPDATE embeddings
265 SET access_count = access_count + 1, last_accessed = ?
266 WHERE text_hash = ? AND model = ?
267 """,
268 (now, text_hash, model)
269 )
270 return json.loads(row["embedding"])
272 return None
274 def set(self, text_hash: str, model: str, embedding: List[float]) -> None:
275 """Set cached embedding."""
276 conn = self._ensure_connection()
277 now = time.time()
278 expires = now + (self.ttl_days * 24 * 3600)
280 with self._lock:
281 conn.execute(
282 """
283 INSERT OR REPLACE INTO embeddings
284 (text_hash, model, embedding, created_at, expires_at, last_accessed)
285 VALUES (?, ?, ?, ?, ?, ?)
286 """,
287 (text_hash, model, json.dumps(embedding), now, expires, now)
288 )
290 def delete(self, text_hash: str, model: str = None) -> bool:
291 """Delete cached embedding."""
292 conn = self._ensure_connection()
293 with self._lock:
294 if model:
295 cursor = conn.execute(
296 "DELETE FROM embeddings WHERE text_hash = ? AND model = ?",
297 (text_hash, model)
298 )
299 else:
300 cursor = conn.execute(
301 "DELETE FROM embeddings WHERE text_hash = ?",
302 (text_hash,)
303 )
304 return cursor.rowcount > 0
306 def exists(self, text_hash: str, model: str) -> bool:
307 """Check if embedding exists."""
308 conn = self._ensure_connection()
309 now = time.time()
310 with self._lock:
311 row = conn.execute(
312 """
313 SELECT 1 FROM embeddings
314 WHERE text_hash = ? AND model = ? AND expires_at > ?
315 """,
316 (text_hash, model, now)
317 ).fetchone()
318 return row is not None
320 def get_batch(self, keys: List[Tuple[str, str]]) -> Dict[str, List[float]]:
321 """Get multiple embeddings at once."""
322 if not keys:
323 return {}
325 conn = self._ensure_connection()
326 now = time.time()
327 result = {}
329 with self._lock:
330 # Use OR conditions — SQLite doesn't support multi-column IN with params
331 conditions = " OR ".join(["(text_hash = ? AND model = ?)"] * len(keys))
332 params = []
333 for text_hash, model in keys:
334 params.extend([text_hash, model])
335 params.append(now)
337 rows = conn.execute(
338 f"""
339 SELECT text_hash, embedding FROM embeddings
340 WHERE ({conditions}) AND expires_at > ?
341 """,
342 params
343 ).fetchall()
345 for row in rows:
346 result[row["text_hash"]] = json.loads(row["embedding"])
348 return result
350 def cleanup(self, max_age_days: Optional[int] = None) -> int:
351 """Remove expired entries."""
352 conn = self._ensure_connection()
353 now = time.time()
355 with self._lock:
356 if max_age_days is not None:
357 cutoff = now - (max_age_days * 24 * 3600)
358 cursor = conn.execute(
359 "DELETE FROM embeddings WHERE created_at < ?",
360 (cutoff,)
361 )
362 else:
363 cursor = conn.execute(
364 "DELETE FROM embeddings WHERE expires_at < ?",
365 (now,)
366 )
367 return cursor.rowcount
369 def get_stats(self) -> Dict[str, Any]:
370 """Get cache statistics."""
371 conn = self._ensure_connection()
372 now = time.time()
374 with self._lock:
375 total = conn.execute("SELECT COUNT(*) as cnt FROM embeddings").fetchone()["cnt"]
376 expired = conn.execute(
377 "SELECT COUNT(*) as cnt FROM embeddings WHERE expires_at < ?",
378 (now,)
379 ).fetchone()["cnt"]
381 dates = conn.execute(
382 "SELECT MIN(created_at) as oldest, MAX(created_at) as newest FROM embeddings"
383 ).fetchone()
385 # Estimate storage size
386 size_row = conn.execute(
387 "SELECT SUM(LENGTH(embedding)) as total FROM embeddings"
388 ).fetchone()
389 storage = size_row["total"] or 0
391 return {
392 "total_entries": total,
393 "expired_entries": expired,
394 "storage_bytes": storage,
395 "oldest_entry": datetime.fromtimestamp(dates["oldest"]) if dates["oldest"] else None,
396 "newest_entry": datetime.fromtimestamp(dates["newest"]) if dates["newest"] else None,
397 }
399 def close(self) -> None:
400 """Close database connection."""
401 with self._lock:
402 if self._conn:
403 self._conn.close()
404 self._conn = None
407class EmbeddingCache:
408 """
409 Cache embeddings with TTL.
411 Provides efficient caching of text embeddings with configurable storage
412 backend (SQLite for local, Redis for distributed).
414 Features:
415 - Automatic TTL-based expiration
416 - Batch operations for efficiency
417 - Support for multiple embedding models
418 - Optional Redis backend for distributed setups
419 """
421 def __init__(
422 self,
423 config: Optional[EmbeddingConfig] = None,
424 embedding_fn: Optional[Callable[[str], List[float]]] = None,
425 async_embedding_fn: Optional[Callable[[str], List[float]]] = None,
426 ):
427 """
428 Initialize the embedding cache.
430 Args:
431 config: Cache configuration.
432 embedding_fn: Synchronous function to compute embeddings.
433 async_embedding_fn: Async function to compute embeddings.
434 """
435 self.config = config or EmbeddingConfig()
436 self.embedding_fn = embedding_fn
437 self.async_embedding_fn = async_embedding_fn
439 # Initialize backend
440 if self.config.redis_url and REDIS_AVAILABLE:
441 self._backend = RedisBackend(
442 self.config.redis_url,
443 ttl_days=self.config.ttl_days,
444 )
445 self._backend_type = "redis"
446 else:
447 self._backend = SQLiteBackend(
448 db_path=self.config.db_path,
449 ttl_days=self.config.ttl_days,
450 )
451 self._backend_type = "sqlite"
453 # Stats tracking
454 self._lock = threading.RLock()
455 self._stats = CacheStats()
456 self._last_cleanup = time.time()
458 @staticmethod
459 def _compute_hash(text: str) -> str:
460 """Compute SHA256 hash of text."""
461 return hashlib.sha256(text.encode("utf-8")).hexdigest()
463 async def get_embedding(
464 self,
465 text: str,
466 model: Optional[str] = None,
467 ) -> List[float]:
468 """
469 Get embedding for text, computing if not cached.
471 Args:
472 text: Text to get embedding for.
473 model: Optional model identifier.
475 Returns:
476 Embedding vector as list of floats.
478 Raises:
479 ValueError: If no embedding function is configured.
480 """
481 model = model or self.config.default_model
482 text_hash = self._compute_hash(text)
484 # Check cache
485 cached = self._backend.get(text_hash, model)
486 with self._lock:
487 self._stats.total_lookups += 1
488 if cached:
489 self._stats.cache_hits += 1
490 return cached
491 self._stats.cache_misses += 1
493 # Compute embedding
494 if self.async_embedding_fn:
495 embedding = await self.async_embedding_fn(text)
496 elif self.embedding_fn:
497 embedding = self.embedding_fn(text)
498 else:
499 raise ValueError("No embedding function configured")
501 # Store in cache
502 self._backend.set(text_hash, model, embedding)
504 # Periodic cleanup
505 self._maybe_cleanup()
507 return embedding
509 async def get_batch(
510 self,
511 texts: List[str],
512 model: Optional[str] = None,
513 ) -> List[List[float]]:
514 """
515 Get embeddings for multiple texts.
517 Args:
518 texts: List of texts to get embeddings for.
519 model: Optional model identifier.
521 Returns:
522 List of embedding vectors.
523 """
524 model = model or self.config.default_model
526 # Compute hashes
527 hashes = [self._compute_hash(text) for text in texts]
529 # Batch lookup
530 cached = self._backend.get_batch([(h, model) for h in hashes])
532 with self._lock:
533 self._stats.total_lookups += len(texts)
534 self._stats.cache_hits += len(cached)
535 self._stats.cache_misses += len(texts) - len(cached)
537 # Find missing embeddings
538 results: List[Optional[List[float]]] = [None] * len(texts)
539 to_compute: List[Tuple[int, str]] = []
541 for i, (text, text_hash) in enumerate(zip(texts, hashes)):
542 if text_hash in cached:
543 results[i] = cached[text_hash]
544 else:
545 to_compute.append((i, text))
547 # Compute missing embeddings
548 if to_compute:
549 for i, text in to_compute:
550 if self.async_embedding_fn:
551 embedding = await self.async_embedding_fn(text)
552 elif self.embedding_fn:
553 embedding = self.embedding_fn(text)
554 else:
555 raise ValueError("No embedding function configured")
557 results[i] = embedding
558 self._backend.set(hashes[i], model, embedding)
560 # Periodic cleanup
561 self._maybe_cleanup()
563 return results # type: ignore
565 def get_embedding_sync(
566 self,
567 text: str,
568 model: Optional[str] = None,
569 ) -> List[float]:
570 """
571 Synchronous version of get_embedding.
573 Args:
574 text: Text to get embedding for.
575 model: Optional model identifier.
577 Returns:
578 Embedding vector as list of floats.
579 """
580 model = model or self.config.default_model
581 text_hash = self._compute_hash(text)
583 # Check cache
584 cached = self._backend.get(text_hash, model)
585 with self._lock:
586 self._stats.total_lookups += 1
587 if cached:
588 self._stats.cache_hits += 1
589 return cached
590 self._stats.cache_misses += 1
592 # Compute embedding
593 if not self.embedding_fn:
594 raise ValueError("No synchronous embedding function configured")
596 embedding = self.embedding_fn(text)
598 # Store in cache
599 self._backend.set(text_hash, model, embedding)
601 return embedding
603 def invalidate(self, text_hash: str) -> bool:
604 """
605 Invalidate a cached embedding by its hash.
607 Args:
608 text_hash: SHA256 hash of the text.
610 Returns:
611 True if entry was removed, False if not found.
612 """
613 return self._backend.delete(text_hash)
615 def invalidate_text(self, text: str, model: Optional[str] = None) -> bool:
616 """
617 Invalidate a cached embedding by text content.
619 Args:
620 text: Original text.
621 model: Optional model identifier.
623 Returns:
624 True if entry was removed, False if not found.
625 """
626 text_hash = self._compute_hash(text)
627 model = model or self.config.default_model
628 return self._backend.delete(text_hash, model)
630 def cleanup(self, max_age_days: int = 30) -> int:
631 """
632 Remove expired or old cache entries.
634 Args:
635 max_age_days: Maximum age of entries to keep.
637 Returns:
638 Number of entries removed.
639 """
640 count = self._backend.cleanup(max_age_days)
641 with self._lock:
642 self._stats.expired_entries += count
643 self._last_cleanup = time.time()
644 return count
646 def _maybe_cleanup(self) -> None:
647 """Run cleanup if enough time has passed."""
648 if time.time() - self._last_cleanup > (self.config.cleanup_interval_hours * 3600):
649 # Run cleanup in background
650 self._backend.cleanup()
651 with self._lock:
652 self._last_cleanup = time.time()
654 def get_stats(self) -> CacheStats:
655 """
656 Get cache statistics.
658 Returns:
659 CacheStats object with current statistics.
660 """
661 if self._backend_type == "sqlite":
662 backend_stats = self._backend.get_stats()
663 else:
664 backend_stats = {}
666 with self._lock:
667 stats = CacheStats(
668 total_entries=backend_stats.get("total_entries", 0),
669 cache_hits=self._stats.cache_hits,
670 cache_misses=self._stats.cache_misses,
671 total_lookups=self._stats.total_lookups,
672 expired_entries=self._stats.expired_entries + backend_stats.get("expired_entries", 0),
673 storage_bytes=backend_stats.get("storage_bytes", 0),
674 oldest_entry=backend_stats.get("oldest_entry"),
675 newest_entry=backend_stats.get("newest_entry"),
676 )
677 return stats
679 def clear(self) -> int:
680 """
681 Clear all cached embeddings.
683 Returns:
684 Number of entries removed.
685 """
686 if self._backend_type == "sqlite":
687 conn = self._backend._ensure_connection()
688 with self._backend._lock:
689 cursor = conn.execute("DELETE FROM embeddings")
690 return cursor.rowcount
691 else:
692 # Redis: use scan to find and delete keys
693 count = 0
694 for key in self._backend.client.scan_iter(f"{self._backend.prefix}*"):
695 self._backend.client.delete(key)
696 count += 1
697 return count
699 def close(self) -> None:
700 """Close the cache and release resources."""
701 if self._backend_type == "sqlite":
702 self._backend.close()
704 def __enter__(self):
705 return self
707 def __exit__(self, exc_type, exc_val, exc_tb):
708 self.close()
709 return False