Coverage for integrations / channels / memory / search.py: 53.5%

331 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Memory Search - Unified search across memory sources. 

3 

4Provides a unified interface for searching across multiple memory sources 

5including file content, embeddings, session history, and custom sources. 

6Designed for Docker environments with container-compatible storage. 

7""" 

8 

9import asyncio 

10import hashlib 

11import json 

12import os 

13import sqlite3 

14import threading 

15import time 

16from abc import ABC, abstractmethod 

17from dataclasses import dataclass, field 

18from datetime import datetime 

19from enum import Enum 

20from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 

21 

22from .memory_store import MemoryStore, MemoryItem, SearchResult 

23from .embeddings import EmbeddingCache, EmbeddingConfig 

24 

25 

26class SearchMode(Enum): 

27 """Search modes available.""" 

28 TEXT = "text" # Full-text search using FTS5 

29 SEMANTIC = "semantic" # Vector similarity search 

30 HYBRID = "hybrid" # Combined FTS + semantic 

31 EXACT = "exact" # Exact string matching 

32 

33 

34@dataclass 

35class SearchConfig: 

36 """Configuration for memory search.""" 

37 

38 # Search settings 

39 default_mode: SearchMode = SearchMode.HYBRID 

40 max_results: int = 20 

41 min_score: float = 0.1 

42 

43 # Hybrid search weights 

44 fts_weight: float = 0.3 

45 semantic_weight: float = 0.7 

46 

47 # Context search settings 

48 context_window: int = 5 # Messages before/after match 

49 include_metadata: bool = True 

50 

51 # Performance settings 

52 timeout_seconds: float = 30.0 

53 parallel_sources: bool = True 

54 

55 

56@dataclass 

57class SearchMatch: 

58 """A single search match.""" 

59 

60 source: str 

61 content: str 

62 score: float 

63 match_type: str = "text" 

64 snippet: str = "" 

65 metadata: Dict[str, Any] = field(default_factory=dict) 

66 timestamp: Optional[datetime] = None 

67 item_id: Optional[str] = None 

68 

69 def to_dict(self) -> Dict[str, Any]: 

70 """Convert to dictionary representation.""" 

71 return { 

72 "source": self.source, 

73 "content": self.content, 

74 "score": self.score, 

75 "match_type": self.match_type, 

76 "snippet": self.snippet, 

77 "metadata": self.metadata, 

78 "timestamp": self.timestamp.isoformat() if self.timestamp else None, 

79 "item_id": self.item_id, 

80 } 

81 

82 

83@dataclass 

84class SearchResults: 

85 """Results from a search operation.""" 

86 

87 query: str 

88 matches: List[SearchMatch] = field(default_factory=list) 

89 total_count: int = 0 

90 sources_searched: List[str] = field(default_factory=list) 

91 duration_ms: float = 0.0 

92 mode: SearchMode = SearchMode.HYBRID 

93 errors: List[str] = field(default_factory=list) 

94 

95 @property 

96 def has_results(self) -> bool: 

97 """Check if there are any results.""" 

98 return len(self.matches) > 0 

99 

100 def to_dict(self) -> Dict[str, Any]: 

101 """Convert to dictionary representation.""" 

102 return { 

103 "query": self.query, 

104 "matches": [m.to_dict() for m in self.matches], 

105 "total_count": self.total_count, 

106 "sources_searched": self.sources_searched, 

107 "duration_ms": self.duration_ms, 

108 "mode": self.mode.value, 

109 "errors": self.errors, 

110 } 

111 

112 

113@dataclass 

114class ContextMatch: 

115 """A match with surrounding context.""" 

116 

117 match: SearchMatch 

118 before: List[Dict[str, Any]] = field(default_factory=list) 

119 after: List[Dict[str, Any]] = field(default_factory=list) 

120 session_id: str = "" 

121 position: int = 0 

122 

123 def to_dict(self) -> Dict[str, Any]: 

124 """Convert to dictionary representation.""" 

125 return { 

126 "match": self.match.to_dict(), 

127 "before": self.before, 

128 "after": self.after, 

129 "session_id": self.session_id, 

130 "position": self.position, 

131 } 

132 

133 

134@dataclass 

135class ContextResults: 

136 """Results from a context-aware search.""" 

137 

138 query: str 

139 session_id: str 

140 context_matches: List[ContextMatch] = field(default_factory=list) 

141 total_count: int = 0 

142 duration_ms: float = 0.0 

143 errors: List[str] = field(default_factory=list) 

144 

145 def to_dict(self) -> Dict[str, Any]: 

146 """Convert to dictionary representation.""" 

147 return { 

148 "query": self.query, 

149 "session_id": self.session_id, 

150 "context_matches": [cm.to_dict() for cm in self.context_matches], 

151 "total_count": self.total_count, 

152 "duration_ms": self.duration_ms, 

153 "errors": self.errors, 

154 } 

155 

156 

157class MemorySource(ABC): 

158 """ 

