Coverage for integrations / channels / session_manager.py: 98.5%

205 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Channel Session Manager 

3 

4Provides isolated session management for multi-channel messaging. 

5Each channel/user combination gets its own conversation context, 

6preventing cross-channel data leakage. 

7 

8Features: 

9- Per-channel conversation history 

10- Session state isolation 

11- Conversation context management 

12- Session timeout/cleanup 

13- Memory limits 

14""" 

15 

16from __future__ import annotations 

17 

18import json 

19import logging 

20import os 

21import threading 

22from collections import OrderedDict 

23from dataclasses import dataclass, field 

24from datetime import datetime, timedelta 

25from typing import Optional, Dict, List, Any, Tuple 

26from pathlib import Path 

27 

28logger = logging.getLogger(__name__) 

29 

30 

31@dataclass 

32class ConversationMessage: 

33 """A single message in a conversation.""" 

34 role: str # "user" or "assistant" 

35 content: str 

36 timestamp: datetime = field(default_factory=datetime.now) 

37 metadata: Dict[str, Any] = field(default_factory=dict) 

38 

39 def to_dict(self) -> Dict[str, Any]: 

40 """Convert to dictionary.""" 

41 return { 

42 "role": self.role, 

43 "content": self.content, 

44 "timestamp": self.timestamp.isoformat(), 

45 "metadata": self.metadata, 

46 } 

47 

48 @classmethod 

49 def from_dict(cls, data: Dict[str, Any]) -> ConversationMessage: 

50 """Create from dictionary.""" 

51 return cls( 

52 role=data["role"], 

53 content=data["content"], 

54 timestamp=datetime.fromisoformat(data["timestamp"]), 

55 metadata=data.get("metadata", {}), 

56 ) 

57 

58 

59@dataclass 

60class ChannelSession: 

61 """ 

62 Represents an isolated session for a channel/user combination. 

63 

64 Contains conversation history, state, and metadata for a single 

65 user's interaction through a specific channel. 

66 """ 

67 channel: str 

68 sender_id: str 

69 user_id: Optional[int] = None 

70 prompt_id: Optional[int] = None 

71 created_at: datetime = field(default_factory=datetime.now) 

72 last_active: datetime = field(default_factory=datetime.now) 

73 messages: List[ConversationMessage] = field(default_factory=list) 

74 state: Dict[str, Any] = field(default_factory=dict) 

75 metadata: Dict[str, Any] = field(default_factory=dict) 

76 

77 # Limits 

78 max_messages: int = 100 

79 

80 @property 

81 def session_key(self) -> Tuple[str, str]: 

82 """Get unique session identifier.""" 

83 return (self.channel, self.sender_id) 

84 

85 @property 

86 def message_count(self) -> int: 

87 """Get number of messages in history.""" 

88 return len(self.messages) 

89 

90 @property 

91 def context_window(self) -> List[Dict[str, str]]: 

92 """Get messages formatted for LLM context.""" 

93 return [ 

94 {"role": msg.role, "content": msg.content} 

95 for msg in self.messages 

96 ] 

97 

98 def add_message(self, role: str, content: str, metadata: Optional[Dict] = None) -> None: 

99 """ 

100 Add a message to the conversation history. 

101 

102 Args: 

103 role: "user" or "assistant" 

104 content: Message content 

105 metadata: Optional message metadata 

