Coverage for integrations / channels / security.py: 97.4%

232 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2DM Pairing Security System 

3 

4Implements secure user-agent linking through pairing codes. 

5Users must pair their messaging accounts with agent accounts 

6before interacting with the system. 

7 

8Features: 

9- Pairing code generation and validation 

10- Time-limited pairing codes 

11- Per-channel user authentication 

12- Session persistence 

13""" 

14 

15from __future__ import annotations 

16 

17import hashlib 

18import hmac 

19import logging 

20import os 

21import secrets 

22import json 

23from datetime import datetime, timedelta 

24from dataclasses import dataclass, field 

25from enum import Enum 

26from typing import Optional, Dict, Tuple, List, Any 

27from pathlib import Path 

28 

29logger = logging.getLogger(__name__) 

30 

31 

32class PairingStatus(Enum): 

33 """Status of a pairing request.""" 

34 PENDING = "pending" 

35 VERIFIED = "verified" 

36 EXPIRED = "expired" 

37 REJECTED = "rejected" 

38 

39 

40@dataclass 

41class PairingCode: 

42 """Represents a pairing code for user verification.""" 

43 code: str 

44 user_id: int 

45 prompt_id: int 

46 created_at: datetime = field(default_factory=datetime.now) 

47 expires_at: Optional[datetime] = None 

48 status: PairingStatus = PairingStatus.PENDING 

49 

50 def __post_init__(self): 

51 if self.expires_at is None: 

52 # Default 15 minute expiration 

53 self.expires_at = self.created_at + timedelta(minutes=15) 

54 

55 @property 

56 def is_expired(self) -> bool: 

57 """Check if the pairing code has expired.""" 

58 return datetime.now() > self.expires_at 

59 

60 @property 

61 def is_valid(self) -> bool: 

62 """Check if the pairing code is still valid.""" 

63 return self.status == PairingStatus.PENDING and not self.is_expired 

64 

65 def to_dict(self) -> Dict[str, Any]: 

66 """Convert to dictionary for serialization.""" 

67 return { 

68 "code": self.code, 

69 "user_id": self.user_id, 

70 "prompt_id": self.prompt_id, 

71 "created_at": self.created_at.isoformat(), 

72 "expires_at": self.expires_at.isoformat() if self.expires_at else None, 

73 "status": self.status.value, 

74 } 

75 

76 @classmethod 

77 def from_dict(cls, data: Dict[str, Any]) -> PairingCode: 

78 """Create from dictionary.""" 

79 return cls( 

80 code=data["code"], 

81 user_id=data["user_id"], 

82 prompt_id=data["prompt_id"], 

83 created_at=datetime.fromisoformat(data["created_at"]), 

84 expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None, 

85 status=PairingStatus(data["status"]), 

86 ) 

87 

88 

89@dataclass 

90class PairedSession: 

91 """Represents a verified pairing between channel user and agent user.""" 

92 channel: str 

93 sender_id: str 

94 user_id: int 

95 prompt_id: int 

96 paired_at: datetime = field(default_factory=datetime.now) 

97 last_active: datetime = field(default_factory=datetime.now) 

98 metadata: Dict[str, Any] = field(default_factory=dict) 

99 

100 @property 

101 def session_key(self) -> Tuple[str, str]: 

102 """Get the session key.""" 

103 return (self.channel, self.sender_id) 

104 

105 def to_dict(self) -> Dict[str, Any]: 

106 """Convert to dictionary for serialization.""" 

107 return { 

108 "channel": self.channel, 

109 "sender_id": self.sender_id, 

110 "user_id": self.user_id, 

111 "prompt_id": self.prompt_id, 

112 "paired_at": self.paired_at.isoformat(), 

113 "last_active": self.last_active.isoformat(), 

114 "metadata": self.metadata, 

115 } 

116 

117 @classmethod 

118 def from_dict(cls, data: Dict[str, Any]) -> PairedSession: 

119 """Create from dictionary.""" 

120 return cls( 

121 channel=data["channel"], 

122 sender_id=data["sender_id"], 

123 user_id=data["user_id"], 

124 prompt_id=data["prompt_id"], 

125 paired_at=datetime.fromisoformat(data["paired_at"]), 

126 last_active=datetime.fromisoformat(data["last_active"]), 

127 metadata=data.get("metadata", {}), 

128 ) 

129 

130 

131class PairingManager: 

132 """ 

133 Manages user pairing for secure channel access. 

134 

135 Usage: 

136 manager = PairingManager() 