159 Abstract base class for memory sources. 

160 

161 Implement this to add custom searchable memory sources. 

162 """ 

163 

164 @property 

165 @abstractmethod 

166 def name(self) -> str: 

167 """Unique name for this source.""" 

168 pass 

169 

170 @abstractmethod 

171 async def search( 

172 self, 

173 query: str, 

174 max_results: int = 10, 

175 min_score: float = 0.0, 

176 filters: Optional[Dict[str, Any]] = None, 

177 ) -> List[SearchMatch]: 

178 """ 

179 Search this memory source. 

180 

181 Args: 

182 query: Search query. 

183 max_results: Maximum results to return. 

184 min_score: Minimum score threshold. 

185 filters: Optional filters (source-specific). 

186 

187 Returns: 

188 List of SearchMatch objects. 

189 """ 

190 pass 

191 

192 @abstractmethod 

193 async def search_semantic( 

194 self, 

195 query: str, 

196 embedding: List[float], 

197 max_results: int = 10, 

198 min_score: float = 0.0, 

199 ) -> List[SearchMatch]: 

200 """ 

201 Perform semantic search using embeddings. 

202 

203 Args: 

204 query: Original query text. 

205 embedding: Query embedding vector. 

206 max_results: Maximum results to return. 

207 min_score: Minimum similarity threshold. 

208 

209 Returns: 

210 List of SearchMatch objects. 

211 """ 

212 pass 

213 

214 async def get_context( 

215 self, 

216 item_id: str, 

217 window: int = 5, 

218 ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: 

219 """ 

220 Get context around an item. 

221 

222 Args: 

223 item_id: Item identifier. 

224 window: Number of items before/after. 

225 

226 Returns: 

227 Tuple of (before, after) context lists. 

228 """ 

229 return [], [] 

230 

231 

232class MemoryStoreSource(MemorySource): 

233 """Memory source backed by MemoryStore.""" 

234 

235 def __init__(self, store: MemoryStore, source_name: str = "memory_store"): 

236 """ 

237 Initialize with a MemoryStore. 

238 

239 Args: 

240 store: MemoryStore instance. 

241 source_name: Name for this source. 

