Coverage for integrations / channels / memory / search.py: 53.5%
331 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Memory Search - Unified search across memory sources.
4Provides a unified interface for searching across multiple memory sources
5including file content, embeddings, session history, and custom sources.
6Designed for Docker environments with container-compatible storage.
7"""
9import asyncio
10import hashlib
11import json
12import os
13import sqlite3
14import threading
15import time
16from abc import ABC, abstractmethod
17from dataclasses import dataclass, field
18from datetime import datetime
19from enum import Enum
20from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
22from .memory_store import MemoryStore, MemoryItem, SearchResult
23from .embeddings import EmbeddingCache, EmbeddingConfig
26class SearchMode(Enum):
27 """Search modes available."""
28 TEXT = "text" # Full-text search using FTS5
29 SEMANTIC = "semantic" # Vector similarity search
30 HYBRID = "hybrid" # Combined FTS + semantic
31 EXACT = "exact" # Exact string matching
34@dataclass
35class SearchConfig:
36 """Configuration for memory search."""
38 # Search settings
39 default_mode: SearchMode = SearchMode.HYBRID
40 max_results: int = 20
41 min_score: float = 0.1
43 # Hybrid search weights
44 fts_weight: float = 0.3
45 semantic_weight: float = 0.7
47 # Context search settings
48 context_window: int = 5 # Messages before/after match
49 include_metadata: bool = True
51 # Performance settings
52 timeout_seconds: float = 30.0
53 parallel_sources: bool = True
56@dataclass
57class SearchMatch:
58 """A single search match."""
60 source: str
61 content: str
62 score: float
63 match_type: str = "text"
64 snippet: str = ""
65 metadata: Dict[str, Any] = field(default_factory=dict)
66 timestamp: Optional[datetime] = None
67 item_id: Optional[str] = None
69 def to_dict(self) -> Dict[str, Any]:
70 """Convert to dictionary representation."""
71 return {
72 "source": self.source,
73 "content": self.content,
74 "score": self.score,
75 "match_type": self.match_type,
76 "snippet": self.snippet,
77 "metadata": self.metadata,
78 "timestamp": self.timestamp.isoformat() if self.timestamp else None,
79 "item_id": self.item_id,
80 }
83@dataclass
84class SearchResults:
85 """Results from a search operation."""
87 query: str
88 matches: List[SearchMatch] = field(default_factory=list)
89 total_count: int = 0
90 sources_searched: List[str] = field(default_factory=list)
91 duration_ms: float = 0.0
92 mode: SearchMode = SearchMode.HYBRID
93 errors: List[str] = field(default_factory=list)
95 @property
96 def has_results(self) -> bool:
97 """Check if there are any results."""
98 return len(self.matches) > 0
100 def to_dict(self) -> Dict[str, Any]:
101 """Convert to dictionary representation."""
102 return {
103 "query": self.query,
104 "matches": [m.to_dict() for m in self.matches],
105 "total_count": self.total_count,
106 "sources_searched": self.sources_searched,
107 "duration_ms": self.duration_ms,
108 "mode": self.mode.value,
109 "errors": self.errors,
110 }
113@dataclass
114class ContextMatch:
115 """A match with surrounding context."""
117 match: SearchMatch
118 before: List[Dict[str, Any]] = field(default_factory=list)
119 after: List[Dict[str, Any]] = field(default_factory=list)
120 session_id: str = ""
121 position: int = 0
123 def to_dict(self) -> Dict[str, Any]:
124 """Convert to dictionary representation."""
125 return {
126 "match": self.match.to_dict(),
127 "before": self.before,
128 "after": self.after,
129 "session_id": self.session_id,
130 "position": self.position,
131 }
134@dataclass
135class ContextResults:
136 """Results from a context-aware search."""
138 query: str
139 session_id: str
140 context_matches: List[ContextMatch] = field(default_factory=list)
141 total_count: int = 0
142 duration_ms: float = 0.0
143 errors: List[str] = field(default_factory=list)
145 def to_dict(self) -> Dict[str, Any]:
146 """Convert to dictionary representation."""
147 return {
148 "query": self.query,
149 "session_id": self.session_id,
150 "context_matches": [cm.to_dict() for cm in self.context_matches],
151 "total_count": self.total_count,
152 "duration_ms": self.duration_ms,
153 "errors": self.errors,
154 }
157class MemorySource(ABC):
158 """
159 Abstract base class for memory sources.
161 Implement this to add custom searchable memory sources.
162 """
164 @property
165 @abstractmethod
166 def name(self) -> str:
167 """Unique name for this source."""
168 pass
170 @abstractmethod
171 async def search(
172 self,
173 query: str,
174 max_results: int = 10,
175 min_score: float = 0.0,
176 filters: Optional[Dict[str, Any]] = None,
177 ) -> List[SearchMatch]:
178 """
179 Search this memory source.
181 Args:
182 query: Search query.
183 max_results: Maximum results to return.
184 min_score: Minimum score threshold.
185 filters: Optional filters (source-specific).
187 Returns:
188 List of SearchMatch objects.
189 """
190 pass
192 @abstractmethod
193 async def search_semantic(
194 self,
195 query: str,
196 embedding: List[float],
197 max_results: int = 10,
198 min_score: float = 0.0,
199 ) -> List[SearchMatch]:
200 """
201 Perform semantic search using embeddings.
203 Args:
204 query: Original query text.
205 embedding: Query embedding vector.
206 max_results: Maximum results to return.
207 min_score: Minimum similarity threshold.
209 Returns:
210 List of SearchMatch objects.
211 """
212 pass
214 async def get_context(
215 self,
216 item_id: str,
217 window: int = 5,
218 ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
219 """
220 Get context around an item.
222 Args:
223 item_id: Item identifier.
224 window: Number of items before/after.
226 Returns:
227 Tuple of (before, after) context lists.
228 """
229 return [], []
232class MemoryStoreSource(MemorySource):
233 """Memory source backed by MemoryStore."""
235 def __init__(self, store: MemoryStore, source_name: str = "memory_store"):
236 """
237 Initialize with a MemoryStore.
239 Args:
240 store: MemoryStore instance.
241 source_name: Name for this source.
242 """
243 self._store = store
244 self._source_name = source_name
246 @property
247 def name(self) -> str:
248 return self._source_name
250 async def search(
251 self,
252 query: str,
253 max_results: int = 10,
254 min_score: float = 0.0,
255 filters: Optional[Dict[str, Any]] = None,
256 ) -> List[SearchMatch]:
257 """Search using FTS5."""
258 source_filter = filters.get("source") if filters else None
259 results = self._store.search(
260 query=query,
261 max_results=max_results,
262 min_score=min_score,
263 source_filter=source_filter,
264 )
266 return [
267 SearchMatch(
268 source=self.name,
269 content=r.item.content,
270 score=r.score,
271 match_type="fts",
272 snippet=r.snippet,
273 metadata=r.item.metadata,
274 timestamp=datetime.fromtimestamp(r.item.created_at),
275 item_id=r.item.id,
276 )
277 for r in results
278 ]
280 async def search_semantic(
281 self,
282 query: str,
283 embedding: List[float],
284 max_results: int = 10,
285 min_score: float = 0.0,
286 ) -> List[SearchMatch]:
287 """Search using embeddings."""
288 results = self._store.search_semantic(
289 query=query,
290 max_results=max_results,
291 min_score=min_score,
292 )
294 return [
295 SearchMatch(
296 source=self.name,
297 content=r.item.content,
298 score=r.score,
299 match_type="semantic",
300 snippet=r.snippet,
301 metadata=r.item.metadata,
302 timestamp=datetime.fromtimestamp(r.item.created_at),
303 item_id=r.item.id,
304 )
305 for r in results
306 ]
309class MemoryGraphSource(MemorySource):
310 """Memory source backed by MemoryGraph — adds backtrace context to search results."""
312 def __init__(self, graph, source_name: str = "memory_graph"):
313 self._graph = graph
314 self._source_name = source_name
316 @property
317 def name(self) -> str:
318 return self._source_name
320 async def search(
321 self,
322 query: str,
323 max_results: int = 10,
324 min_score: float = 0.0,
325 filters: Optional[Dict[str, Any]] = None,
326 ) -> List[SearchMatch]:
327 nodes = self._graph.recall(query, mode='text', top_k=max_results)
328 return [
329 SearchMatch(
330 source=self.name,
331 content=n.content,
332 score=1.0,
333 match_type="fts",
334 snippet=n.content[:200],
335 metadata={"memory_type": n.memory_type, "source_agent": n.source_agent, "session_id": n.session_id},
336 timestamp=datetime.fromtimestamp(n.created_at),
337 item_id=n.id,
338 )
339 for n in nodes
340 ]
342 async def search_semantic(
343 self,
344 query: str,
345 embedding: List[float],
346 max_results: int = 10,
347 min_score: float = 0.0,
348 ) -> List[SearchMatch]:
349 nodes = self._graph.recall(query, mode='hybrid', top_k=max_results)
350 return [
351 SearchMatch(
352 source=self.name,
353 content=n.content,
354 score=1.0,
355 match_type="semantic",
356 snippet=n.content[:200],
357 metadata={"memory_type": n.memory_type, "source_agent": n.source_agent, "session_id": n.session_id},
358 timestamp=datetime.fromtimestamp(n.created_at),
359 item_id=n.id,
360 )
361 for n in nodes
362 ]
364 async def get_context(
365 self,
366 item_id: str,
367 window: int = 5,
368 ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
369 """Use backtrace for 'before' context (parent chain) and children for 'after'."""
370 chain_data = self._graph.get_memory_chain(item_id)
371 if "error" in chain_data:
372 return [], []
374 before = [
375 {"id": p["id"], "role": p.get("source_agent", ""), "content": p["content"], "timestamp": p.get("created_at", 0)}
376 for p in chain_data.get("parents", [])[-window:]
377 ]
378 after = [
379 {"id": c["id"], "role": c.get("source_agent", ""), "content": c["content"], "timestamp": c.get("created_at", 0)}
380 for c in chain_data.get("children", [])[:window]
381 ]
382 return before, after
385class SessionHistorySource(MemorySource):
386 """Memory source for session/conversation history."""
388 def __init__(
389 self,
390 db_path: Optional[str] = None,
391 source_name: str = "session_history",
392 ):
393 """
394 Initialize session history source.
396 Args:
397 db_path: Path to SQLite database.
398 source_name: Name for this source.
399 """
400 self._source_name = source_name
401 self._lock = threading.RLock()
403 # Determine database path
404 if db_path:
405 self.db_path = db_path
406 else:
407 if os.path.exists("/app/data"):
408 db_dir = "/app/data"
409 elif os.path.exists("/tmp"):
410 db_dir = "/tmp/session_history"
411 else:
412 db_dir = os.path.join(os.path.abspath("."), ".session_history")
414 os.makedirs(db_dir, exist_ok=True)
415 self.db_path = os.path.join(db_dir, "session_history.db")
417 self._conn: Optional[sqlite3.Connection] = None
418 self._ensure_schema()
420 def _ensure_connection(self) -> sqlite3.Connection:
421 """Ensure database connection."""
422 if self._conn is None:
423 os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True)
424 self._conn = sqlite3.connect(
425 self.db_path,
426 check_same_thread=False,
427 isolation_level=None,
428 timeout=30.0,
429 )
430 self._conn.row_factory = sqlite3.Row
431 self._conn.execute("PRAGMA journal_mode=WAL")
432 return self._conn
434 def _ensure_schema(self) -> None:
435 """Create database schema."""
436 conn = self._ensure_connection()
437 with self._lock:
438 conn.execute("""
439 CREATE TABLE IF NOT EXISTS messages (
440 id INTEGER PRIMARY KEY AUTOINCREMENT,
441 session_id TEXT NOT NULL,
442 role TEXT NOT NULL,
443 content TEXT NOT NULL,
444 timestamp REAL NOT NULL,
445 metadata TEXT DEFAULT '{}'
446 )
447 """)
449 # Create FTS5 table
450 try:
451 conn.execute("""
452 CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
453 content,
454 session_id UNINDEXED,
455 content=messages,
456 content_rowid=id
457 )
458 """)
459 self._fts_available = True
460 except sqlite3.OperationalError:
461 self._fts_available = False
463 conn.execute("CREATE INDEX IF NOT EXISTS idx_msg_session ON messages(session_id)")
464 conn.execute("CREATE INDEX IF NOT EXISTS idx_msg_time ON messages(timestamp)")
466 @property
467 def name(self) -> str:
468 return self._source_name
470 def add_message(
471 self,
472 session_id: str,
473 role: str,
474 content: str,
475 metadata: Optional[Dict[str, Any]] = None,
476 ) -> int:
477 """
478 Add a message to session history.
480 Args:
481 session_id: Session identifier.
482 role: Message role (user, assistant, system).
483 content: Message content.
484 metadata: Optional metadata.
486 Returns:
487 Message ID.
488 """
489 conn = self._ensure_connection()
490 with self._lock:
491 cursor = conn.execute(
492 """
493 INSERT INTO messages (session_id, role, content, timestamp, metadata)
494 VALUES (?, ?, ?, ?, ?)
495 """,
496 (session_id, role, content, time.time(), json.dumps(metadata or {}))
497 )
498 return cursor.lastrowid
500 async def search(
501 self,
502 query: str,
503 max_results: int = 10,
504 min_score: float = 0.0,
505 filters: Optional[Dict[str, Any]] = None,
506 ) -> List[SearchMatch]:
507 """Search session history using FTS5."""
508 conn = self._ensure_connection()
509 session_filter = filters.get("session_id") if filters else None
511 with self._lock:
512 if self._fts_available:
513 sql = """
514 SELECT m.*, bm25(messages_fts) as score,
515 snippet(messages_fts, 0, '<b>', '</b>', '...', 32) as snippet
516 FROM messages_fts f
517 JOIN messages m ON f.rowid = m.id
518 WHERE messages_fts MATCH ?
519 """
520 params: List[Any] = [query]
522 if session_filter:
523 sql += " AND m.session_id = ?"
524 params.append(session_filter)
526 sql += " ORDER BY score LIMIT ?"
527 params.append(max_results)
529 rows = conn.execute(sql, params).fetchall()
530 else:
531 sql = "SELECT *, 1.0 as score, '' as snippet FROM messages WHERE content LIKE ?"
532 params = [f"%{query}%"]
534 if session_filter:
535 sql += " AND session_id = ?"
536 params.append(session_filter)
538 sql += " LIMIT ?"
539 params.append(max_results)
541 rows = conn.execute(sql, params).fetchall()
543 return [
544 SearchMatch(
545 source=self.name,
546 content=row["content"],
547 score=abs(row["score"]) if row["score"] else 0.5,
548 match_type="fts",
549 snippet=row["snippet"] if row["snippet"] else row["content"][:200],
550 metadata={
551 "session_id": row["session_id"],
552 "role": row["role"],
553 **json.loads(row["metadata"] or "{}"),
554 },
555 timestamp=datetime.fromtimestamp(row["timestamp"]),
556 item_id=str(row["id"]),
557 )
558 for row in rows
559 ]
561 async def search_semantic(
562 self,
563 query: str,
564 embedding: List[float],
565 max_results: int = 10,
566 min_score: float = 0.0,
567 ) -> List[SearchMatch]:
568 """Semantic search not directly supported; falls back to FTS."""
569 return await self.search(query, max_results, min_score)
571 async def get_context(
572 self,
573 item_id: str,
574 window: int = 5,
575 ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
576 """Get messages before and after the specified message."""
577 conn = self._ensure_connection()
578 msg_id = int(item_id)
580 with self._lock:
581 # Get the message to find its session
582 row = conn.execute(
583 "SELECT session_id, timestamp FROM messages WHERE id = ?",
584 (msg_id,)
585 ).fetchone()
587 if not row:
588 return [], []
590 session_id = row["session_id"]
591 timestamp = row["timestamp"]
593 # Get messages before
594 before_rows = conn.execute(
595 """
596 SELECT * FROM messages
597 WHERE session_id = ? AND timestamp < ?
598 ORDER BY timestamp DESC LIMIT ?
599 """,
600 (session_id, timestamp, window)
601 ).fetchall()
603 # Get messages after
604 after_rows = conn.execute(
605 """
606 SELECT * FROM messages
607 WHERE session_id = ? AND timestamp > ?
608 ORDER BY timestamp ASC LIMIT ?
609 """,
610 (session_id, timestamp, window)
611 ).fetchall()
613 before = [
614 {
615 "id": r["id"],
616 "role": r["role"],
617 "content": r["content"],
618 "timestamp": r["timestamp"],
619 }
620 for r in reversed(before_rows)
621 ]
623 after = [
624 {
625 "id": r["id"],
626 "role": r["role"],
627 "content": r["content"],
628 "timestamp": r["timestamp"],
629 }
630 for r in after_rows
631 ]
633 return before, after
635 def close(self) -> None:
636 """Close database connection."""
637 with self._lock:
638 if self._conn:
639 self._conn.close()
640 self._conn = None
643class MemorySearch:
644 """
645 Unified search across memory sources.
647 Provides a single interface for searching across multiple memory backends
648 including file content, embeddings, session history, and custom sources.
650 Features:
651 - Multiple search modes (text, semantic, hybrid)
652 - Pluggable memory sources
653 - Context-aware search for sessions
654 - Parallel source searching
655 """
657 def __init__(
658 self,
659 config: Optional[SearchConfig] = None,
660 embedding_cache: Optional[EmbeddingCache] = None,
661 enable_simplemem: Optional[bool] = None,
662 ):
663 """
664 Initialize the memory search.
666 Args:
667 config: Search configuration.
668 embedding_cache: Optional embedding cache for semantic search.
669 enable_simplemem: Explicitly enable/disable SimpleMem. If None,
670 uses SIMPLEMEM_ENABLED env var (default: false for auto-register).
671 """
672 self.config = config or SearchConfig()
673 self.embedding_cache = embedding_cache
675 self._lock = threading.RLock()
676 self._sources: Dict[str, MemorySource] = {}
678 # Auto-register SimpleMem if available and enabled
679 should_enable = enable_simplemem
680 if should_enable is None:
681 should_enable = os.getenv("SIMPLEMEM_ENABLED", "false").lower() == "true"
683 if should_enable:
684 try:
685 from .simplemem_store import SimpleMemStore, SimpleMemConfig
686 simplemem_config = SimpleMemConfig.from_env()
687 if simplemem_config.enabled and simplemem_config.api_key:
688 self.add_source("simplemem", SimpleMemStore(simplemem_config))
689 except ImportError:
690 pass
691 except Exception as e:
692 import logging
693 logging.getLogger(__name__).warning(
694 "Failed to auto-register SimpleMem: %s", e
695 )
697 def add_source(self, name: str, source: MemorySource) -> None:
698 """
699 Add a memory source.
701 Args:
702 name: Unique name for the source.
703 source: MemorySource implementation.
704 """
705 with self._lock:
706 self._sources[name] = source
708 def remove_source(self, name: str) -> bool:
709 """
710 Remove a memory source.
712 Args:
713 name: Name of source to remove.
715 Returns:
716 True if source was removed, False if not found.
717 """
718 with self._lock:
719 if name in self._sources:
720 del self._sources[name]
721 return True
722 return False
724 def get_sources(self) -> List[str]:
725 """Get list of registered source names."""
726 with self._lock:
727 return list(self._sources.keys())
729 async def search(
730 self,
731 query: str,
732 sources: Optional[List[str]] = None,
733 mode: Optional[SearchMode] = None,
734 max_results: Optional[int] = None,
735 min_score: Optional[float] = None,
736 filters: Optional[Dict[str, Any]] = None,
737 ) -> SearchResults:
738 """
739 Search across memory sources.
741 Args:
742 query: Search query.
743 sources: Optional list of sources to search. Searches all if None.
744 mode: Search mode (text, semantic, hybrid).
745 max_results: Maximum results to return.
746 min_score: Minimum score threshold.
747 filters: Optional filters passed to sources.
749 Returns:
750 SearchResults object.
751 """
752 start_time = time.time()
753 mode = mode or self.config.default_mode
754 max_results = max_results or self.config.max_results
755 min_score = min_score if min_score is not None else self.config.min_score
757 results = SearchResults(query=query, mode=mode)
759 # Determine which sources to search
760 with self._lock:
761 if sources:
762 target_sources = {k: v for k, v in self._sources.items() if k in sources}
763 else:
764 target_sources = dict(self._sources)
766 if not target_sources:
767 results.errors.append("No sources available for search")
768 return results
770 results.sources_searched = list(target_sources.keys())
772 # Get query embedding if needed
773 query_embedding: Optional[List[float]] = None
774 if mode in (SearchMode.SEMANTIC, SearchMode.HYBRID) and self.embedding_cache:
775 try:
776 query_embedding = await self.embedding_cache.get_embedding(query)
777 except Exception as e:
778 results.errors.append(f"Failed to get query embedding: {e}")
779 if mode == SearchMode.SEMANTIC:
780 return results
782 # Search each source
783 all_matches: List[SearchMatch] = []
785 if self.config.parallel_sources:
786 # Search sources in parallel
787 tasks = []
788 for name, source in target_sources.items():
789 tasks.append(self._search_source(
790 source, query, query_embedding, mode, max_results * 2, min_score, filters
791 ))
793 source_results = await asyncio.gather(*tasks, return_exceptions=True)
795 for i, result in enumerate(source_results):
796 if isinstance(result, Exception):
797 results.errors.append(f"Source error: {result}")
798 else:
799 all_matches.extend(result)
800 else:
801 # Search sources sequentially
802 for name, source in target_sources.items():
803 try:
804 matches = await self._search_source(
805 source, query, query_embedding, mode, max_results * 2, min_score, filters
806 )
807 all_matches.extend(matches)
808 except Exception as e:
809 results.errors.append(f"Source {name} error: {e}")
811 # Sort by score and limit results
812 all_matches.sort(key=lambda m: m.score, reverse=True)
813 results.matches = all_matches[:max_results]
814 results.total_count = len(all_matches)
815 results.duration_ms = (time.time() - start_time) * 1000
817 return results
819 async def _search_source(
820 self,
821 source: MemorySource,
822 query: str,
823 embedding: Optional[List[float]],
824 mode: SearchMode,
825 max_results: int,
826 min_score: float,
827 filters: Optional[Dict[str, Any]],
828 ) -> List[SearchMatch]:
829 """Search a single source based on mode."""
830 if mode == SearchMode.TEXT or mode == SearchMode.EXACT:
831 return await source.search(query, max_results, min_score, filters)
833 elif mode == SearchMode.SEMANTIC:
834 if embedding:
835 return await source.search_semantic(query, embedding, max_results, min_score)
836 return []
838 elif mode == SearchMode.HYBRID:
839 # Get both FTS and semantic results
840 fts_results = await source.search(query, max_results, min_score, filters)
842 semantic_results = []
843 if embedding:
844 semantic_results = await source.search_semantic(query, embedding, max_results, min_score)
846 # Merge results
847 return self._merge_hybrid_results(
848 fts_results,
849 semantic_results,
850 self.config.fts_weight,
851 self.config.semantic_weight,
852 )
854 return []
856 def _merge_hybrid_results(
857 self,
858 fts_results: List[SearchMatch],
859 semantic_results: List[SearchMatch],
860 fts_weight: float,
861 semantic_weight: float,
862 ) -> List[SearchMatch]:
863 """Merge FTS and semantic results with weighted scores."""
864 scores: Dict[str, Dict[str, Any]] = {}
866 # Process FTS results
867 for match in fts_results:
868 key = f"{match.source}:{match.item_id or match.content[:50]}"
869 scores[key] = {
870 "match": match,
871 "fts_score": match.score,
872 "semantic_score": 0.0,
873 }
875 # Process semantic results
876 for match in semantic_results:
877 key = f"{match.source}:{match.item_id or match.content[:50]}"
878 if key in scores:
879 scores[key]["semantic_score"] = match.score
880 else:
881 scores[key] = {
882 "match": match,
883 "fts_score": 0.0,
884 "semantic_score": match.score,
885 }
887 # Compute combined scores
888 merged = []
889 for data in scores.values():
890 combined = (fts_weight * data["fts_score"]) + (semantic_weight * data["semantic_score"])
891 match = data["match"]
892 match.score = combined
893 match.match_type = "hybrid"
894 merged.append(match)
896 return merged
898 async def search_context(
899 self,
900 query: str,
901 session_id: str,
902 sources: Optional[List[str]] = None,
903 max_results: Optional[int] = None,
904 ) -> ContextResults:
905 """
906 Search with surrounding context for a specific session.
908 Args:
909 query: Search query.
910 session_id: Session to search within.
911 sources: Optional list of sources to search.
912 max_results: Maximum results to return.
914 Returns:
915 ContextResults with matches and surrounding context.
916 """
917 start_time = time.time()
918 max_results = max_results or self.config.max_results
920 results = ContextResults(query=query, session_id=session_id)
922 # First do a regular search filtered by session
923 filters = {"session_id": session_id}
924 search_results = await self.search(
925 query=query,
926 sources=sources,
927 max_results=max_results,
928 filters=filters,
929 )
931 results.errors = search_results.errors
933 # Get context for each match
934 for match in search_results.matches:
935 if not match.item_id:
936 continue
938 # Find the source
939 source = self._sources.get(match.source)
940 if not source:
941 continue
943 try:
944 before, after = await source.get_context(
945 match.item_id,
946 window=self.config.context_window,
947 )
949 context_match = ContextMatch(
950 match=match,
951 before=before,
952 after=after,
953 session_id=session_id,
954 )
955 results.context_matches.append(context_match)
957 except Exception as e:
958 results.errors.append(f"Failed to get context: {e}")
960 results.total_count = len(results.context_matches)
961 results.duration_ms = (time.time() - start_time) * 1000
963 return results
965 def close(self) -> None:
966 """Close all sources that support it."""
967 with self._lock:
968 for source in self._sources.values():
969 if hasattr(source, "close"):
970 try:
971 source.close()
972 except Exception:
973 pass
974 self._sources.clear()
976 def __enter__(self):
977 return self
979 def __exit__(self, exc_type, exc_val, exc_tb):
980 self.close()
981 return False