Coverage for integrations / channels / memory / memory_graph.py: 30.1%

183 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Memory Graph — Framework-agnostic provenance-aware memory layer. 

3 

4Builds on top of MemoryStore (SQLite FTS5 + embeddings) and adds: 

5- Provenance tracking via memory_links table (parent/child chains) 

6- Registration with auto-linking to recent memories 

7- Semantic and direct backtrace through memory chains 

8- Lifecycle event recording for agent status transitions 

9- Context-aware recall from recent conversation 

10 

11Zero framework dependencies — works with autogen, LangChain, or any agent framework. 

12""" 

13 

14import json 

15import logging 

16import sqlite3 

17import time 

18import uuid 

19from dataclasses import dataclass, field 

20from datetime import datetime 

21from pathlib import Path 

22from typing import Any, Dict, List, Optional, Tuple 

23 

24from .memory_store import MemoryStore, MemoryItem, SearchResult 

25 

26logger = logging.getLogger(__name__) 

27 

28 

29@dataclass 

30class MemoryNode: 

31 """A memory entry with provenance metadata.""" 

32 

33 id: str 

34 content: str 

35 memory_type: str = "fact" # fact, conversation, decision, insight, lifecycle 

36 source_agent: str = "" # Which agent/framework created this 

37 session_id: str = "" # user_id + prompt_id (scoping key) 

38 user_id: str = "" 

39 parent_ids: List[str] = field(default_factory=list) 

40 context_snapshot: str = "" 

41 created_at: float = field(default_factory=time.time) 

42 accessed_at: float = field(default_factory=time.time) 

43 access_count: int = 0 

44 

45 def to_dict(self) -> Dict[str, Any]: 

46 return { 

47 "id": self.id, 

48 "content": self.content, 

49 "memory_type": self.memory_type, 

50 "source_agent": self.source_agent, 

51 "session_id": self.session_id, 

52 "user_id": self.user_id, 

53 "parent_ids": self.parent_ids, 

54 "context_snapshot": self.context_snapshot, 

55 "created_at": self.created_at, 

56 "accessed_at": self.accessed_at, 

57 "access_count": self.access_count, 

58 } 

59 

60 @classmethod 

61 def from_memory_item(cls, item: MemoryItem) -> "MemoryNode": 

62 """Create a MemoryNode from a MemoryStore MemoryItem.""" 

63 meta = item.metadata or {} 

64 parent_ids_raw = meta.get("parent_ids", "[]") 

65 if isinstance(parent_ids_raw, str): 

66 try: 

67 parent_ids = json.loads(parent_ids_raw) 

68 except (json.JSONDecodeError, TypeError): 

69 parent_ids = [] 

70 else: 

71 parent_ids = parent_ids_raw if isinstance(parent_ids_raw, list) else [] 

72 

73 return cls( 

74 id=item.id, 

75 content=item.content, 

76 memory_type=meta.get("memory_type", "fact"), 

77 source_agent=meta.get("source_agent", ""), 

78 session_id=meta.get("session_id", ""), 

79 user_id=meta.get("user_id", ""), 

80 parent_ids=parent_ids, 

81 context_snapshot=meta.get("context_snapshot", ""), 

82 created_at=item.created_at, 

83 accessed_at=meta.get("accessed_at", item.created_at), 

84 access_count=meta.get("access_count", 0), 

85 ) 

86 

87 

88class MemoryGraph: 

89 """ 

90 Framework-agnostic provenance-aware memory graph. 

91 

92 Wraps MemoryStore for persistence and adds: 

93 - memory_links table for parent/child provenance chains 

94 - register_conversation() for auto-linked conversation turns 

95 - register_lifecycle() for agent status transitions 

96 - backtrace() for direct chain walking 

97 - backtrace_semantic() for semantic + chain walking 

98 """ 

99 

100 def __init__( 

101 self, 

102 db_path: str, 

103 user_id: str, 

104 embedding_fn=None, 

105 ): 

106 self._user_id = user_id 

107 self._db_path = db_path 

108 

109 # Ensure directory exists 

110 Path(db_path).mkdir(parents=True, exist_ok=True) 

111 db_file = str(Path(db_path) / "memory_graph.db") 

112 

113 self._store = MemoryStore( 

114 db_path=db_file, 

115 embedding_fn=embedding_fn, 

116 ) 

117 self._init_links_table() 

118 

119 def _init_links_table(self): 

120 """Create memory_links table and add provenance columns.""" 

121 conn = self._store._ensure_connection() 

122 with self._store._lock: 

123 conn.execute(""" 