242 """ 

243 self._store = store 

244 self._source_name = source_name 

245 

246 @property 

247 def name(self) -> str: 

248 return self._source_name 

249 

250 async def search( 

251 self, 

252 query: str, 

253 max_results: int = 10, 

254 min_score: float = 0.0, 

255 filters: Optional[Dict[str, Any]] = None, 

256 ) -> List[SearchMatch]: 

257 """Search using FTS5.""" 

258 source_filter = filters.get("source") if filters else None 

259 results = self._store.search( 

260 query=query, 

261 max_results=max_results, 

262 min_score=min_score, 

263 source_filter=source_filter, 

264 ) 

265 

266 return [ 

267 SearchMatch( 

268 source=self.name, 

269 content=r.item.content, 

270 score=r.score, 

271 match_type="fts", 

272 snippet=r.snippet, 

273 metadata=r.item.metadata, 

274 timestamp=datetime.fromtimestamp(r.item.created_at), 

275 item_id=r.item.id, 

276 ) 

277 for r in results 

278 ] 

279 

280 async def search_semantic( 

281 self, 

282 query: str, 

283 embedding: List[float], 

284 max_results: int = 10, 

285 min_score: float = 0.0, 

286 ) -> List[SearchMatch]: 

287 """Search using embeddings.""" 

288 results = self._store.search_semantic( 

289 query=query, 

290 max_results=max_results, 

291 min_score=min_score, 

292 ) 

293 

294 return [ 

295 SearchMatch( 

296 source=self.name, 

297 content=r.item.content, 

298 score=r.score, 

299 match_type="semantic", 

300 snippet=r.snippet, 

301 metadata=r.item.metadata, 

302 timestamp=datetime.fromtimestamp(r.item.created_at), 

303 item_id=r.item.id, 

304 ) 

305 for r in results 

306 ] 

307 

308 

309class MemoryGraphSource(MemorySource): 

310 """Memory source backed by MemoryGraph — adds backtrace context to search results.""" 

311 

312 def __init__(self, graph, source_name: str = "memory_graph"): 

313 self._graph = graph 

314 self._source_name = source_name 

315 

316 @property 

317 def name(self) -> str: 

318 return self._source_name 

319 

320 async def search( 

321 self, 

322 query: str, 

323 max_results: int = 10, 

324 min_score: float = 0.0, 

325 filters: Optional[Dict[str, Any]] = None, 

326 ) -> List[SearchMatch]: 

327 nodes = self._graph.recall(query, mode='text', top_k=max_results) 

328 return [ 

329 SearchMatch( 

330 source=self.name, 

331 content=n.content, 

332 score=1.0, 

333 match_type="fts", 

334 snippet=n.content[:200], 

335 metadata={"memory_type": n.memory_type, "source_agent": n.source_agent, "session_id": n.session_id}, 

336 timestamp=datetime.fromtimestamp(n.created_at), 

337 item_id=n.id, 

338 ) 

339 for n in nodes 

340 ] 

341 

342 async def search_semantic( 

343 self, 

344 query: str, 

345 embedding: List[float], 

346 max_results: int = 10, 

347 min_score: float = 0.0, 

348 ) -> List[SearchMatch]: 

349 nodes = self._graph.recall(query, mode='hybrid', top_k=max_results) 

350 return [ 

351 SearchMatch( 

352 source=self.name, 

353 content=n.content, 

354 score=1.0, 

355 match_type="semantic", 

356 snippet=n.content[:200], 

357 metadata={"memory_type": n.memory_type, "source_agent": n.source_agent, "session_id": n.session_id}, 

358 timestamp=datetime.fromtimestamp(n.created_at), 

359 item_id=n.id, 

360 ) 

361 for n in nodes 

362 ] 

363 

364 async def get_context( 

365 self, 

366 item_id: str, 

367 window: int = 5, 

368 ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: 

369 """Use backtrace for 'before' context (parent chain) and children for 'after'.""" 

370 chain_data = self._graph.get_memory_chain(item_id) 

371 if "error" in chain_data: 

372 return [], [] 

373 

374 before = [ 

375 {"id": p["id"], "role": p.get("source_agent", ""), "content": p["content"], "timestamp": p.get("created_at", 0)} 

376 for p in chain_data.get("parents", [])[-window:] 

377 ] 

378 after = [ 

379 {"id": c["id"], "role": c.get("source_agent", ""), "content": c["content"], "timestamp": c.get("created_at", 0)} 

380 for c in chain_data.get("children", [])[:window] 

381 ] 

382 return before, after 

383 

384 

385class SessionHistorySource(MemorySource): 

386 """Memory source for session/conversation history.""" 

387 

388 def __init__( 

389 self, 

390 db_path: Optional[str] = None, 

391 source_name: str = "session_history", 

392 ): 

393 """ 

394 Initialize session history source. 

395 

396 Args: 

397 db_path: Path to SQLite database. 

398 source_name: Name for this source. 

399 """ 

400 self._source_name = source_name 

401 self._lock = threading.RLock() 

402 

403 # Determine database path 

404 if db_path: 

405 self.db_path = db_path 

406 else: 

407 if os.path.exists("/app/data"): 

408 db_dir = "/app/data" 

409 elif os.path.exists("/tmp"): 

410 db_dir = "/tmp/session_history" 

411 else: 

412 db_dir = os.path.join(os.path.abspath("."), ".session_history") 

413 

414 os.makedirs(db_dir, exist_ok=True) 

415 self.db_path = os.path.join(db_dir, "session_history.db") 

416 

417 self._conn: Optional[sqlite3.Connection] = None 

418 self._ensure_schema() 

419 

420 def _ensure_connection(self) -> sqlite3.Connection: 

421 """Ensure database connection.""" 

422 if self._conn is None: 

423 os.makedirs(os.path.dirname(self.db_path) or ".", exist_ok=True) 

424 self._conn = sqlite3.connect( 

425 self.db_path, 

426 check_same_thread=False, 

427 isolation_level=None, 

428 timeout=30.0, 

429 ) 

430 self._conn.row_factory = sqlite3.Row 

431 self._conn.execute("PRAGMA journal_mode=WAL") 

432 return self._conn 

433 

434 def _ensure_schema(self) -> None: 

435 """Create database schema.""" 

436 conn = self._ensure_connection() 

437 with self._lock: 

438 conn.execute(""" 