137 

138 # Generate a pairing code for a user 

139 code = manager.generate_pairing_code(user_id=123, prompt_id=456) 

140 

141 # User enters code in their DM 

142 session = manager.verify_pairing("telegram", "user123", code) 

143 

144 # Check if user is paired 

145 if manager.is_paired("telegram", "user123"): 

146 user_id, prompt_id = manager.get_user_mapping("telegram", "user123") 

147 """ 

148 

149 def __init__( 

150 self, 

151 code_length: int = 6, 

152 code_expiry_minutes: int = 15, 

153 storage_path: Optional[str] = None, 

154 secret_key: Optional[str] = None, 

155 ): 

156 self.code_length = code_length 

157 self.code_expiry_minutes = code_expiry_minutes 

158 self.storage_path = storage_path or os.path.join( 

159 os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 

160 "agent_data", 

161 "pairing_data.json" 

162 ) 

163 self.secret_key = secret_key or os.getenv("PAIRING_SECRET_KEY", secrets.token_hex(32)) 

164 

165 # In-memory stores 

166 self._pending_codes: Dict[str, PairingCode] = {} # code -> PairingCode 

167 self._paired_sessions: Dict[Tuple[str, str], PairedSession] = {} # (channel, sender) -> Session 

168 

169 # Load persisted data 

170 self._load_state() 

171 

172 def generate_pairing_code( 

173 self, 

174 user_id: int, 

175 prompt_id: int, 

176 expiry_minutes: Optional[int] = None, 

177 ) -> str: 

178 """ 

179 Generate a new pairing code for a user. 

180 

181 Args: 

182 user_id: Agent user ID 

183 prompt_id: Agent prompt ID 

184 expiry_minutes: Optional custom expiry time 

185 

186 Returns: 

187 The generated pairing code 

188 """ 

189 # Generate secure random code 

190 code = self._generate_secure_code() 

191 

192 # Calculate expiry 

193 expiry = timedelta(minutes=expiry_minutes or self.code_expiry_minutes) 

194 expires_at = datetime.now() + expiry 

195 

196 # Create pairing code record 

197 pairing = PairingCode( 

198 code=code, 

199 user_id=user_id, 

200 prompt_id=prompt_id, 

201 expires_at=expires_at, 

202 ) 

203 

204 # Store in pending codes 

205 self._pending_codes[code] = pairing 

206 

207 # Clean up expired codes 

208 self._cleanup_expired_codes() 

209 

210 logger.info(f"Generated pairing code for user {user_id}: {code}") 

211 return code 

212 

213 def verify_pairing( 

214 self, 

215 channel: str, 

216 sender_id: str, 

217 code: str, 

218 metadata: Optional[Dict[str, Any]] = None, 

219 ) -> Optional[PairedSession]: 

220 """ 

221 Verify a pairing code and create session. 

222 

223 Args: 

224 channel: Channel name (e.g., "telegram", "discord") 

225 sender_id: Sender ID from the channel 

226 code: The pairing code entered by user 

227 metadata: Optional metadata to attach to session 

228 

229 Returns: 

230 PairedSession if successful, None otherwise 

231 """ 

232 code = code.upper().strip() 

233 

234 # Check if code exists 

235 pairing = self._pending_codes.get(code) 

236 if not pairing: 

237 logger.warning(f"Invalid pairing code attempted: {code}") 

238 return None 

239 

240 # Check if code is valid 

241 if not pairing.is_valid: 

242 logger.warning(f"Expired/used pairing code attempted: {code}") 

243 return None 

244 

245 # Create paired session 

246 session = PairedSession( 

247 channel=channel, 

248 sender_id=sender_id, 

249 user_id=pairing.user_id, 

250 prompt_id=pairing.prompt_id, 

251 metadata=metadata or {}, 

252 ) 

253 

254 # Mark code as used 

255 pairing.status = PairingStatus.VERIFIED 

256 del self._pending_codes[code] 

257 

258 # Store session 

259 self._paired_sessions[session.session_key] = session 

260 

261 # Persist state 

262 self._save_state() 

263 

264 logger.info(f"Paired {channel}:{sender_id} with user {pairing.user_id}") 

265 return session 

266 

267 def is_paired(self, channel: str, sender_id: str) -> bool: 

268 """Check if a channel user is paired.""" 

269 return (channel, sender_id) in self._paired_sessions 

270 

271 def get_user_mapping( 

272 self, 

273 channel: str, 

274 sender_id: str, 

275 ) -> Optional[Tuple[int, int]]: 

276 """ 