124 CREATE TABLE IF NOT EXISTS memory_links ( 

125 id TEXT PRIMARY KEY, 

126 source_id TEXT NOT NULL, 

127 target_id TEXT NOT NULL, 

128 link_type TEXT DEFAULT 'derived', 

129 context TEXT DEFAULT '', 

130 created_at TEXT DEFAULT CURRENT_TIMESTAMP 

131 ) 

132 """) 

133 conn.execute( 

134 "CREATE INDEX IF NOT EXISTS idx_links_source ON memory_links(source_id)" 

135 ) 

136 conn.execute( 

137 "CREATE INDEX IF NOT EXISTS idx_links_target ON memory_links(target_id)" 

138 ) 

139 

140 # ========================================================================= 

141 # Registration 

142 # ========================================================================= 

143 

144 def register( 

145 self, 

146 content: str, 

147 metadata: Optional[Dict[str, Any]] = None, 

148 parent_ids: Optional[List[str]] = None, 

149 context_snapshot: str = "", 

150 ) -> str: 

151 """ 

152 Store a memory with provenance. Returns memory_id. 

153 

154 Args: 

155 content: The memory content. 

156 metadata: Dict with memory_type, source_agent, etc. 

157 parent_ids: IDs of memories that led to this one. 

158 context_snapshot: Summary of context when created. 

159 

160 Returns: 

161 The generated memory ID. 

162 """ 

163 memory_id = uuid.uuid4().hex[:16] 

164 metadata = metadata or {} 

165 parent_ids = parent_ids or [] 

166 

167 # Merge provenance into metadata for MemoryStore storage 

168 full_metadata = { 

169 **metadata, 

170 "memory_type": metadata.get("memory_type", "fact"), 

171 "source_agent": metadata.get("source_agent", ""), 

172 "session_id": metadata.get("session_id", ""), 

173 "user_id": self._user_id, 

174 "parent_ids": json.dumps(parent_ids), 

175 "context_snapshot": context_snapshot, 

176 "accessed_at": time.time(), 

177 "access_count": 0, 

178 } 

179 

180 # Store in MemoryStore (gets FTS5 + optional embedding) 

181 self._store.add( 

182 content=content, 

183 metadata=full_metadata, 

184 source=metadata.get("memory_type", "fact"), 

185 item_id=memory_id, 

186 ) 

187 

188 # Insert provenance links 

189 if parent_ids: 

190 conn = self._store._ensure_connection() 

191 with self._store._lock: 

192 for pid in parent_ids: 

193 link_id = uuid.uuid4().hex[:16] 

194 conn.execute( 

195 "INSERT OR IGNORE INTO memory_links (id, source_id, target_id, link_type, context) VALUES (?, ?, ?, ?, ?)", 

196 (link_id, pid, memory_id, "derived", context_snapshot[:200]), 

197 ) 

198 

199 logger.debug(f"Registered memory {memory_id}: {content[:50]}...") 

200 return memory_id 

201 

202 def register_conversation( 

203 self, 

204 speaker: str, 

205 content: str, 

206 session_id: str, 

207 ) -> str: 

208 """ 

209 Auto-register a conversation turn, linking to the previous turn. 

210 

211 Args: 

212 speaker: Who said this (agent name, 'user', etc.) 

213 content: The message content. 

214 session_id: Session scope (e.g. user_id_prompt_id). 

215 

216 Returns: 

217 Memory ID. 

218 """ 

219 # Find the most recent conversation memory in this session 

220 recent = self._get_latest_session_memory(session_id) 

221 parent_ids = [recent.id] if recent else [] 

222 

223 return self.register( 

224 content=content, 

225 metadata={ 

226 "memory_type": "conversation", 

227 "source_agent": speaker, 

228 "session_id": session_id, 

229 }, 

230 parent_ids=parent_ids, 

231 context_snapshot=f"Conversation by {speaker} in session {session_id}", 

232 ) 

233 

234 def register_lifecycle( 

235 self, 

236 event: str, 

237 agent_id: str, 

238 session_id: str, 

239 details: str = "", 

240 ) -> str: 

241 """ 

242 Record an agent lifecycle transition. 

243 

244 Args: 

245 event: Lifecycle status (e.g. 'Creation Mode', 'Review Mode', 'completed'). 

246 agent_id: The agent/user ID. 

247 session_id: Session scope. 

248 details: Additional details about the transition. 

249 

250 Returns: 

251 Memory ID. 