439 CREATE TABLE IF NOT EXISTS messages ( 

440 id INTEGER PRIMARY KEY AUTOINCREMENT, 

441 session_id TEXT NOT NULL, 

442 role TEXT NOT NULL, 

443 content TEXT NOT NULL, 

444 timestamp REAL NOT NULL, 

445 metadata TEXT DEFAULT '{}' 

446 ) 

447 """) 

448 

449 # Create FTS5 table 

450 try: 

451 conn.execute(""" 

452 CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5( 

453 content, 

454 session_id UNINDEXED, 

455 content=messages, 

456 content_rowid=id 

457 ) 

458 """) 

459 self._fts_available = True 

460 except sqlite3.OperationalError: 

461 self._fts_available = False 

462 

463 conn.execute("CREATE INDEX IF NOT EXISTS idx_msg_session ON messages(session_id)") 

464 conn.execute("CREATE INDEX IF NOT EXISTS idx_msg_time ON messages(timestamp)") 

465 

466 @property 

467 def name(self) -> str: 

468 return self._source_name 

469 

470 def add_message( 

471 self, 

472 session_id: str, 

473 role: str, 

474 content: str, 

475 metadata: Optional[Dict[str, Any]] = None, 

476 ) -> int: 

477 """ 

478 Add a message to session history. 

479 

480 Args: 

481 session_id: Session identifier. 

482 role: Message role (user, assistant, system). 

483 content: Message content. 

484 metadata: Optional metadata. 

485 

486 Returns: 

487 Message ID. 

488 """ 

489 conn = self._ensure_connection() 

490 with self._lock: 

491 cursor = conn.execute( 

492 """ 

493 INSERT INTO messages (session_id, role, content, timestamp, metadata) 

494 VALUES (?, ?, ?, ?, ?) 

495 """, 

496 (session_id, role, content, time.time(), json.dumps(metadata or {})) 

497 ) 

498 return cursor.lastrowid 

499 

500 async def search( 

501 self, 

502 query: str, 

503 max_results: int = 10, 

504 min_score: float = 0.0, 

505 filters: Optional[Dict[str, Any]] = None, 

506 ) -> List[SearchMatch]: 

507 """Search session history using FTS5.""" 

508 conn = self._ensure_connection() 

509 session_filter = filters.get("session_id") if filters else None 

510 

511 with self._lock: 

512 if self._fts_available: 

513 sql = """ 

514 SELECT m.*, bm25(messages_fts) as score, 

515 snippet(messages_fts, 0, '<b>', '</b>', '...', 32) as snippet 

516 FROM messages_fts f 

517 JOIN messages m ON f.rowid = m.id 

518 WHERE messages_fts MATCH ? 

519 """ 

520 params: List[Any] = [query] 

521 

522 if session_filter: 

523 sql += " AND m.session_id = ?" 

524 params.append(session_filter) 

525 

526 sql += " ORDER BY score LIMIT ?" 

527 params.append(max_results) 

528 

529 rows = conn.execute(sql, params).fetchall() 

530 else: 

531 sql = "SELECT *, 1.0 as score, '' as snippet FROM messages WHERE content LIKE ?" 

532 params = [f"%{query}%"] 

533 

534 if session_filter: 

535 sql += " AND session_id = ?" 

536 params.append(session_filter) 

537 

538 sql += " LIMIT ?" 

539 params.append(max_results) 

540 

541 rows = conn.execute(sql, params).fetchall() 

542 

543 return [ 

544 SearchMatch( 

545 source=self.name, 

546 content=row["content"], 

547 score=abs(row["score"]) if row["score"] else 0.5, 

548 match_type="fts", 

549 snippet=row["snippet"] if row["snippet"] else row["content"][:200], 

550 metadata={ 

551 "session_id": row["session_id"], 

552 "role": row["role"], 

553 **json.loads(row["metadata"] or "{}"), 

554 }, 

555 timestamp=datetime.fromtimestamp(row["timestamp"]), 

556 item_id=str(row["id"]), 

557 ) 

558 for row in rows 

559 ] 

560 

561 async def search_semantic( 

562 self, 

563 query: str, 

564 embedding: List[float], 

565 max_results: int = 10, 

566 min_score: float = 0.0, 

567 ) -> List[SearchMatch]: 

568 """Semantic search not directly supported; falls back to FTS.""" 

569 return await self.search(query, max_results, min_score) 

570 

571 async def get_context( 

572 self, 

573 item_id: str, 

574 window: int = 5, 

575 ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: 

576 """Get messages before and after the specified message.""" 

577 conn = self._ensure_connection() 

578 msg_id = int(item_id) 

579 

580 with self._lock: 

581 # Get the message to find its session 

582 row = conn.execute( 

583 "SELECT session_id, timestamp FROM messages WHERE id = ?", 

584 (msg_id,) 

585 ).fetchone() 

586 

587 if not row: 

588 return [], [] 

589 

590 session_id = row["session_id"] 

591 timestamp = row["timestamp"] 

592 

593 # Get messages before 

594 before_rows = conn.execute( 

595 """ 