106 """ 

107 msg = ConversationMessage( 

108 role=role, 

109 content=content, 

110 metadata=metadata or {}, 

111 ) 

112 self.messages.append(msg) 

113 self.last_active = datetime.now() 

114 

115 # Trim if over limit 

116 if len(self.messages) > self.max_messages: 

117 self.messages = self.messages[-self.max_messages:] 

118 

119 def add_user_message(self, content: str, metadata: Optional[Dict] = None) -> None: 

120 """Add a user message.""" 

121 self.add_message("user", content, metadata) 

122 

123 def add_assistant_message(self, content: str, metadata: Optional[Dict] = None) -> None: 

124 """Add an assistant message.""" 

125 self.add_message("assistant", content, metadata) 

126 

127 def get_state(self, key: str, default: Any = None) -> Any: 

128 """Get a state value.""" 

129 return self.state.get(key, default) 

130 

131 def set_state(self, key: str, value: Any) -> None: 

132 """Set a state value.""" 

133 self.state[key] = value 

134 self.last_active = datetime.now() 

135 

136 def clear_state(self) -> None: 

137 """Clear all state.""" 

138 self.state = {} 

139 

140 def clear_history(self) -> None: 

141 """Clear conversation history.""" 

142 self.messages = [] 

143 

144 def to_dict(self) -> Dict[str, Any]: 

145 """Serialize to dictionary.""" 

146 return { 

147 "channel": self.channel, 

148 "sender_id": self.sender_id, 

149 "user_id": self.user_id, 

150 "prompt_id": self.prompt_id, 

151 "created_at": self.created_at.isoformat(), 

152 "last_active": self.last_active.isoformat(), 

153 "messages": [m.to_dict() for m in self.messages], 

154 "state": self.state, 

155 "metadata": self.metadata, 

156 } 

157 

158 @classmethod 

159 def from_dict(cls, data: Dict[str, Any]) -> ChannelSession: 

160 """Deserialize from dictionary.""" 

161 session = cls( 

162 channel=data["channel"], 

163 sender_id=data["sender_id"], 

164 user_id=data.get("user_id"), 

165 prompt_id=data.get("prompt_id"), 

166 created_at=datetime.fromisoformat(data["created_at"]), 

167 last_active=datetime.fromisoformat(data["last_active"]), 

168 state=data.get("state", {}), 

169 metadata=data.get("metadata", {}), 

170 ) 

171 session.messages = [ 

172 ConversationMessage.from_dict(m) for m in data.get("messages", []) 

173 ] 

174 return session 

175 

176 

177class LRUSessionCache(OrderedDict): 

178 """LRU cache for sessions.""" 

179 

180 def __init__(self, maxsize: int = 1000): 

181 super().__init__() 

182 self.maxsize = maxsize 

183 self._lock = threading.Lock() 

184 

185 def get(self, key: Tuple[str, str], default=None) -> Optional[ChannelSession]: 

186 """Get a session, moving it to end (most recently used).""" 

187 with self._lock: 

188 if key in self: 

189 self.move_to_end(key) 

190 return self[key] 

191 return default 

192 

193 def put(self, key: Tuple[str, str], value: ChannelSession) -> None: 

194 """Put a session, evicting oldest if at capacity.""" 

195 with self._lock: 

196 if key in self: 

197 self.move_to_end(key) 

198 else: 

199 if len(self) >= self.maxsize: 

200 # Evict oldest 

201 oldest_key = next(iter(self)) 

202 del self[oldest_key] 

203 self[key] = value 

204 

205 

206class ChannelSessionManager: 

207 """ 

208 Manages isolated sessions for multi-channel messaging. 

209 

210 Provides session isolation, ensuring that each channel/user 

211 combination has its own conversation context. 

212 

213 Usage: 

214 manager = ChannelSessionManager() 

215 

216 # Get or create session 

217 session = manager.get_session("telegram", "user123") 

218 

219 # Add messages 

220 session.add_user_message("Hello!") 

221 session.add_assistant_message("Hi there!") 

222 

223 # Get context for LLM 

224 context = session.context_window 

225 

226 # Store session state 

227 session.set_state("language", "en") 

228 """ 

229 

230 def __init__( 

231 self, 

232 storage_path: Optional[str] = None, 

233 max_sessions: int = 1000, 

234 session_timeout_hours: int = 24, 

235 auto_persist: bool = True, 

236 ): 

237 self.storage_path = storage_path or os.path.join( 

238 os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 

239 "agent_data", 

240 "channel_sessions.json" 

241 ) 

242 self.session_timeout = timedelta(hours=session_timeout_hours) 

243 self.auto_persist = auto_persist 

244 

245 # In-memory LRU cache 

246 self._sessions = LRUSessionCache(maxsize=max_sessions) 

247 self._lock = threading.Lock() 

248 

249 # Load persisted sessions 

250 self._load_sessions() 

251 

252 def get_session( 

253 self, 

254 channel: str, 

255 sender_id: str, 

256 user_id: Optional[int] = None, 

257 prompt_id: Optional[int] = None, 

258 create: bool = True, 

259 ) -> Optional[ChannelSession]: 

260 """ 

