Coverage for integrations / channels / security.py: 97.4%
232 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2DM Pairing Security System
4Implements secure user-agent linking through pairing codes.
5Users must pair their messaging accounts with agent accounts
6before interacting with the system.
8Features:
9- Pairing code generation and validation
10- Time-limited pairing codes
11- Per-channel user authentication
12- Session persistence
13"""
15from __future__ import annotations
17import hashlib
18import hmac
19import logging
20import os
21import secrets
22import json
23from datetime import datetime, timedelta
24from dataclasses import dataclass, field
25from enum import Enum
26from typing import Optional, Dict, Tuple, List, Any
27from pathlib import Path
29logger = logging.getLogger(__name__)
32class PairingStatus(Enum):
33 """Status of a pairing request."""
34 PENDING = "pending"
35 VERIFIED = "verified"
36 EXPIRED = "expired"
37 REJECTED = "rejected"
40@dataclass
41class PairingCode:
42 """Represents a pairing code for user verification."""
43 code: str
44 user_id: int
45 prompt_id: int
46 created_at: datetime = field(default_factory=datetime.now)
47 expires_at: Optional[datetime] = None
48 status: PairingStatus = PairingStatus.PENDING
50 def __post_init__(self):
51 if self.expires_at is None:
52 # Default 15 minute expiration
53 self.expires_at = self.created_at + timedelta(minutes=15)
55 @property
56 def is_expired(self) -> bool:
57 """Check if the pairing code has expired."""
58 return datetime.now() > self.expires_at
60 @property
61 def is_valid(self) -> bool:
62 """Check if the pairing code is still valid."""
63 return self.status == PairingStatus.PENDING and not self.is_expired
65 def to_dict(self) -> Dict[str, Any]:
66 """Convert to dictionary for serialization."""
67 return {
68 "code": self.code,
69 "user_id": self.user_id,
70 "prompt_id": self.prompt_id,
71 "created_at": self.created_at.isoformat(),
72 "expires_at": self.expires_at.isoformat() if self.expires_at else None,
73 "status": self.status.value,
74 }
76 @classmethod
77 def from_dict(cls, data: Dict[str, Any]) -> PairingCode:
78 """Create from dictionary."""
79 return cls(
80 code=data["code"],
81 user_id=data["user_id"],
82 prompt_id=data["prompt_id"],
83 created_at=datetime.fromisoformat(data["created_at"]),
84 expires_at=datetime.fromisoformat(data["expires_at"]) if data.get("expires_at") else None,
85 status=PairingStatus(data["status"]),
86 )
89@dataclass
90class PairedSession:
91 """Represents a verified pairing between channel user and agent user."""
92 channel: str
93 sender_id: str
94 user_id: int
95 prompt_id: int
96 paired_at: datetime = field(default_factory=datetime.now)
97 last_active: datetime = field(default_factory=datetime.now)
98 metadata: Dict[str, Any] = field(default_factory=dict)
100 @property
101 def session_key(self) -> Tuple[str, str]:
102 """Get the session key."""
103 return (self.channel, self.sender_id)
105 def to_dict(self) -> Dict[str, Any]:
106 """Convert to dictionary for serialization."""
107 return {
108 "channel": self.channel,
109 "sender_id": self.sender_id,
110 "user_id": self.user_id,
111 "prompt_id": self.prompt_id,
112 "paired_at": self.paired_at.isoformat(),
113 "last_active": self.last_active.isoformat(),
114 "metadata": self.metadata,
115 }
117 @classmethod
118 def from_dict(cls, data: Dict[str, Any]) -> PairedSession:
119 """Create from dictionary."""
120 return cls(
121 channel=data["channel"],
122 sender_id=data["sender_id"],
123 user_id=data["user_id"],
124 prompt_id=data["prompt_id"],
125 paired_at=datetime.fromisoformat(data["paired_at"]),
126 last_active=datetime.fromisoformat(data["last_active"]),
127 metadata=data.get("metadata", {}),
128 )
131class PairingManager:
132 """
133 Manages user pairing for secure channel access.
135 Usage:
136 manager = PairingManager()
138 # Generate a pairing code for a user
139 code = manager.generate_pairing_code(user_id=123, prompt_id=456)
141 # User enters code in their DM
142 session = manager.verify_pairing("telegram", "user123", code)
144 # Check if user is paired
145 if manager.is_paired("telegram", "user123"):
146 user_id, prompt_id = manager.get_user_mapping("telegram", "user123")
147 """
149 def __init__(
150 self,
151 code_length: int = 6,
152 code_expiry_minutes: int = 15,
153 storage_path: Optional[str] = None,
154 secret_key: Optional[str] = None,
155 ):
156 self.code_length = code_length
157 self.code_expiry_minutes = code_expiry_minutes
158 self.storage_path = storage_path or os.path.join(
159 os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
160 "agent_data",
161 "pairing_data.json"
162 )
163 self.secret_key = secret_key or os.getenv("PAIRING_SECRET_KEY", secrets.token_hex(32))
165 # In-memory stores
166 self._pending_codes: Dict[str, PairingCode] = {} # code -> PairingCode
167 self._paired_sessions: Dict[Tuple[str, str], PairedSession] = {} # (channel, sender) -> Session
169 # Load persisted data
170 self._load_state()
172 def generate_pairing_code(
173 self,
174 user_id: int,
175 prompt_id: int,
176 expiry_minutes: Optional[int] = None,
177 ) -> str:
178 """
179 Generate a new pairing code for a user.
181 Args:
182 user_id: Agent user ID
183 prompt_id: Agent prompt ID
184 expiry_minutes: Optional custom expiry time
186 Returns:
187 The generated pairing code
188 """
189 # Generate secure random code
190 code = self._generate_secure_code()
192 # Calculate expiry
193 expiry = timedelta(minutes=expiry_minutes or self.code_expiry_minutes)
194 expires_at = datetime.now() + expiry
196 # Create pairing code record
197 pairing = PairingCode(
198 code=code,
199 user_id=user_id,
200 prompt_id=prompt_id,
201 expires_at=expires_at,
202 )
204 # Store in pending codes
205 self._pending_codes[code] = pairing
207 # Clean up expired codes
208 self._cleanup_expired_codes()
210 logger.info(f"Generated pairing code for user {user_id}: {code}")
211 return code
213 def verify_pairing(
214 self,
215 channel: str,
216 sender_id: str,
217 code: str,
218 metadata: Optional[Dict[str, Any]] = None,
219 ) -> Optional[PairedSession]:
220 """
221 Verify a pairing code and create session.
223 Args:
224 channel: Channel name (e.g., "telegram", "discord")
225 sender_id: Sender ID from the channel
226 code: The pairing code entered by user
227 metadata: Optional metadata to attach to session
229 Returns:
230 PairedSession if successful, None otherwise
231 """
232 code = code.upper().strip()
234 # Check if code exists
235 pairing = self._pending_codes.get(code)
236 if not pairing:
237 logger.warning(f"Invalid pairing code attempted: {code}")
238 return None
240 # Check if code is valid
241 if not pairing.is_valid:
242 logger.warning(f"Expired/used pairing code attempted: {code}")
243 return None
245 # Create paired session
246 session = PairedSession(
247 channel=channel,
248 sender_id=sender_id,
249 user_id=pairing.user_id,
250 prompt_id=pairing.prompt_id,
251 metadata=metadata or {},
252 )
254 # Mark code as used
255 pairing.status = PairingStatus.VERIFIED
256 del self._pending_codes[code]
258 # Store session
259 self._paired_sessions[session.session_key] = session
261 # Persist state
262 self._save_state()
264 logger.info(f"Paired {channel}:{sender_id} with user {pairing.user_id}")
265 return session
267 def is_paired(self, channel: str, sender_id: str) -> bool:
268 """Check if a channel user is paired."""
269 return (channel, sender_id) in self._paired_sessions
271 def get_user_mapping(
272 self,
273 channel: str,
274 sender_id: str,
275 ) -> Optional[Tuple[int, int]]:
276 """
277 Get the user mapping for a paired channel user.
279 Returns:
280 Tuple of (user_id, prompt_id) if paired, None otherwise
281 """
282 session = self._paired_sessions.get((channel, sender_id))
283 if session:
284 # Update last active
285 session.last_active = datetime.now()
286 return (session.user_id, session.prompt_id)
287 return None
289 def get_session(self, channel: str, sender_id: str) -> Optional[PairedSession]:
290 """Get the full session for a paired user."""
291 return self._paired_sessions.get((channel, sender_id))
293 def unpair(self, channel: str, sender_id: str) -> bool:
294 """
295 Remove a pairing.
297 Returns:
298 True if unpaired, False if not found
299 """
300 key = (channel, sender_id)
301 if key in self._paired_sessions:
302 del self._paired_sessions[key]
303 self._save_state()
304 logger.info(f"Unpaired {channel}:{sender_id}")
305 return True
306 return False
308 def unpair_user(self, user_id: int) -> int:
309 """
310 Remove all pairings for a user.
312 Returns:
313 Number of pairings removed
314 """
315 to_remove = [
316 key for key, session in self._paired_sessions.items()
317 if session.user_id == user_id
318 ]
320 for key in to_remove:
321 del self._paired_sessions[key]
323 if to_remove:
324 self._save_state()
325 logger.info(f"Unpaired {len(to_remove)} sessions for user {user_id}")
327 return len(to_remove)
329 def list_user_pairings(self, user_id: int) -> List[PairedSession]:
330 """List all pairings for a user."""
331 return [
332 session for session in self._paired_sessions.values()
333 if session.user_id == user_id
334 ]
336 def _generate_secure_code(self) -> str:
337 """Generate a secure random pairing code."""
338 # Generate alphanumeric code (excluding ambiguous chars)
339 alphabet = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
340 code = ''.join(secrets.choice(alphabet) for _ in range(self.code_length))
342 # Add HMAC signature component for extra security
343 signature = hmac.new(
344 self.secret_key.encode(),
345 code.encode(),
346 hashlib.sha256
347 ).hexdigest()[:4].upper()
349 return f"{code}-{signature}"
351 def _cleanup_expired_codes(self) -> None:
352 """Remove expired pairing codes."""
353 expired = [
354 code for code, pairing in self._pending_codes.items()
355 if pairing.is_expired
356 ]
357 for code in expired:
358 del self._pending_codes[code]
360 def _save_state(self) -> None:
361 """Persist paired sessions to storage."""
362 try:
363 data = {
364 "sessions": [
365 session.to_dict()
366 for session in self._paired_sessions.values()
367 ],
368 "pending_codes": [
369 code.to_dict()
370 for code in self._pending_codes.values()
371 ],
372 }
374 os.makedirs(os.path.dirname(self.storage_path), exist_ok=True)
375 with open(self.storage_path, 'w') as f:
376 json.dump(data, f, indent=2)
378 except Exception as e:
379 logger.error(f"Failed to save pairing state: {e}")
381 def _load_state(self) -> None:
382 """Load persisted paired sessions."""
383 try:
384 if os.path.exists(self.storage_path):
385 with open(self.storage_path, 'r') as f:
386 data = json.load(f)
388 # Load sessions
389 for session_data in data.get("sessions", []):
390 session = PairedSession.from_dict(session_data)
391 self._paired_sessions[session.session_key] = session
393 # Load pending codes (skip expired)
394 for code_data in data.get("pending_codes", []):
395 code = PairingCode.from_dict(code_data)
396 if code.is_valid:
397 self._pending_codes[code.code] = code
399 logger.info(f"Loaded {len(self._paired_sessions)} paired sessions")
401 except Exception as e:
402 logger.error(f"Failed to load pairing state: {e}")
405class PairingMiddleware:
406 """
407 Middleware that enforces pairing for channel messages.
409 Usage:
410 middleware = PairingMiddleware(pairing_manager)
412 # In message handler:
413 result = middleware.check_pairing(message)
414 if not result.is_paired:
415 await send_pairing_instructions(result.instructions)
416 return
418 # Process message with result.user_id, result.prompt_id
419 """
421 @dataclass
422 class CheckResult:
423 """Result of pairing check."""
424 is_paired: bool
425 user_id: Optional[int] = None
426 prompt_id: Optional[int] = None
427 instructions: Optional[str] = None
429 def __init__(
430 self,
431 manager: PairingManager,
432 require_pairing: bool = True,
433 default_user_id: Optional[int] = None,
434 default_prompt_id: Optional[int] = None,
435 ):
436 self.manager = manager
437 self.require_pairing = require_pairing
438 self.default_user_id = default_user_id
439 self.default_prompt_id = default_prompt_id
441 def check_pairing(self, channel: str, sender_id: str, text: str) -> CheckResult:
442 """
443 Check if sender is paired and handle pairing flow.
445 Args:
446 channel: Channel name
447 sender_id: Sender ID
448 text: Message text (to check for pairing code)
450 Returns:
451 CheckResult with pairing status
452 """
453 # Check if already paired
454 mapping = self.manager.get_user_mapping(channel, sender_id)
455 if mapping:
456 return self.CheckResult(
457 is_paired=True,
458 user_id=mapping[0],
459 prompt_id=mapping[1],
460 )
462 # Check if message contains pairing code
463 if self._looks_like_pairing_code(text):
464 session = self.manager.verify_pairing(channel, sender_id, text.strip())
465 if session:
466 return self.CheckResult(
467 is_paired=True,
468 user_id=session.user_id,
469 prompt_id=session.prompt_id,
470 instructions="Pairing successful! You can now chat with me.",
471 )
472 else:
473 return self.CheckResult(
474 is_paired=False,
475 instructions="Invalid or expired pairing code. Please get a new code.",
476 )
478 # Not paired
479 if self.require_pairing:
480 return self.CheckResult(
481 is_paired=False,
482 instructions=(
483 "Welcome! To use this bot, you need to pair your account.\n\n"
484 "1. Go to the web interface and get a pairing code\n"
485 "2. Send the code here (e.g., ABC123-XYZ1)\n\n"
486 "This links your account securely."
487 ),
488 )
489 else:
490 # Use defaults if pairing not required
491 return self.CheckResult(
492 is_paired=True,
493 user_id=self.default_user_id,
494 prompt_id=self.default_prompt_id,
495 )
497 def _looks_like_pairing_code(self, text: str) -> bool:
498 """Check if text looks like a pairing code."""
499 text = text.strip().upper()
500 # Code format: XXXXXX-YYYY (6 chars + hyphen + 4 chars)
501 if len(text) >= 10 and '-' in text:
502 parts = text.split('-')
503 if len(parts) == 2 and all(p.isalnum() for p in parts):
504 return True
505 return False
508# Singleton instance
509_pairing_manager: Optional[PairingManager] = None
512def get_pairing_manager() -> PairingManager:
513 """Get or create the global pairing manager."""
514 global _pairing_manager
515 if _pairing_manager is None:
516 _pairing_manager = PairingManager()
517 return _pairing_manager
520# ────────────────────────────────────────────────────────────────────
521# OAuth click-through state manager (PR O)
522# ────────────────────────────────────────────────────────────────────
523#
524# CSRF + binding identity for the /oauth/<channel_type>/start →
525# /oauth/<channel_type>/callback round-trip. The provider redirects
526# back to us with a `state` query param; we use that to recover:
527# - the user_id (so we can write the binding to the right account)
528# - the channel_type (so we know which adapter to call)
529# - the PKCE code_verifier (for providers that require PKCE)
530#
531# Stays in-process — fine for single-instance flat/regional deploys.
532# Multi-instance central deploys would need Redis-backed state, but
533# central isn't an OAuth callback target by architecture (#275).
536@dataclass
537class _OAuthState:
538 """Pending OAuth state record, expires after STATE_TTL_MIN."""
539 state: str
540 user_id: int
541 channel_type: str
542 code_verifier: Optional[str] = None # PKCE
543 extra: Dict[str, Any] = field(default_factory=dict)
544 created_at: datetime = field(default_factory=datetime.now)
546 @property
547 def is_expired(self) -> bool:
548 return datetime.now() > self.created_at + timedelta(
549 minutes=OAuthStateManager.STATE_TTL_MIN
550 )
553class OAuthStateManager:
554 """Issues + verifies the ``state`` parameter for OAuth click-through.
556 Mirrors PairingManager's lifecycle: in-memory dict, periodic eviction,
557 one-shot verification (verify consumes the state — replay-protected).
558 """
560 STATE_TTL_MIN = 10 # OAuth round-trip should complete in seconds
562 def __init__(self):
563 self._states: Dict[str, _OAuthState] = {}
565 def generate_state(
566 self,
567 user_id: int,
568 channel_type: str,
569 code_verifier: Optional[str] = None,
570 **extra: Any,
571 ) -> str:
572 """Generate a fresh state token. Use return value as the
573 ``state`` param in the authorize URL. ``code_verifier`` is the
574 PKCE secret for providers that require it; pass None otherwise.
575 """
576 self._evict_expired()
577 # 32 random bytes → 43 url-safe chars. Wide enough that
578 # birthday-collision is irrelevant inside a 10-minute window.
579 state = secrets.token_urlsafe(32)
580 self._states[state] = _OAuthState(
581 state=state,
582 user_id=user_id,
583 channel_type=channel_type,
584 code_verifier=code_verifier,
585 extra=dict(extra),
586 )
587 logger.info(
588 "Generated OAuth state for user_id=%s channel=%s (pending=%d)",
589 user_id, channel_type, len(self._states),
590 )
591 return state
593 def verify_state(self, state: str) -> Optional[Dict[str, Any]]:
594 """Verify + consume a state token. Returns the stored context
595 dict on success, None if the state is missing / expired / already
596 consumed. Single-use: a verified state cannot be re-used.
597 """
598 if not state or not isinstance(state, str):
599 return None
600 record = self._states.pop(state, None)
601 if record is None:
602 logger.warning("OAuth state verify failed: unknown / replayed token")
603 return None
604 if record.is_expired:
605 logger.warning(
606 "OAuth state verify failed: expired (created %s)",
607 record.created_at.isoformat(),
608 )
609 return None
610 return {
611 'user_id': record.user_id,
612 'channel_type': record.channel_type,
613 'code_verifier': record.code_verifier,
614 'extra': record.extra,
615 }
617 def _evict_expired(self) -> None:
618 """Remove expired records. O(N) but N is tiny (≤ active OAuth
619 flows ≤ active users * 1). No need for a background task.
620 """
621 expired = [s for s, r in self._states.items() if r.is_expired]
622 for s in expired:
623 del self._states[s]
624 if expired:
625 logger.debug("Evicted %d expired OAuth states", len(expired))
628_oauth_state_manager: Optional[OAuthStateManager] = None
631def get_oauth_state_manager() -> OAuthStateManager:
632 """Get or create the global OAuth state manager."""
633 global _oauth_state_manager
634 if _oauth_state_manager is None:
635 _oauth_state_manager = OAuthStateManager()
636 return _oauth_state_manager
639def generate_pkce_pair() -> Tuple[str, str]:
640 """Generate a PKCE (code_verifier, code_challenge) pair.
642 code_verifier: 43-char url-safe random string (RFC 7636 §4.1).
643 code_challenge: BASE64URL(SHA256(code_verifier)) without padding (S256).
645 Used by providers that set ``oauth_uses_pkce: True`` in metadata —
646 Google, Microsoft, Twitter X v2.
647 """
648 import base64
649 code_verifier = secrets.token_urlsafe(32) # ≥43 chars
650 digest = hashlib.sha256(code_verifier.encode('ascii')).digest()
651 code_challenge = base64.urlsafe_b64encode(digest).rstrip(b'=').decode('ascii')
652 return code_verifier, code_challenge