596 SELECT * FROM messages 

597 WHERE session_id = ? AND timestamp < ? 

598 ORDER BY timestamp DESC LIMIT ? 

599 """, 

600 (session_id, timestamp, window) 

601 ).fetchall() 

602 

603 # Get messages after 

604 after_rows = conn.execute( 

605 """ 

606 SELECT * FROM messages 

607 WHERE session_id = ? AND timestamp > ? 

608 ORDER BY timestamp ASC LIMIT ? 

609 """, 

610 (session_id, timestamp, window) 

611 ).fetchall() 

612 

613 before = [ 

614 { 

615 "id": r["id"], 

616 "role": r["role"], 

617 "content": r["content"], 

618 "timestamp": r["timestamp"], 

619 } 

620 for r in reversed(before_rows) 

621 ] 

622 

623 after = [ 

624 { 

625 "id": r["id"], 

626 "role": r["role"], 

627 "content": r["content"], 

628 "timestamp": r["timestamp"], 

629 } 

630 for r in after_rows 

631 ] 

632 

633 return before, after 

634 

635 def close(self) -> None: 

636 """Close database connection.""" 

637 with self._lock: 

638 if self._conn: 

639 self._conn.close() 

640 self._conn = None 

641 

642 

643class MemorySearch: 

644 """ 

645 Unified search across memory sources. 

646 

647 Provides a single interface for searching across multiple memory backends 

648 including file content, embeddings, session history, and custom sources. 

649 

650 Features: 

651 - Multiple search modes (text, semantic, hybrid) 

652 - Pluggable memory sources 

653 - Context-aware search for sessions 

654 - Parallel source searching 

655 """ 

656 

657 def __init__( 

658 self, 

659 config: Optional[SearchConfig] = None, 

660 embedding_cache: Optional[EmbeddingCache] = None, 

661 enable_simplemem: Optional[bool] = None, 

662 ): 

663 """ 

664 Initialize the memory search. 

665 

666 Args: 

667 config: Search configuration. 

668 embedding_cache: Optional embedding cache for semantic search. 

669 enable_simplemem: Explicitly enable/disable SimpleMem. If None, 

670 uses SIMPLEMEM_ENABLED env var (default: false for auto-register). 

671 """ 

672 self.config = config or SearchConfig() 

673 self.embedding_cache = embedding_cache 

674 

675 self._lock = threading.RLock() 

676 self._sources: Dict[str, MemorySource] = {} 

677 

678 # Auto-register SimpleMem if available and enabled 

679 should_enable = enable_simplemem 

680 if should_enable is None: 

681 should_enable = os.getenv("SIMPLEMEM_ENABLED", "false").lower() == "true" 

682 

683 if should_enable: 

684 try: 

685 from .simplemem_store import SimpleMemStore, SimpleMemConfig 

686 simplemem_config = SimpleMemConfig.from_env() 

687 if simplemem_config.enabled and simplemem_config.api_key: 

688 self.add_source("simplemem", SimpleMemStore(simplemem_config)) 

689 except ImportError: 

690 pass 

691 except Exception as e: 

692 import logging 

693 logging.getLogger(__name__).warning( 

694 "Failed to auto-register SimpleMem: %s", e 

695 ) 

696 

697 def add_source(self, name: str, source: MemorySource) -> None: 

698 """ 