261 Get or create a session for a channel/user. 

262 

263 Args: 

264 channel: Channel name (e.g., "telegram", "discord") 

265 sender_id: Unique sender identifier from the channel 

266 user_id: Agent user ID (if known from pairing) 

267 prompt_id: Agent prompt ID (if known from pairing) 

268 create: Whether to create if not exists 

269 

270 Returns: 

271 ChannelSession or None if not found and create=False 

272 """ 

273 key = (channel, sender_id) 

274 

275 # Try cache first 

276 session = self._sessions.get(key) 

277 if session: 

278 # Check timeout 

279 if datetime.now() - session.last_active > self.session_timeout: 

280 # Session expired, remove it 

281 with self._lock: 

282 if key in self._sessions: 

283 del self._sessions[key] 

284 session = None 

285 

286 # Create if not found and allowed 

287 if not session and create: 

288 session = ChannelSession( 

289 channel=channel, 

290 sender_id=sender_id, 

291 user_id=user_id, 

292 prompt_id=prompt_id, 

293 ) 

294 self._sessions.put(key, session) 

295 

296 if self.auto_persist: 

297 self._save_sessions() 

298 

299 # Update user/prompt ID if provided 

300 if session and (user_id is not None or prompt_id is not None): 

301 if user_id is not None: 

302 session.user_id = user_id 

303 if prompt_id is not None: 

304 session.prompt_id = prompt_id 

305 

306 return session 

307 

308 def has_session(self, channel: str, sender_id: str) -> bool: 

309 """Check if a session exists.""" 

310 return self.get_session(channel, sender_id, create=False) is not None 

311 

312 def delete_session(self, channel: str, sender_id: str) -> bool: 

313 """ 

314 Delete a session. 

315 

316 Returns: 

317 True if deleted, False if not found 

318 """ 

319 key = (channel, sender_id) 

320 with self._lock: 

321 if key in self._sessions: 

322 del self._sessions[key] 

323 if self.auto_persist: 

324 self._save_sessions() 

325 return True 

326 return False 

327 

328 def clear_channel_sessions(self, channel: str) -> int: 

329 """ 

330 Delete all sessions for a channel. 

331 

332 Returns: 

333 Number of sessions deleted 

334 """ 

335 to_delete = [ 

336 key for key in self._sessions.keys() 

337 if key[0] == channel 

338 ] 

339 

340 with self._lock: 

341 for key in to_delete: 

342 del self._sessions[key] 

343 

344 if to_delete and self.auto_persist: 

345 self._save_sessions() 

346 

347 return len(to_delete) 

348 

349 def clear_user_sessions(self, user_id: int) -> int: 

350 """ 

351 Delete all sessions for an agent user. 

352 

353 Returns: 

354 Number of sessions deleted 

355 """ 

356 to_delete = [ 

357 key for key, session in self._sessions.items() 

358 if session.user_id == user_id 

359 ] 

360 

361 with self._lock: 

362 for key in to_delete: 

363 del self._sessions[key] 

364 

365 if to_delete and self.auto_persist: 

366 self._save_sessions() 

367 

368 return len(to_delete) 

369 

370 def list_sessions( 

371 self, 

372 channel: Optional[str] = None, 

373 user_id: Optional[int] = None, 

374 ) -> List[ChannelSession]: 

375 """ 

376 List sessions with optional filtering. 

377 

378 Args: 

379 channel: Filter by channel 

380 user_id: Filter by user ID 

381 

382 Returns: 

383 List of matching sessions 

384 """ 

385 sessions = list(self._sessions.values()) 

386 

387 if channel: 

388 sessions = [s for s in sessions if s.channel == channel] 

389 if user_id is not None: 

390 sessions = [s for s in sessions if s.user_id == user_id] 

391 

392 return sessions 

393 

394 def get_session_count(self, channel: Optional[str] = None) -> int: 

395 """Get number of active sessions.""" 

396 if channel: 

397 return len([s for s in self._sessions.values() if s.channel == channel]) 

398 return len(self._sessions) 

399 

400 def cleanup_expired(self) -> int: 

401 """ 