252 """ 

253 # Link to previous lifecycle event in this session 

254 recent = self._get_latest_session_memory(session_id, memory_type="lifecycle") 

255 parent_ids = [recent.id] if recent else [] 

256 

257 return self.register( 

258 content=f"[LIFECYCLE] {event}: {details}", 

259 metadata={ 

260 "memory_type": "lifecycle", 

261 "source_agent": agent_id, 

262 "session_id": session_id, 

263 "lifecycle_event": event, 

264 }, 

265 parent_ids=parent_ids, 

266 context_snapshot=f"Agent status: {event}", 

267 ) 

268 

269 # ========================================================================= 

270 # Recall 

271 # ========================================================================= 

272 

273 def recall( 

274 self, 

275 query: str, 

276 mode: str = "hybrid", 

277 top_k: int = 5, 

278 since: Optional[float] = None, 

279 until: Optional[float] = None, 

280 ) -> List[MemoryNode]: 

281 """ 

282 Search memories by text, semantic, or hybrid search, optionally 

283 filtered to a time window. 

284 

285 Args: 

286 query: Search query. 

287 mode: 'text', 'semantic', or 'hybrid'. 

288 top_k: Max results. 

289 since: Optional lower bound (UNIX epoch seconds). Memories 

290 with created_at < since are dropped from the result. 

291 until: Optional upper bound (UNIX epoch seconds). Memories 

292 with created_at > until are dropped from the result. 

293 

294 Returns: 

295 List of MemoryNode results. 

296 """ 

297 # When a time range is specified, pull more candidates from the 

298 # store so post-filtering can still return top_k. SQLite FTS5 

299 # returns them in relevance order, which is what we want inside 

300 # the window — we just need a bigger buffer than top_k to not 

301 # run out after filtering. 

302 fetch_k = top_k if since is None and until is None else top_k * 6 

303 if mode == "text": 

304 results = self._store.search(query, max_results=fetch_k) 

305 elif mode == "semantic": 

306 results = self._store.search_semantic(query, max_results=fetch_k) 

307 else: 

308 results = self._store.search_hybrid(query, max_results=fetch_k) 

309 

310 nodes = [] 

311 for sr in results: 

312 node = MemoryNode.from_memory_item(sr.item) 

313 if since is not None and (node.created_at or 0) < since: 

314 continue 

315 if until is not None and (node.created_at or 0) > until: 

316 continue 

317 # Update access tracking 

318 self._update_access(node.id) 

319 node.access_count += 1 

320 node.accessed_at = time.time() 

321 nodes.append(node) 

322 if len(nodes) >= top_k: 

323 break 

324 

325 return nodes 

326 

327 def context_recall( 

328 self, 

329 recent_messages: List[str], 

330 top_k: int = 3, 

331 ) -> List[MemoryNode]: 

332 """ 

333 Auto-recall: combine recent messages into a query and search. 

334 

335 Args: 

336 recent_messages: List of recent message strings. 

337 top_k: Max results. 

338 

339 Returns: 

340 List of relevant MemoryNodes. 

341 """ 

342 if not recent_messages: 

343 return [] 

344 

345 # Combine recent messages into a single query 

346 combined = " ".join(msg[:200] for msg in recent_messages[-3:]) 

347 if not combined.strip(): 

348 return [] 

349 

350 return self.recall(combined, mode="hybrid", top_k=top_k) 

351 

352 def get_session_memories( 

353 self, 

354 session_id: str, 

355 limit: int = 50, 

356 ) -> List[MemoryNode]: 

357 """Get all memories from a specific session, ordered by creation time.""" 

358 conn = self._store._ensure_connection() 

359 with self._store._lock: 

360 rows = conn.execute( 

361 """ 

362 SELECT * FROM memory_items 

363 WHERE json_extract(metadata, '$.session_id') = ? 

364 ORDER BY created_at ASC 

365 LIMIT ? 

366 """, 

367 (session_id, limit), 

368 ).fetchall() 

369 

370 nodes = [] 

371 for row in rows: 

372 item = self._store._row_to_item(row) 

373 nodes.append(MemoryNode.from_memory_item(item)) 

374 return nodes 

375 

376 # ========================================================================= 

377 # Backtrace 

378 # ========================================================================= 

379 

380 def backtrace(self, memory_id: str, depth: int = 10) -> List[MemoryNode]: 

381 """ 

382 Direct backtrace: walk parent links from memory_id back to origin. 

383 

384 Returns ordered list: [origin, ..., intermediate, ..., target]. 

385 """ 

386 chain = [] 

387 visited = set() 

388 current_id = memory_id 

389 

390 for _ in range(depth): 

391 if current_id in visited: 

392 break 

393 visited.add(current_id) 

394 

395 item = self._store.get(current_id) 

396 if not item: 

397 break 

398 

399 node = MemoryNode.from_memory_item(item) 

400 chain.append(node) 

401 

402 # Find parent via memory_links 

403 parent_id = self._get_parent_id(current_id) 

404 if not parent_id: 

405 break 

406 current_id = parent_id 

407 

408 # Reverse so origin comes first 

409 chain.reverse() 

410 return chain 

411 

412 def backtrace_semantic( 

413 self, 

414 query: str, 

415 depth: int = 5, 

416 top_k: int = 3, 

417 ) -> List[List[MemoryNode]]: 

418 """ 