277 Get the user mapping for a paired channel user. 

278 

279 Returns: 

280 Tuple of (user_id, prompt_id) if paired, None otherwise 

281 """ 

282 session = self._paired_sessions.get((channel, sender_id)) 

283 if session: 

284 # Update last active 

285 session.last_active = datetime.now() 

286 return (session.user_id, session.prompt_id) 

287 return None 

288 

289 def get_session(self, channel: str, sender_id: str) -> Optional[PairedSession]: 

290 """Get the full session for a paired user.""" 

291 return self._paired_sessions.get((channel, sender_id)) 

292 

293 def unpair(self, channel: str, sender_id: str) -> bool: 

294 """ 

295 Remove a pairing. 

296 

297 Returns: 

298 True if unpaired, False if not found 

299 """ 

300 key = (channel, sender_id) 

301 if key in self._paired_sessions: 

302 del self._paired_sessions[key] 

303 self._save_state() 

304 logger.info(f"Unpaired {channel}:{sender_id}") 

305 return True 

306 return False 

307 

308 def unpair_user(self, user_id: int) -> int: 

309 """ 

310 Remove all pairings for a user. 

311 

312 Returns: 

313 Number of pairings removed 

314 """ 

315 to_remove = [ 

316 key for key, session in self._paired_sessions.items() 

317 if session.user_id == user_id 

318 ] 

319 

320 for key in to_remove: 

321 del self._paired_sessions[key] 

322 

323 if to_remove: 

324 self._save_state() 

325 logger.info(f"Unpaired {len(to_remove)} sessions for user {user_id}") 

326 

327 return len(to_remove) 

328 

329 def list_user_pairings(self, user_id: int) -> List[PairedSession]: 

330 """List all pairings for a user.""" 

331 return [ 

332 session for session in self._paired_sessions.values() 

333 if session.user_id == user_id 

334 ] 

335 

336 def _generate_secure_code(self) -> str: 

337 """Generate a secure random pairing code.""" 

338 # Generate alphanumeric code (excluding ambiguous chars) 

339 alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" 

340 code = ''.join(secrets.choice(alphabet) for _ in range(self.code_length)) 

341 

342 # Add HMAC signature component for extra security 

343 signature = hmac.new( 

344 self.secret_key.encode(), 

345 code.encode(), 

346 hashlib.sha256 

347 ).hexdigest()[:4].upper() 

348 

349 return f"{code}-{signature}" 

350 

351 def _cleanup_expired_codes(self) -> None: 

352 """Remove expired pairing codes.""" 

353 expired = [ 

354 code for code, pairing in self._pending_codes.items() 

355 if pairing.is_expired 

356 ] 

357 for code in expired: 

358 del self._pending_codes[code] 

359 

360 def _save_state(self) -> None: 

361 """Persist paired sessions to storage.""" 

362 try: 

363 data = { 

364 "sessions": [ 

365 session.to_dict() 

366 for session in self._paired_sessions.values() 

367 ], 

368 "pending_codes": [ 

369 code.to_dict() 

370 for code in self._pending_codes.values() 

371 ], 

372 } 

373 

374 os.makedirs(os.path.dirname(self.storage_path), exist_ok=True) 

375 with open(self.storage_path, 'w') as f: 

376 json.dump(data, f, indent=2) 

377 

378 except Exception as e: 

379 logger.error(f"Failed to save pairing state: {e}") 

380 

381 def _load_state(self) -> None: 

382 """Load persisted paired sessions.""" 

383 try: 

384 if os.path.exists(self.storage_path): 

385 with open(self.storage_path, 'r') as f: 

386 data = json.load(f) 

387 

388 # Load sessions 

389 for session_data in data.get("sessions", []): 

390 session = PairedSession.from_dict(session_data) 

391 self._paired_sessions[session.session_key] = session 

392 

393 # Load pending codes (skip expired) 

394 for code_data in data.get("pending_codes", []): 

395 code = PairingCode.from_dict(code_data) 

396 if code.is_valid: 

397 self._pending_codes[code.code] = code 

398 

399 logger.info(f"Loaded {len(self._paired_sessions)} paired sessions") 

400 

401 except Exception as e: 

402 logger.error(f"Failed to load pairing state: {e}") 

403 

404 

405class PairingMiddleware: 

406 """ 

407 Middleware that enforces pairing for channel messages. 

408 

409 Usage: 

410 middleware = PairingMiddleware(pairing_manager) 

