Coverage for integrations / channels / session_manager.py: 98.5%
205 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Channel Session Manager
4Provides isolated session management for multi-channel messaging.
5Each channel/user combination gets its own conversation context,
6preventing cross-channel data leakage.
8Features:
9- Per-channel conversation history
10- Session state isolation
11- Conversation context management
12- Session timeout/cleanup
13- Memory limits
14"""
16from __future__ import annotations
18import json
19import logging
20import os
21import threading
22from collections import OrderedDict
23from dataclasses import dataclass, field
24from datetime import datetime, timedelta
25from typing import Optional, Dict, List, Any, Tuple
26from pathlib import Path
28logger = logging.getLogger(__name__)
31@dataclass
32class ConversationMessage:
33 """A single message in a conversation."""
34 role: str # "user" or "assistant"
35 content: str
36 timestamp: datetime = field(default_factory=datetime.now)
37 metadata: Dict[str, Any] = field(default_factory=dict)
39 def to_dict(self) -> Dict[str, Any]:
40 """Convert to dictionary."""
41 return {
42 "role": self.role,
43 "content": self.content,
44 "timestamp": self.timestamp.isoformat(),
45 "metadata": self.metadata,
46 }
48 @classmethod
49 def from_dict(cls, data: Dict[str, Any]) -> ConversationMessage:
50 """Create from dictionary."""
51 return cls(
52 role=data["role"],
53 content=data["content"],
54 timestamp=datetime.fromisoformat(data["timestamp"]),
55 metadata=data.get("metadata", {}),
56 )
59@dataclass
60class ChannelSession:
61 """
62 Represents an isolated session for a channel/user combination.
64 Contains conversation history, state, and metadata for a single
65 user's interaction through a specific channel.
66 """
67 channel: str
68 sender_id: str
69 user_id: Optional[int] = None
70 prompt_id: Optional[int] = None
71 created_at: datetime = field(default_factory=datetime.now)
72 last_active: datetime = field(default_factory=datetime.now)
73 messages: List[ConversationMessage] = field(default_factory=list)
74 state: Dict[str, Any] = field(default_factory=dict)
75 metadata: Dict[str, Any] = field(default_factory=dict)
77 # Limits
78 max_messages: int = 100
80 @property
81 def session_key(self) -> Tuple[str, str]:
82 """Get unique session identifier."""
83 return (self.channel, self.sender_id)
85 @property
86 def message_count(self) -> int:
87 """Get number of messages in history."""
88 return len(self.messages)
90 @property
91 def context_window(self) -> List[Dict[str, str]]:
92 """Get messages formatted for LLM context."""
93 return [
94 {"role": msg.role, "content": msg.content}
95 for msg in self.messages
96 ]
98 def add_message(self, role: str, content: str, metadata: Optional[Dict] = None) -> None:
99 """
100 Add a message to the conversation history.
102 Args:
103 role: "user" or "assistant"
104 content: Message content
105 metadata: Optional message metadata
106 """
107 msg = ConversationMessage(
108 role=role,
109 content=content,
110 metadata=metadata or {},
111 )
112 self.messages.append(msg)
113 self.last_active = datetime.now()
115 # Trim if over limit
116 if len(self.messages) > self.max_messages:
117 self.messages = self.messages[-self.max_messages:]
119 def add_user_message(self, content: str, metadata: Optional[Dict] = None) -> None:
120 """Add a user message."""
121 self.add_message("user", content, metadata)
123 def add_assistant_message(self, content: str, metadata: Optional[Dict] = None) -> None:
124 """Add an assistant message."""
125 self.add_message("assistant", content, metadata)
127 def get_state(self, key: str, default: Any = None) -> Any:
128 """Get a state value."""
129 return self.state.get(key, default)
131 def set_state(self, key: str, value: Any) -> None:
132 """Set a state value."""
133 self.state[key] = value
134 self.last_active = datetime.now()
136 def clear_state(self) -> None:
137 """Clear all state."""
138 self.state = {}
140 def clear_history(self) -> None:
141 """Clear conversation history."""
142 self.messages = []
144 def to_dict(self) -> Dict[str, Any]:
145 """Serialize to dictionary."""
146 return {
147 "channel": self.channel,
148 "sender_id": self.sender_id,
149 "user_id": self.user_id,
150 "prompt_id": self.prompt_id,
151 "created_at": self.created_at.isoformat(),
152 "last_active": self.last_active.isoformat(),
153 "messages": [m.to_dict() for m in self.messages],
154 "state": self.state,
155 "metadata": self.metadata,
156 }
158 @classmethod
159 def from_dict(cls, data: Dict[str, Any]) -> ChannelSession:
160 """Deserialize from dictionary."""
161 session = cls(
162 channel=data["channel"],
163 sender_id=data["sender_id"],
164 user_id=data.get("user_id"),
165 prompt_id=data.get("prompt_id"),
166 created_at=datetime.fromisoformat(data["created_at"]),
167 last_active=datetime.fromisoformat(data["last_active"]),
168 state=data.get("state", {}),
169 metadata=data.get("metadata", {}),
170 )
171 session.messages = [
172 ConversationMessage.from_dict(m) for m in data.get("messages", [])
173 ]
174 return session
177class LRUSessionCache(OrderedDict):
178 """LRU cache for sessions."""
180 def __init__(self, maxsize: int = 1000):
181 super().__init__()
182 self.maxsize = maxsize
183 self._lock = threading.Lock()
185 def get(self, key: Tuple[str, str], default=None) -> Optional[ChannelSession]:
186 """Get a session, moving it to end (most recently used)."""
187 with self._lock:
188 if key in self:
189 self.move_to_end(key)
190 return self[key]
191 return default
193 def put(self, key: Tuple[str, str], value: ChannelSession) -> None:
194 """Put a session, evicting oldest if at capacity."""
195 with self._lock:
196 if key in self:
197 self.move_to_end(key)
198 else:
199 if len(self) >= self.maxsize:
200 # Evict oldest
201 oldest_key = next(iter(self))
202 del self[oldest_key]
203 self[key] = value
206class ChannelSessionManager:
207 """
208 Manages isolated sessions for multi-channel messaging.
210 Provides session isolation, ensuring that each channel/user
211 combination has its own conversation context.
213 Usage:
214 manager = ChannelSessionManager()
216 # Get or create session
217 session = manager.get_session("telegram", "user123")
219 # Add messages
220 session.add_user_message("Hello!")
221 session.add_assistant_message("Hi there!")
223 # Get context for LLM
224 context = session.context_window
226 # Store session state
227 session.set_state("language", "en")
228 """
230 def __init__(
231 self,
232 storage_path: Optional[str] = None,
233 max_sessions: int = 1000,
234 session_timeout_hours: int = 24,
235 auto_persist: bool = True,
236 ):
237 self.storage_path = storage_path or os.path.join(
238 os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
239 "agent_data",
240 "channel_sessions.json"
241 )
242 self.session_timeout = timedelta(hours=session_timeout_hours)
243 self.auto_persist = auto_persist
245 # In-memory LRU cache
246 self._sessions = LRUSessionCache(maxsize=max_sessions)
247 self._lock = threading.Lock()
249 # Load persisted sessions
250 self._load_sessions()
252 def get_session(
253 self,
254 channel: str,
255 sender_id: str,
256 user_id: Optional[int] = None,
257 prompt_id: Optional[int] = None,
258 create: bool = True,
259 ) -> Optional[ChannelSession]:
260 """
261 Get or create a session for a channel/user.
263 Args:
264 channel: Channel name (e.g., "telegram", "discord")
265 sender_id: Unique sender identifier from the channel
266 user_id: Agent user ID (if known from pairing)
267 prompt_id: Agent prompt ID (if known from pairing)
268 create: Whether to create if not exists
270 Returns:
271 ChannelSession or None if not found and create=False
272 """
273 key = (channel, sender_id)
275 # Try cache first
276 session = self._sessions.get(key)
277 if session:
278 # Check timeout
279 if datetime.now() - session.last_active > self.session_timeout:
280 # Session expired, remove it
281 with self._lock:
282 if key in self._sessions:
283 del self._sessions[key]
284 session = None
286 # Create if not found and allowed
287 if not session and create:
288 session = ChannelSession(
289 channel=channel,
290 sender_id=sender_id,
291 user_id=user_id,
292 prompt_id=prompt_id,
293 )
294 self._sessions.put(key, session)
296 if self.auto_persist:
297 self._save_sessions()
299 # Update user/prompt ID if provided
300 if session and (user_id is not None or prompt_id is not None):
301 if user_id is not None:
302 session.user_id = user_id
303 if prompt_id is not None:
304 session.prompt_id = prompt_id
306 return session
308 def has_session(self, channel: str, sender_id: str) -> bool:
309 """Check if a session exists."""
310 return self.get_session(channel, sender_id, create=False) is not None
312 def delete_session(self, channel: str, sender_id: str) -> bool:
313 """
314 Delete a session.
316 Returns:
317 True if deleted, False if not found
318 """
319 key = (channel, sender_id)
320 with self._lock:
321 if key in self._sessions:
322 del self._sessions[key]
323 if self.auto_persist:
324 self._save_sessions()
325 return True
326 return False
328 def clear_channel_sessions(self, channel: str) -> int:
329 """
330 Delete all sessions for a channel.
332 Returns:
333 Number of sessions deleted
334 """
335 to_delete = [
336 key for key in self._sessions.keys()
337 if key[0] == channel
338 ]
340 with self._lock:
341 for key in to_delete:
342 del self._sessions[key]
344 if to_delete and self.auto_persist:
345 self._save_sessions()
347 return len(to_delete)
349 def clear_user_sessions(self, user_id: int) -> int:
350 """
351 Delete all sessions for an agent user.
353 Returns:
354 Number of sessions deleted
355 """
356 to_delete = [
357 key for key, session in self._sessions.items()
358 if session.user_id == user_id
359 ]
361 with self._lock:
362 for key in to_delete:
363 del self._sessions[key]
365 if to_delete and self.auto_persist:
366 self._save_sessions()
368 return len(to_delete)
370 def list_sessions(
371 self,
372 channel: Optional[str] = None,
373 user_id: Optional[int] = None,
374 ) -> List[ChannelSession]:
375 """
376 List sessions with optional filtering.
378 Args:
379 channel: Filter by channel
380 user_id: Filter by user ID
382 Returns:
383 List of matching sessions
384 """
385 sessions = list(self._sessions.values())
387 if channel:
388 sessions = [s for s in sessions if s.channel == channel]
389 if user_id is not None:
390 sessions = [s for s in sessions if s.user_id == user_id]
392 return sessions
394 def get_session_count(self, channel: Optional[str] = None) -> int:
395 """Get number of active sessions."""
396 if channel:
397 return len([s for s in self._sessions.values() if s.channel == channel])
398 return len(self._sessions)
400 def cleanup_expired(self) -> int:
401 """
402 Remove expired sessions.
404 Returns:
405 Number of sessions removed
406 """
407 now = datetime.now()
408 to_delete = [
409 key for key, session in self._sessions.items()
410 if now - session.last_active > self.session_timeout
411 ]
413 with self._lock:
414 for key in to_delete:
415 del self._sessions[key]
417 if to_delete and self.auto_persist:
418 self._save_sessions()
420 return len(to_delete)
422 def persist(self) -> None:
423 """Manually persist sessions to storage."""
424 self._save_sessions()
426 def _save_sessions(self) -> None:
427 """Save sessions to storage."""
428 try:
429 data = {
430 "sessions": [
431 session.to_dict()
432 for session in self._sessions.values()
433 ],
434 "saved_at": datetime.now().isoformat(),
435 }
437 os.makedirs(os.path.dirname(self.storage_path), exist_ok=True)
438 with open(self.storage_path, 'w') as f:
439 json.dump(data, f, indent=2)
441 logger.debug(f"Saved {len(data['sessions'])} sessions")
443 except Exception as e:
444 logger.error(f"Failed to save sessions: {e}")
446 def _load_sessions(self) -> None:
447 """Load sessions from storage."""
448 try:
449 if os.path.exists(self.storage_path):
450 with open(self.storage_path, 'r') as f:
451 data = json.load(f)
453 loaded = 0
454 for session_data in data.get("sessions", []):
455 try:
456 session = ChannelSession.from_dict(session_data)
458 # Skip expired sessions
459 if datetime.now() - session.last_active > self.session_timeout:
460 continue
462 self._sessions.put(session.session_key, session)
463 loaded += 1
464 except Exception as e:
465 logger.warning(f"Failed to load session: {e}")
467 logger.info(f"Loaded {loaded} sessions")
469 except Exception as e:
470 logger.error(f"Failed to load sessions: {e}")
473class SessionIsolationMiddleware:
474 """
475 Middleware that provides session isolation for message handling.
477 Usage:
478 middleware = SessionIsolationMiddleware(session_manager)
480 # In message handler:
481 session = middleware.get_session_for_message(message)
482 session.add_user_message(message.text)
484 # Process with LLM using session.context_window
486 session.add_assistant_message(response)
487 """
489 def __init__(
490 self,
491 session_manager: ChannelSessionManager,
492 pairing_manager: Optional[Any] = None, # PairingManager
493 ):
494 self.session_manager = session_manager
495 self.pairing_manager = pairing_manager
497 def get_session_for_message(self, message: Any) -> ChannelSession:
498 """
499 Get session for a message, with pairing integration if available.
501 Args:
502 message: Message object with channel, sender_id attributes
504 Returns:
505 ChannelSession for this message's sender
506 """
507 channel = message.channel
508 sender_id = message.sender_id
510 # Get user mapping from pairing if available
511 user_id = None
512 prompt_id = None
513 if self.pairing_manager:
514 mapping = self.pairing_manager.get_user_mapping(channel, sender_id)
515 if mapping:
516 user_id, prompt_id = mapping
518 return self.session_manager.get_session(
519 channel=channel,
520 sender_id=sender_id,
521 user_id=user_id,
522 prompt_id=prompt_id,
523 )
526# Singleton instance
527_session_manager: Optional[ChannelSessionManager] = None
530def get_session_manager() -> ChannelSessionManager:
531 """Get or create the global session manager."""
532 global _session_manager
533 if _session_manager is None:
534 _session_manager = ChannelSessionManager()
535 return _session_manager