419 Semantic backtrace: find nearest memories, then trace each one back. 

420 

421 Returns list of chains, one per matching memory. 

422 """ 

423 matches = self.recall(query, mode="hybrid", top_k=top_k) 

424 chains = [] 

425 

426 for node in matches: 

427 chain = self.backtrace(node.id, depth=depth) 

428 if chain: 

429 chains.append(chain) 

430 

431 return chains 

432 

433 def get_memory_chain(self, memory_id: str) -> Dict[str, Any]: 

434 """ 

435 Get full chain: parents -> this -> children. 

436 

437 Returns tree structure with the target memory at center. 

438 """ 

439 item = self._store.get(memory_id) 

440 if not item: 

441 return {"error": f"Memory {memory_id} not found"} 

442 

443 node = MemoryNode.from_memory_item(item) 

444 

445 # Walk parents 

446 parents = self.backtrace(memory_id) 

447 # Remove the target itself from parents list 

448 parents = [p for p in parents if p.id != memory_id] 

449 

450 # Walk children 

451 children = self._get_children(memory_id) 

452 

453 return { 

454 "target": node.to_dict(), 

455 "parents": [p.to_dict() for p in parents], 

456 "children": [c.to_dict() for c in children], 

457 } 

458 

459 # ========================================================================= 

460 # Internal helpers 

461 # ========================================================================= 

462 

463 def _get_parent_id(self, memory_id: str) -> Optional[str]: 

464 """Get the parent memory ID from memory_links.""" 

465 conn = self._store._ensure_connection() 

466 with self._store._lock: 

467 row = conn.execute( 

468 "SELECT source_id FROM memory_links WHERE target_id = ? ORDER BY created_at DESC LIMIT 1", 

469 (memory_id,), 

470 ).fetchone() 

471 return row["source_id"] if row else None 

472 

473 def _get_children(self, memory_id: str) -> List[MemoryNode]: 

474 """Get direct children of a memory.""" 

475 conn = self._store._ensure_connection() 

476 with self._store._lock: 

477 rows = conn.execute( 

478 "SELECT target_id FROM memory_links WHERE source_id = ? ORDER BY created_at ASC", 

479 (memory_id,), 

480 ).fetchall() 

481 

482 children = [] 

483 for row in rows: 

484 item = self._store.get(row["target_id"]) 

485 if item: 

486 children.append(MemoryNode.from_memory_item(item)) 

487 return children 

488 

489 def _get_latest_session_memory( 

490 self, 

491 session_id: str, 

492 memory_type: Optional[str] = None, 

493 ) -> Optional[MemoryNode]: 

494 """Get the most recent memory in a session.""" 

495 conn = self._store._ensure_connection() 

496 with self._store._lock: 

497 if memory_type: 

498 row = conn.execute( 

499 """ 

500 SELECT * FROM memory_items 

501 WHERE json_extract(metadata, '$.session_id') = ? 

502 AND json_extract(metadata, '$.memory_type') = ? 

503 ORDER BY created_at DESC LIMIT 1 

504 """, 

505 (session_id, memory_type), 

506 ).fetchone() 

507 else: 

508 row = conn.execute( 

509 """ 

510 SELECT * FROM memory_items 

511 WHERE json_extract(metadata, '$.session_id') = ? 

512 ORDER BY created_at DESC LIMIT 1 

513 """, 

514 (session_id,), 

515 ).fetchone() 

516 

517 if not row: 

518 return None 

519 item = self._store._row_to_item(row) 

520 return MemoryNode.from_memory_item(item) 

521 

522 def _update_access(self, memory_id: str): 

523 """Update accessed_at and access_count for a memory.""" 

524 conn = self._store._ensure_connection() 

525 with self._store._lock: 

526 try: 

527 row = conn.execute( 

528 "SELECT metadata FROM memory_items WHERE id = ?", 

529 (memory_id,), 

530 ).fetchone() 

531 if row and row["metadata"]: 

532 meta = json.loads(row["metadata"]) 

533 meta["accessed_at"] = time.time() 

534 meta["access_count"] = meta.get("access_count", 0) + 1 

535 conn.execute( 

536 "UPDATE memory_items SET metadata = ?, updated_at = ? WHERE id = ?", 

537 (json.dumps(meta), time.time(), memory_id), 

538 ) 

539 except Exception: 

540 pass # Non-blocking 

541 

542 def close(self): 

543 """Close the underlying MemoryStore connection.""" 

544 self._store.close()