411 

412 # In message handler: 

413 result = middleware.check_pairing(message) 

414 if not result.is_paired: 

415 await send_pairing_instructions(result.instructions) 

416 return 

417 

418 # Process message with result.user_id, result.prompt_id 

419 """ 

420 

421 @dataclass 

422 class CheckResult: 

423 """Result of pairing check.""" 

424 is_paired: bool 

425 user_id: Optional[int] = None 

426 prompt_id: Optional[int] = None 

427 instructions: Optional[str] = None 

428 

429 def __init__( 

430 self, 

431 manager: PairingManager, 

432 require_pairing: bool = True, 

433 default_user_id: Optional[int] = None, 

434 default_prompt_id: Optional[int] = None, 

435 ): 

436 self.manager = manager 

437 self.require_pairing = require_pairing 

438 self.default_user_id = default_user_id 

439 self.default_prompt_id = default_prompt_id 

440 

441 def check_pairing(self, channel: str, sender_id: str, text: str) -> CheckResult: 

442 """ 

443 Check if sender is paired and handle pairing flow. 

444 

445 Args: 

446 channel: Channel name 

447 sender_id: Sender ID 

448 text: Message text (to check for pairing code) 

449 

450 Returns: 

451 CheckResult with pairing status 

452 """ 

453 # Check if already paired 

454 mapping = self.manager.get_user_mapping(channel, sender_id) 

455 if mapping: 

456 return self.CheckResult( 

457 is_paired=True, 

458 user_id=mapping[0], 

459 prompt_id=mapping[1], 

460 ) 

461 

462 # Check if message contains pairing code 

463 if self._looks_like_pairing_code(text): 

464 session = self.manager.verify_pairing(channel, sender_id, text.strip()) 

465 if session: 

466 return self.CheckResult( 

467 is_paired=True, 

468 user_id=session.user_id, 

469 prompt_id=session.prompt_id, 

470 instructions="Pairing successful! You can now chat with me.", 

471 ) 

472 else: 

473 return self.CheckResult( 

474 is_paired=False, 

475 instructions="Invalid or expired pairing code. Please get a new code.", 

476 ) 

477 

478 # Not paired 

479 if self.require_pairing: 

480 return self.CheckResult( 

481 is_paired=False, 

482 instructions=( 

483 "Welcome! To use this bot, you need to pair your account.\n\n" 

484 "1. Go to the web interface and get a pairing code\n" 

485 "2. Send the code here (e.g., ABC123-XYZ1)\n\n" 

486 "This links your account securely." 

487 ), 

488 ) 

489 else: 

490 # Use defaults if pairing not required 

491 return self.CheckResult( 

492 is_paired=True, 

493 user_id=self.default_user_id, 

494 prompt_id=self.default_prompt_id, 

495 ) 

496 

497 def _looks_like_pairing_code(self, text: str) -> bool: 

498 """Check if text looks like a pairing code.""" 

499 text = text.strip().upper() 

500 # Code format: XXXXXX-YYYY (6 chars + hyphen + 4 chars) 

501 if len(text) >= 10 and '-' in text: 

502 parts = text.split('-') 

503 if len(parts) == 2 and all(p.isalnum() for p in parts): 

504 return True 

505 return False 

506 

507 

508# Singleton instance 

509_pairing_manager: Optional[PairingManager] = None 

510 

511 

512def get_pairing_manager() -> PairingManager: 

513 """Get or create the global pairing manager.""" 

514 global _pairing_manager 

515 if _pairing_manager is None: 

516 _pairing_manager = PairingManager() 

517 return _pairing_manager 

518 

519 

520# ──────────────────────────────────────────────────────────────────── 

521# OAuth click-through state manager (PR O) 

522# ──────────────────────────────────────────────────────────────────── 

523# 

524# CSRF + binding identity for the /oauth/<channel_type>/start → 

525# /oauth/<channel_type>/callback round-trip. The provider redirects 

526# back to us with a `state` query param; we use that to recover: 

527# - the user_id (so we can write the binding to the right account) 

528# - the channel_type (so we know which adapter to call) 

529# - the PKCE code_verifier (for providers that require PKCE) 

530# 

531# Stays in-process — fine for single-instance flat/regional deploys. 

532# Multi-instance central deploys would need Redis-backed state, but 

533# central isn't an OAuth callback target by architecture (#275). 

534 

535 

536@dataclass 

537class _OAuthState: 

538 """Pending OAuth state record, expires after STATE_TTL_MIN.""" 

539 state: str 

540 user_id: int 

541 channel_type: str 

542 code_verifier: Optional[str] = None # PKCE 

543 extra: Dict[str, Any] = field(default_factory=dict) 

544 created_at: datetime = field(default_factory=datetime.now) 

545 

546 @property 

547 def is_expired(self) -> bool: 

548 return datetime.now() > self.created_at + timedelta( 

549 minutes=OAuthStateManager.STATE_TTL_MIN 

550 ) 

551 

552 

553class OAuthStateManager: 

554 """Issues + verifies the ``state`` parameter for OAuth click-through. 

