Coverage for integrations / channels / memory / memory_graph.py: 30.1%
183 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Memory Graph — Framework-agnostic provenance-aware memory layer.
4Builds on top of MemoryStore (SQLite FTS5 + embeddings) and adds:
5- Provenance tracking via memory_links table (parent/child chains)
6- Registration with auto-linking to recent memories
7- Semantic and direct backtrace through memory chains
8- Lifecycle event recording for agent status transitions
9- Context-aware recall from recent conversation
11Zero framework dependencies — works with autogen, LangChain, or any agent framework.
12"""
14import json
15import logging
16import sqlite3
17import time
18import uuid
19from dataclasses import dataclass, field
20from datetime import datetime
21from pathlib import Path
22from typing import Any, Dict, List, Optional, Tuple
24from .memory_store import MemoryStore, MemoryItem, SearchResult
26logger = logging.getLogger(__name__)
29@dataclass
30class MemoryNode:
31 """A memory entry with provenance metadata."""
33 id: str
34 content: str
35 memory_type: str = "fact" # fact, conversation, decision, insight, lifecycle
36 source_agent: str = "" # Which agent/framework created this
37 session_id: str = "" # user_id + prompt_id (scoping key)
38 user_id: str = ""
39 parent_ids: List[str] = field(default_factory=list)
40 context_snapshot: str = ""
41 created_at: float = field(default_factory=time.time)
42 accessed_at: float = field(default_factory=time.time)
43 access_count: int = 0
45 def to_dict(self) -> Dict[str, Any]:
46 return {
47 "id": self.id,
48 "content": self.content,
49 "memory_type": self.memory_type,
50 "source_agent": self.source_agent,
51 "session_id": self.session_id,
52 "user_id": self.user_id,
53 "parent_ids": self.parent_ids,
54 "context_snapshot": self.context_snapshot,
55 "created_at": self.created_at,
56 "accessed_at": self.accessed_at,
57 "access_count": self.access_count,
58 }
60 @classmethod
61 def from_memory_item(cls, item: MemoryItem) -> "MemoryNode":
62 """Create a MemoryNode from a MemoryStore MemoryItem."""
63 meta = item.metadata or {}
64 parent_ids_raw = meta.get("parent_ids", "[]")
65 if isinstance(parent_ids_raw, str):
66 try:
67 parent_ids = json.loads(parent_ids_raw)
68 except (json.JSONDecodeError, TypeError):
69 parent_ids = []
70 else:
71 parent_ids = parent_ids_raw if isinstance(parent_ids_raw, list) else []
73 return cls(
74 id=item.id,
75 content=item.content,
76 memory_type=meta.get("memory_type", "fact"),
77 source_agent=meta.get("source_agent", ""),
78 session_id=meta.get("session_id", ""),
79 user_id=meta.get("user_id", ""),
80 parent_ids=parent_ids,
81 context_snapshot=meta.get("context_snapshot", ""),
82 created_at=item.created_at,
83 accessed_at=meta.get("accessed_at", item.created_at),
84 access_count=meta.get("access_count", 0),
85 )
88class MemoryGraph:
89 """
90 Framework-agnostic provenance-aware memory graph.
92 Wraps MemoryStore for persistence and adds:
93 - memory_links table for parent/child provenance chains
94 - register_conversation() for auto-linked conversation turns
95 - register_lifecycle() for agent status transitions
96 - backtrace() for direct chain walking
97 - backtrace_semantic() for semantic + chain walking
98 """
100 def __init__(
101 self,
102 db_path: str,
103 user_id: str,
104 embedding_fn=None,
105 ):
106 self._user_id = user_id
107 self._db_path = db_path
109 # Ensure directory exists
110 Path(db_path).mkdir(parents=True, exist_ok=True)
111 db_file = str(Path(db_path) / "memory_graph.db")
113 self._store = MemoryStore(
114 db_path=db_file,
115 embedding_fn=embedding_fn,
116 )
117 self._init_links_table()
119 def _init_links_table(self):
120 """Create memory_links table and add provenance columns."""
121 conn = self._store._ensure_connection()
122 with self._store._lock:
123 conn.execute("""
124 CREATE TABLE IF NOT EXISTS memory_links (
125 id TEXT PRIMARY KEY,
126 source_id TEXT NOT NULL,
127 target_id TEXT NOT NULL,
128 link_type TEXT DEFAULT 'derived',
129 context TEXT DEFAULT '',
130 created_at TEXT DEFAULT CURRENT_TIMESTAMP
131 )
132 """)
133 conn.execute(
134 "CREATE INDEX IF NOT EXISTS idx_links_source ON memory_links(source_id)"
135 )
136 conn.execute(
137 "CREATE INDEX IF NOT EXISTS idx_links_target ON memory_links(target_id)"
138 )
140 # =========================================================================
141 # Registration
142 # =========================================================================
144 def register(
145 self,
146 content: str,
147 metadata: Optional[Dict[str, Any]] = None,
148 parent_ids: Optional[List[str]] = None,
149 context_snapshot: str = "",
150 ) -> str:
151 """
152 Store a memory with provenance. Returns memory_id.
154 Args:
155 content: The memory content.
156 metadata: Dict with memory_type, source_agent, etc.
157 parent_ids: IDs of memories that led to this one.
158 context_snapshot: Summary of context when created.
160 Returns:
161 The generated memory ID.
162 """
163 memory_id = uuid.uuid4().hex[:16]
164 metadata = metadata or {}
165 parent_ids = parent_ids or []
167 # Merge provenance into metadata for MemoryStore storage
168 full_metadata = {
169 **metadata,
170 "memory_type": metadata.get("memory_type", "fact"),
171 "source_agent": metadata.get("source_agent", ""),
172 "session_id": metadata.get("session_id", ""),
173 "user_id": self._user_id,
174 "parent_ids": json.dumps(parent_ids),
175 "context_snapshot": context_snapshot,
176 "accessed_at": time.time(),
177 "access_count": 0,
178 }
180 # Store in MemoryStore (gets FTS5 + optional embedding)
181 self._store.add(
182 content=content,
183 metadata=full_metadata,
184 source=metadata.get("memory_type", "fact"),
185 item_id=memory_id,
186 )
188 # Insert provenance links
189 if parent_ids:
190 conn = self._store._ensure_connection()
191 with self._store._lock:
192 for pid in parent_ids:
193 link_id = uuid.uuid4().hex[:16]
194 conn.execute(
195 "INSERT OR IGNORE INTO memory_links (id, source_id, target_id, link_type, context) VALUES (?, ?, ?, ?, ?)",
196 (link_id, pid, memory_id, "derived", context_snapshot[:200]),
197 )
199 logger.debug(f"Registered memory {memory_id}: {content[:50]}...")
200 return memory_id
202 def register_conversation(
203 self,
204 speaker: str,
205 content: str,
206 session_id: str,
207 ) -> str:
208 """
209 Auto-register a conversation turn, linking to the previous turn.
211 Args:
212 speaker: Who said this (agent name, 'user', etc.)
213 content: The message content.
214 session_id: Session scope (e.g. user_id_prompt_id).
216 Returns:
217 Memory ID.
218 """
219 # Find the most recent conversation memory in this session
220 recent = self._get_latest_session_memory(session_id)
221 parent_ids = [recent.id] if recent else []
223 return self.register(
224 content=content,
225 metadata={
226 "memory_type": "conversation",
227 "source_agent": speaker,
228 "session_id": session_id,
229 },
230 parent_ids=parent_ids,
231 context_snapshot=f"Conversation by {speaker} in session {session_id}",
232 )
234 def register_lifecycle(
235 self,
236 event: str,
237 agent_id: str,
238 session_id: str,
239 details: str = "",
240 ) -> str:
241 """
242 Record an agent lifecycle transition.
244 Args:
245 event: Lifecycle status (e.g. 'Creation Mode', 'Review Mode', 'completed').
246 agent_id: The agent/user ID.
247 session_id: Session scope.
248 details: Additional details about the transition.
250 Returns:
251 Memory ID.
252 """
253 # Link to previous lifecycle event in this session
254 recent = self._get_latest_session_memory(session_id, memory_type="lifecycle")
255 parent_ids = [recent.id] if recent else []
257 return self.register(
258 content=f"[LIFECYCLE] {event}: {details}",
259 metadata={
260 "memory_type": "lifecycle",
261 "source_agent": agent_id,
262 "session_id": session_id,
263 "lifecycle_event": event,
264 },
265 parent_ids=parent_ids,
266 context_snapshot=f"Agent status: {event}",
267 )
269 # =========================================================================
270 # Recall
271 # =========================================================================
273 def recall(
274 self,
275 query: str,
276 mode: str = "hybrid",
277 top_k: int = 5,
278 since: Optional[float] = None,
279 until: Optional[float] = None,
280 ) -> List[MemoryNode]:
281 """
282 Search memories by text, semantic, or hybrid search, optionally
283 filtered to a time window.
285 Args:
286 query: Search query.
287 mode: 'text', 'semantic', or 'hybrid'.
288 top_k: Max results.
289 since: Optional lower bound (UNIX epoch seconds). Memories
290 with created_at < since are dropped from the result.
291 until: Optional upper bound (UNIX epoch seconds). Memories
292 with created_at > until are dropped from the result.
294 Returns:
295 List of MemoryNode results.
296 """
297 # When a time range is specified, pull more candidates from the
298 # store so post-filtering can still return top_k. SQLite FTS5
299 # returns them in relevance order, which is what we want inside
300 # the window — we just need a bigger buffer than top_k to not
301 # run out after filtering.
302 fetch_k = top_k if since is None and until is None else top_k * 6
303 if mode == "text":
304 results = self._store.search(query, max_results=fetch_k)
305 elif mode == "semantic":
306 results = self._store.search_semantic(query, max_results=fetch_k)
307 else:
308 results = self._store.search_hybrid(query, max_results=fetch_k)
310 nodes = []
311 for sr in results:
312 node = MemoryNode.from_memory_item(sr.item)
313 if since is not None and (node.created_at or 0) < since:
314 continue
315 if until is not None and (node.created_at or 0) > until:
316 continue
317 # Update access tracking
318 self._update_access(node.id)
319 node.access_count += 1
320 node.accessed_at = time.time()
321 nodes.append(node)
322 if len(nodes) >= top_k:
323 break
325 return nodes
327 def context_recall(
328 self,
329 recent_messages: List[str],
330 top_k: int = 3,
331 ) -> List[MemoryNode]:
332 """
333 Auto-recall: combine recent messages into a query and search.
335 Args:
336 recent_messages: List of recent message strings.
337 top_k: Max results.
339 Returns:
340 List of relevant MemoryNodes.
341 """
342 if not recent_messages:
343 return []
345 # Combine recent messages into a single query
346 combined = " ".join(msg[:200] for msg in recent_messages[-3:])
347 if not combined.strip():
348 return []
350 return self.recall(combined, mode="hybrid", top_k=top_k)
352 def get_session_memories(
353 self,
354 session_id: str,
355 limit: int = 50,
356 ) -> List[MemoryNode]:
357 """Get all memories from a specific session, ordered by creation time."""
358 conn = self._store._ensure_connection()
359 with self._store._lock:
360 rows = conn.execute(
361 """
362 SELECT * FROM memory_items
363 WHERE json_extract(metadata, '$.session_id') = ?
364 ORDER BY created_at ASC
365 LIMIT ?
366 """,
367 (session_id, limit),
368 ).fetchall()
370 nodes = []
371 for row in rows:
372 item = self._store._row_to_item(row)
373 nodes.append(MemoryNode.from_memory_item(item))
374 return nodes
376 # =========================================================================
377 # Backtrace
378 # =========================================================================
380 def backtrace(self, memory_id: str, depth: int = 10) -> List[MemoryNode]:
381 """
382 Direct backtrace: walk parent links from memory_id back to origin.
384 Returns ordered list: [origin, ..., intermediate, ..., target].
385 """
386 chain = []
387 visited = set()
388 current_id = memory_id
390 for _ in range(depth):
391 if current_id in visited:
392 break
393 visited.add(current_id)
395 item = self._store.get(current_id)
396 if not item:
397 break
399 node = MemoryNode.from_memory_item(item)
400 chain.append(node)
402 # Find parent via memory_links
403 parent_id = self._get_parent_id(current_id)
404 if not parent_id:
405 break
406 current_id = parent_id
408 # Reverse so origin comes first
409 chain.reverse()
410 return chain
412 def backtrace_semantic(
413 self,
414 query: str,
415 depth: int = 5,
416 top_k: int = 3,
417 ) -> List[List[MemoryNode]]:
418 """
419 Semantic backtrace: find nearest memories, then trace each one back.
421 Returns list of chains, one per matching memory.
422 """
423 matches = self.recall(query, mode="hybrid", top_k=top_k)
424 chains = []
426 for node in matches:
427 chain = self.backtrace(node.id, depth=depth)
428 if chain:
429 chains.append(chain)
431 return chains
433 def get_memory_chain(self, memory_id: str) -> Dict[str, Any]:
434 """
435 Get full chain: parents -> this -> children.
437 Returns tree structure with the target memory at center.
438 """
439 item = self._store.get(memory_id)
440 if not item:
441 return {"error": f"Memory {memory_id} not found"}
443 node = MemoryNode.from_memory_item(item)
445 # Walk parents
446 parents = self.backtrace(memory_id)
447 # Remove the target itself from parents list
448 parents = [p for p in parents if p.id != memory_id]
450 # Walk children
451 children = self._get_children(memory_id)
453 return {
454 "target": node.to_dict(),
455 "parents": [p.to_dict() for p in parents],
456 "children": [c.to_dict() for c in children],
457 }
459 # =========================================================================
460 # Internal helpers
461 # =========================================================================
463 def _get_parent_id(self, memory_id: str) -> Optional[str]:
464 """Get the parent memory ID from memory_links."""
465 conn = self._store._ensure_connection()
466 with self._store._lock:
467 row = conn.execute(
468 "SELECT source_id FROM memory_links WHERE target_id = ? ORDER BY created_at DESC LIMIT 1",
469 (memory_id,),
470 ).fetchone()
471 return row["source_id"] if row else None
473 def _get_children(self, memory_id: str) -> List[MemoryNode]:
474 """Get direct children of a memory."""
475 conn = self._store._ensure_connection()
476 with self._store._lock:
477 rows = conn.execute(
478 "SELECT target_id FROM memory_links WHERE source_id = ? ORDER BY created_at ASC",
479 (memory_id,),
480 ).fetchall()
482 children = []
483 for row in rows:
484 item = self._store.get(row["target_id"])
485 if item:
486 children.append(MemoryNode.from_memory_item(item))
487 return children
489 def _get_latest_session_memory(
490 self,
491 session_id: str,
492 memory_type: Optional[str] = None,
493 ) -> Optional[MemoryNode]:
494 """Get the most recent memory in a session."""
495 conn = self._store._ensure_connection()
496 with self._store._lock:
497 if memory_type:
498 row = conn.execute(
499 """
500 SELECT * FROM memory_items
501 WHERE json_extract(metadata, '$.session_id') = ?
502 AND json_extract(metadata, '$.memory_type') = ?
503 ORDER BY created_at DESC LIMIT 1
504 """,
505 (session_id, memory_type),
506 ).fetchone()
507 else:
508 row = conn.execute(
509 """
510 SELECT * FROM memory_items
511 WHERE json_extract(metadata, '$.session_id') = ?
512 ORDER BY created_at DESC LIMIT 1
513 """,
514 (session_id,),
515 ).fetchone()
517 if not row:
518 return None
519 item = self._store._row_to_item(row)
520 return MemoryNode.from_memory_item(item)
522 def _update_access(self, memory_id: str):
523 """Update accessed_at and access_count for a memory."""
524 conn = self._store._ensure_connection()
525 with self._store._lock:
526 try:
527 row = conn.execute(
528 "SELECT metadata FROM memory_items WHERE id = ?",
529 (memory_id,),
530 ).fetchone()
531 if row and row["metadata"]:
532 meta = json.loads(row["metadata"])
533 meta["accessed_at"] = time.time()
534 meta["access_count"] = meta.get("access_count", 0) + 1
535 conn.execute(
536 "UPDATE memory_items SET metadata = ?, updated_at = ? WHERE id = ?",
537 (json.dumps(meta), time.time(), memory_id),
538 )
539 except Exception:
540 pass # Non-blocking
542 def close(self):
543 """Close the underlying MemoryStore connection."""
544 self._store.close()