402 Remove expired sessions. 

403 

404 Returns: 

405 Number of sessions removed 

406 """ 

407 now = datetime.now() 

408 to_delete = [ 

409 key for key, session in self._sessions.items() 

410 if now - session.last_active > self.session_timeout 

411 ] 

412 

413 with self._lock: 

414 for key in to_delete: 

415 del self._sessions[key] 

416 

417 if to_delete and self.auto_persist: 

418 self._save_sessions() 

419 

420 return len(to_delete) 

421 

422 def persist(self) -> None: 

423 """Manually persist sessions to storage.""" 

424 self._save_sessions() 

425 

426 def _save_sessions(self) -> None: 

427 """Save sessions to storage.""" 

428 try: 

429 data = { 

430 "sessions": [ 

431 session.to_dict() 

432 for session in self._sessions.values() 

433 ], 

434 "saved_at": datetime.now().isoformat(), 

435 } 

436 

437 os.makedirs(os.path.dirname(self.storage_path), exist_ok=True) 

438 with open(self.storage_path, 'w') as f: 

439 json.dump(data, f, indent=2) 

440 

441 logger.debug(f"Saved {len(data['sessions'])} sessions") 

442 

443 except Exception as e: 

444 logger.error(f"Failed to save sessions: {e}") 

445 

446 def _load_sessions(self) -> None: 

447 """Load sessions from storage.""" 

448 try: 

449 if os.path.exists(self.storage_path): 

450 with open(self.storage_path, 'r') as f: 

451 data = json.load(f) 

452 

453 loaded = 0 

454 for session_data in data.get("sessions", []): 

455 try: 

456 session = ChannelSession.from_dict(session_data) 

457 

458 # Skip expired sessions 

459 if datetime.now() - session.last_active > self.session_timeout: 

460 continue 

461 

462 self._sessions.put(session.session_key, session) 

463 loaded += 1 

464 except Exception as e: 

465 logger.warning(f"Failed to load session: {e}") 

466 

467 logger.info(f"Loaded {loaded} sessions") 

468 

469 except Exception as e: 

470 logger.error(f"Failed to load sessions: {e}") 

471 

472 

473class SessionIsolationMiddleware: 

474 """ 

475 Middleware that provides session isolation for message handling. 

476 

477 Usage: 

478 middleware = SessionIsolationMiddleware(session_manager) 

479 

480 # In message handler: 

481 session = middleware.get_session_for_message(message) 

482 session.add_user_message(message.text) 

483 

484 # Process with LLM using session.context_window 

485 

486 session.add_assistant_message(response) 

487 """ 

488 

489 def __init__( 

490 self, 

491 session_manager: ChannelSessionManager, 

492 pairing_manager: Optional[Any] = None, # PairingManager 

493 ): 

494 self.session_manager = session_manager 

495 self.pairing_manager = pairing_manager 

496 

497 def get_session_for_message(self, message: Any) -> ChannelSession: 

498 """ 

499 Get session for a message, with pairing integration if available. 

500 

501 Args: 

502 message: Message object with channel, sender_id attributes 

503 

504 Returns: 

505 ChannelSession for this message's sender 

506 """ 

507 channel = message.channel 

508 sender_id = message.sender_id 

509 

510 # Get user mapping from pairing if available 

511 user_id = None 

512 prompt_id = None 

513 if self.pairing_manager: 

514 mapping = self.pairing_manager.get_user_mapping(channel, sender_id) 

515 if mapping: 

516 user_id, prompt_id = mapping 

517 

518 return self.session_manager.get_session( 

519 channel=channel, 

520 sender_id=sender_id, 

521 user_id=user_id, 

522 prompt_id=prompt_id, 

523 ) 

524 

525 

526# Singleton instance 

527_session_manager: Optional[ChannelSessionManager] = None 

528 

529 

530def get_session_manager() -> ChannelSessionManager: 

531 """Get or create the global session manager.""" 

532 global _session_manager 

533 if _session_manager is None: 

534 _session_manager = ChannelSessionManager() 

535 return _session_manager