699 Add a memory source. 

700 

701 Args: 

702 name: Unique name for the source. 

703 source: MemorySource implementation. 

704 """ 

705 with self._lock: 

706 self._sources[name] = source 

707 

708 def remove_source(self, name: str) -> bool: 

709 """ 

710 Remove a memory source. 

711 

712 Args: 

713 name: Name of source to remove. 

714 

715 Returns: 

716 True if source was removed, False if not found. 

717 """ 

718 with self._lock: 

719 if name in self._sources: 

720 del self._sources[name] 

721 return True 

722 return False 

723 

724 def get_sources(self) -> List[str]: 

725 """Get list of registered source names.""" 

726 with self._lock: 

727 return list(self._sources.keys()) 

728 

729 async def search( 

730 self, 

731 query: str, 

732 sources: Optional[List[str]] = None, 

733 mode: Optional[SearchMode] = None, 

734 max_results: Optional[int] = None, 

735 min_score: Optional[float] = None, 

736 filters: Optional[Dict[str, Any]] = None, 

737 ) -> SearchResults: 

738 """ 

739 Search across memory sources. 

740 

741 Args: 

742 query: Search query. 

743 sources: Optional list of sources to search. Searches all if None. 

744 mode: Search mode (text, semantic, hybrid). 

745 max_results: Maximum results to return. 

746 min_score: Minimum score threshold. 

747 filters: Optional filters passed to sources. 

748 

749 Returns: 

750 SearchResults object. 

751 """ 

752 start_time = time.time() 

753 mode = mode or self.config.default_mode 

754 max_results = max_results or self.config.max_results 

755 min_score = min_score if min_score is not None else self.config.min_score 

756 

757 results = SearchResults(query=query, mode=mode) 

758 

759 # Determine which sources to search 

760 with self._lock: 

761 if sources: 

762 target_sources = {k: v for k, v in self._sources.items() if k in sources} 

763 else: 

764 target_sources = dict(self._sources) 

765 

766 if not target_sources: 

767 results.errors.append("No sources available for search") 

768 return results 

769 

770 results.sources_searched = list(target_sources.keys()) 

771 

772 # Get query embedding if needed 

773 query_embedding: Optional[List[float]] = None 

774 if mode in (SearchMode.SEMANTIC, SearchMode.HYBRID) and self.embedding_cache: 

775 try: 

776 query_embedding = await self.embedding_cache.get_embedding(query) 

777 except Exception as e: 

778 results.errors.append(f"Failed to get query embedding: {e}") 

779 if mode == SearchMode.SEMANTIC: 

780 return results 

781 

782 # Search each source 

783 all_matches: List[SearchMatch] = [] 

784 

785 if self.config.parallel_sources: 

786 # Search sources in parallel 

787 tasks = [] 

788 for name, source in target_sources.items(): 

789 tasks.append(self._search_source( 

790 source, query, query_embedding, mode, max_results * 2, min_score, filters 

791 )) 

792 

793 source_results = await asyncio.gather(*tasks, return_exceptions=True) 

794 

795 for i, result in enumerate(source_results): 

796 if isinstance(result, Exception): 

797 results.errors.append(f"Source error: {result}") 

798 else: 

799 all_matches.extend(result) 

800 else: 

801 # Search sources sequentially 

802 for name, source in target_sources.items(): 

803 try: 

804 matches = await self._search_source( 

805 source, query, query_embedding, mode, max_results * 2, min_score, filters 

806 ) 

807 all_matches.extend(matches) 

808 except Exception as e: 

809 results.errors.append(f"Source {name} error: {e}") 

810 

811 # Sort by score and limit results 

812 all_matches.sort(key=lambda m: m.score, reverse=True) 

813 results.matches = all_matches[:max_results] 

814 results.total_count = len(all_matches) 

815 results.duration_ms = (time.time() - start_time) * 1000 

816 

817 return results 

818 

819 async def _search_source( 

820 self, 

821 source: MemorySource, 

822 query: str, 

823 embedding: Optional[List[float]], 

824 mode: SearchMode, 

825 max_results: int, 

826 min_score: float, 

827 filters: Optional[Dict[str, Any]], 

828 ) -> List[SearchMatch]: 

829 """Search a single source based on mode.""" 

830 if mode == SearchMode.TEXT or mode == SearchMode.EXACT: 

831 return await source.search(query, max_results, min_score, filters) 

832 

833 elif mode == SearchMode.SEMANTIC: 

834 if embedding: 

835 return await source.search_semantic(query, embedding, max_results, min_score) 

836 return [] 

837 

838 elif mode == SearchMode.HYBRID: 

839 # Get both FTS and semantic results 

840 fts_results = await source.search(query, max_results, min_score, filters) 

841 

842 semantic_results = [] 

843 if embedding: 

844 semantic_results = await source.search_semantic(query, embedding, max_results, min_score) 

845 

846 # Merge results 

847 return self._merge_hybrid_results( 

848 fts_results, 

849 semantic_results, 

850 self.config.fts_weight, 

851 self.config.semantic_weight, 

852 ) 

853 

854 return [] 

855 

856 def _merge_hybrid_results( 

857 self, 

858 fts_results: List[SearchMatch], 

859 semantic_results: List[SearchMatch], 

860 fts_weight: float, 

861 semantic_weight: float, 

862 ) -> List[SearchMatch]: 

863 """Merge FTS and semantic results with weighted scores.""" 

864 scores: Dict[str, Dict[str, Any]] = {} 

865 

866 # Process FTS results 

867 for match in fts_results: 

868 key = f"{match.source}:{match.item_id or match.content[:50]}" 

869 scores[key] = { 

870 "match": match, 

871 "fts_score": match.score, 

872 "semantic_score": 0.0, 

873 } 

874 

875 # Process semantic results 

876 for match in semantic_results: 

877 key = f"{match.source}:{match.item_id or match.content[:50]}" 

878 if key in scores: 

879 scores[key]["semantic_score"] = match.score 

880 else: 

881 scores[key] = { 

882 "match": match, 

883 "fts_score": 0.0, 

884 "semantic_score": match.score, 

885 } 

886 

887 # Compute combined scores 

888 merged = [] 

889 for data in scores.values(): 

890 combined = (fts_weight * data["fts_score"]) + (semantic_weight * data["semantic_score"]) 

891 match = data["match"] 

892 match.score = combined 

893 match.match_type = "hybrid" 

894 merged.append(match) 

895 

896 return merged 

897 

898 async def search_context( 

899 self, 

900 query: str, 

901 session_id: str, 

902 sources: Optional[List[str]] = None, 

903 max_results: Optional[int] = None, 

904 ) -> ContextResults: 

905 """ 