555 

556 Mirrors PairingManager's lifecycle: in-memory dict, periodic eviction, 

557 one-shot verification (verify consumes the state — replay-protected). 

558 """ 

559 

560 STATE_TTL_MIN = 10 # OAuth round-trip should complete in seconds 

561 

562 def __init__(self): 

563 self._states: Dict[str, _OAuthState] = {} 

564 

565 def generate_state( 

566 self, 

567 user_id: int, 

568 channel_type: str, 

569 code_verifier: Optional[str] = None, 

570 **extra: Any, 

571 ) -> str: 

572 """Generate a fresh state token. Use return value as the 

573 ``state`` param in the authorize URL. ``code_verifier`` is the 

574 PKCE secret for providers that require it; pass None otherwise. 

575 """ 

576 self._evict_expired() 

577 # 32 random bytes → 43 url-safe chars. Wide enough that 

578 # birthday-collision is irrelevant inside a 10-minute window. 

579 state = secrets.token_urlsafe(32) 

580 self._states[state] = _OAuthState( 

581 state=state, 

582 user_id=user_id, 

583 channel_type=channel_type, 

584 code_verifier=code_verifier, 

585 extra=dict(extra), 

586 ) 

587 logger.info( 

588 "Generated OAuth state for user_id=%s channel=%s (pending=%d)", 

589 user_id, channel_type, len(self._states), 

590 ) 

591 return state 

592 

593 def verify_state(self, state: str) -> Optional[Dict[str, Any]]: 

594 """Verify + consume a state token. Returns the stored context 

595 dict on success, None if the state is missing / expired / already 

596 consumed. Single-use: a verified state cannot be re-used. 

597 """ 

598 if not state or not isinstance(state, str): 

599 return None 

600 record = self._states.pop(state, None) 

601 if record is None: 

602 logger.warning("OAuth state verify failed: unknown / replayed token") 

603 return None 

604 if record.is_expired: 

605 logger.warning( 

606 "OAuth state verify failed: expired (created %s)", 

607 record.created_at.isoformat(), 

608 ) 

609 return None 

610 return { 

611 'user_id': record.user_id, 

612 'channel_type': record.channel_type, 

613 'code_verifier': record.code_verifier, 

614 'extra': record.extra, 

615 } 

616 

617 def _evict_expired(self) -> None: 

618 """Remove expired records. O(N) but N is tiny (≤ active OAuth 

619 flows ≤ active users * 1). No need for a background task. 

620 """ 

621 expired = [s for s, r in self._states.items() if r.is_expired] 

622 for s in expired: 

623 del self._states[s] 

624 if expired: 

625 logger.debug("Evicted %d expired OAuth states", len(expired)) 

626 

627 

628_oauth_state_manager: Optional[OAuthStateManager] = None 

629 

630 

631def get_oauth_state_manager() -> OAuthStateManager: 

632 """Get or create the global OAuth state manager.""" 

633 global _oauth_state_manager 

634 if _oauth_state_manager is None: 

635 _oauth_state_manager = OAuthStateManager() 

636 return _oauth_state_manager 

637 

638 

639def generate_pkce_pair() -> Tuple[str, str]: 

640 """Generate a PKCE (code_verifier, code_challenge) pair. 

641 

642 code_verifier: 43-char url-safe random string (RFC 7636 §4.1). 

643 code_challenge: BASE64URL(SHA256(code_verifier)) without padding (S256). 

644 

645 Used by providers that set ``oauth_uses_pkce: True`` in metadata — 

646 Google, Microsoft, Twitter X v2. 

647 """ 

648 import base64 

649 code_verifier = secrets.token_urlsafe(32) # ≥43 chars 

650 digest = hashlib.sha256(code_verifier.encode('ascii')).digest() 

651 code_challenge = base64.urlsafe_b64encode(digest).rstrip(b'=').decode('ascii') 

652 return code_verifier, code_challenge