906 Search with surrounding context for a specific session. 

907 

908 Args: 

909 query: Search query. 

910 session_id: Session to search within. 

911 sources: Optional list of sources to search. 

912 max_results: Maximum results to return. 

913 

914 Returns: 

915 ContextResults with matches and surrounding context. 

916 """ 

917 start_time = time.time() 

918 max_results = max_results or self.config.max_results 

919 

920 results = ContextResults(query=query, session_id=session_id) 

921 

922 # First do a regular search filtered by session 

923 filters = {"session_id": session_id} 

924 search_results = await self.search( 

925 query=query, 

926 sources=sources, 

927 max_results=max_results, 

928 filters=filters, 

929 ) 

930 

931 results.errors = search_results.errors 

932 

933 # Get context for each match 

934 for match in search_results.matches: 

935 if not match.item_id: 

936 continue 

937 

938 # Find the source 

939 source = self._sources.get(match.source) 

940 if not source: 

941 continue 

942 

943 try: 

944 before, after = await source.get_context( 

945 match.item_id, 

946 window=self.config.context_window, 

947 ) 

948 

949 context_match = ContextMatch( 

950 match=match, 

951 before=before, 

952 after=after, 

953 session_id=session_id, 

954 ) 

955 results.context_matches.append(context_match) 

956 

957 except Exception as e: 

958 results.errors.append(f"Failed to get context: {e}") 

959 

960 results.total_count = len(results.context_matches) 

961 results.duration_ms = (time.time() - start_time) * 1000 

962 

963 return results 

964 

965 def close(self) -> None: 

966 """Close all sources that support it.""" 

967 with self._lock: 

968 for source in self._sources.values(): 

969 if hasattr(source, "close"): 

970 try: 

971 source.close() 

972 except Exception: 

973 pass 

974 self._sources.clear() 

975 

976 def __enter__(self): 

977 return self 

978 

979 def __exit__(self, exc_type, exc_val, exc_tb): 

980 self.close() 

981 return False