Coverage for integrations / channels / web_adapter.py: 28.3%
325 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Web/Browser Channel Adapter
4Implements web-based messaging with WebSocket and REST API support.
5Designed for Docker-compatible deployments with browser clients.
7Features:
8- WebSocket real-time communication
9- REST API fallback for polling
10- Session management
11- File upload/download
12- Typing indicators
13- Read receipts
14- Multi-tab support
16This adapter creates a WebSocket server that browser clients can connect to.
17It also provides REST endpoints for clients that don't support WebSockets.
18"""
20from __future__ import annotations
22import asyncio
23import logging
24import os
25import uuid
26import json
27import time
28import base64
29import mimetypes
30from typing import Optional, List, Dict, Any, Set
31from datetime import datetime, timedelta
32from pathlib import Path
33from dataclasses import dataclass, field
35try:
36 import aiohttp
37 from aiohttp import web, WSMsgType
38 HAS_AIOHTTP = True
39except ImportError:
40 HAS_AIOHTTP = False
42from .base import (
43 ChannelAdapter,
44 ChannelConfig,
45 ChannelStatus,
46 Message,
47 MessageType,
48 MediaAttachment,
49 SendResult,
50 ChannelConnectionError,
51 ChannelSendError,
52 ChannelRateLimitError,
53)
55logger = logging.getLogger(__name__)
58@dataclass
59class WebSession:
60 """Represents a connected web client session."""
61 session_id: str
62 user_id: str
63 user_name: Optional[str] = None
64 connected_at: datetime = field(default_factory=datetime.now)
65 last_activity: datetime = field(default_factory=datetime.now)
66 websockets: Set[web.WebSocketResponse] = field(default_factory=set)
67 metadata: Dict[str, Any] = field(default_factory=dict)
69 @property
70 def is_connected(self) -> bool:
71 return len(self.websockets) > 0
73 def touch(self) -> None:
74 """Update last activity timestamp."""
75 self.last_activity = datetime.now()
78@dataclass
79class PendingMessage:
80 """Message waiting to be delivered to a disconnected client."""
81 id: str
82 session_id: str
83 data: Dict[str, Any]
84 created_at: datetime = field(default_factory=datetime.now)
85 expires_at: datetime = field(default_factory=lambda: datetime.now() + timedelta(hours=24))
88class WebAdapter(ChannelAdapter):
89 """
90 Web/Browser channel adapter with WebSocket and REST API.
92 Usage:
93 config = ChannelConfig(
94 extra={
95 "host": "0.0.0.0",
96 "port": 8765,
97 "upload_dir": "/tmp/uploads",
98 "cors_origins": ["*"],
99 }
100 )
101 adapter = WebAdapter(config)
102 adapter.on_message(my_handler)
103 await adapter.start()
105 Browser client example:
106 const ws = new WebSocket('ws://localhost:8765/ws?session_id=xxx&user_id=yyy');
107 ws.onmessage = (event) => {
108 const data = JSON.parse(event.data);
109 console.log('Received:', data);
110 };
111 ws.send(JSON.stringify({type: 'message', text: 'Hello!'}));
112 """
114 def __init__(self, config: ChannelConfig):
115 if not HAS_AIOHTTP:
116 raise ImportError(
117 "aiohttp not installed. "
118 "Install with: pip install aiohttp"
119 )
121 super().__init__(config)
122 self._host = config.extra.get("host", "0.0.0.0")
123 self._port = config.extra.get("port", 8765)
124 self._upload_dir = Path(config.extra.get("upload_dir", "/tmp/web_adapter_uploads"))
125 self._cors_origins = config.extra.get("cors_origins", ["*"])
126 self._session_timeout = config.extra.get("session_timeout", 3600) # 1 hour
128 self._app: Optional[web.Application] = None
129 self._runner: Optional[web.AppRunner] = None
130 self._site: Optional[web.TCPSite] = None
132 self._sessions: Dict[str, WebSession] = {}
133 self._pending_messages: Dict[str, List[PendingMessage]] = {}
134 self._read_receipts: Dict[str, Set[str]] = {} # message_id -> set of session_ids
135 self._typing_status: Dict[str, datetime] = {} # session_id -> typing_until
137 self._cleanup_task: Optional[asyncio.Task] = None
139 @property
140 def name(self) -> str:
141 return "web"
143 async def connect(self) -> bool:
144 """Start the WebSocket server."""
145 try:
146 # Create upload directory
147 self._upload_dir.mkdir(parents=True, exist_ok=True)
149 # Create aiohttp application
150 self._app = web.Application(middlewares=[self._cors_middleware])
151 self._setup_routes()
153 # Start server
154 self._runner = web.AppRunner(self._app)
155 await self._runner.setup()
157 self._site = web.TCPSite(self._runner, self._host, self._port)
158 await self._site.start()
160 # Start cleanup task
161 self._cleanup_task = asyncio.create_task(self._cleanup_loop())
163 self.status = ChannelStatus.CONNECTED
164 logger.info(f"Web adapter started on ws://{self._host}:{self._port}")
165 return True
167 except Exception as e:
168 logger.error(f"Failed to start web adapter: {e}")
169 self.status = ChannelStatus.ERROR
170 return False
172 async def disconnect(self) -> None:
173 """Stop the WebSocket server."""
174 # Cancel cleanup task
175 if self._cleanup_task:
176 self._cleanup_task.cancel()
177 try:
178 await self._cleanup_task
179 except asyncio.CancelledError:
180 pass
182 # Close all WebSocket connections
183 for session in self._sessions.values():
184 for ws in list(session.websockets):
185 await ws.close()
187 # Stop server
188 if self._site:
189 await self._site.stop()
190 if self._runner:
191 await self._runner.cleanup()
193 self._app = None
194 self._runner = None
195 self._site = None
196 self.status = ChannelStatus.DISCONNECTED
197 logger.info("Web adapter stopped")
199 async def _cors_middleware(self, request, handler):
200 """CORS middleware for REST endpoints."""
201 if request.method == "OPTIONS":
202 return web.Response(
203 headers={
204 "Access-Control-Allow-Origin": ", ".join(self._cors_origins),
205 "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
206 "Access-Control-Allow-Headers": "Content-Type, Authorization, X-Session-ID",
207 "Access-Control-Max-Age": "86400",
208 }
209 )
211 response = await handler(request)
212 response.headers["Access-Control-Allow-Origin"] = ", ".join(self._cors_origins)
213 return response
215 def _setup_routes(self) -> None:
216 """Set up HTTP/WebSocket routes."""
217 self._app.router.add_get("/ws", self._handle_websocket)
218 self._app.router.add_get("/health", self._handle_health)
220 # REST API endpoints
221 self._app.router.add_post("/api/messages", self._handle_rest_message)
222 self._app.router.add_get("/api/messages", self._handle_get_messages)
223 self._app.router.add_post("/api/upload", self._handle_upload)
224 self._app.router.add_get("/api/download/{file_id}", self._handle_download)
225 self._app.router.add_post("/api/typing", self._handle_typing)
226 self._app.router.add_post("/api/read", self._handle_read_receipt)
227 self._app.router.add_get("/api/session", self._handle_session_info)
229 async def _handle_health(self, request: web.Request) -> web.Response:
230 """Health check endpoint."""
231 return web.json_response({
232 "status": "ok",
233 "channel": self.name,
234 "connections": sum(len(s.websockets) for s in self._sessions.values()),
235 "sessions": len(self._sessions),
236 })
238 async def _handle_websocket(self, request: web.Request) -> web.WebSocketResponse:
239 """Handle WebSocket connections."""
240 ws = web.WebSocketResponse(heartbeat=30)
241 await ws.prepare(request)
243 # Get session info from query params
244 session_id = request.query.get("session_id") or str(uuid.uuid4())
245 user_id = request.query.get("user_id", session_id)
246 user_name = request.query.get("user_name")
248 # Get or create session
249 session = self._sessions.get(session_id)
250 if not session:
251 session = WebSession(
252 session_id=session_id,
253 user_id=user_id,
254 user_name=user_name,
255 )
256 self._sessions[session_id] = session
257 else:
258 session.touch()
260 session.websockets.add(ws)
262 # Send connection confirmation
263 await ws.send_json({
264 "type": "connected",
265 "session_id": session_id,
266 "user_id": user_id,
267 })
269 # Send pending messages
270 pending = self._pending_messages.get(session_id, [])
271 for pm in pending:
272 await ws.send_json(pm.data)
273 if session_id in self._pending_messages:
274 del self._pending_messages[session_id]
276 logger.info(f"WebSocket connected: session={session_id}, user={user_id}")
278 try:
279 async for msg in ws:
280 if msg.type == WSMsgType.TEXT:
281 await self._handle_ws_message(session, msg.data)
282 elif msg.type == WSMsgType.BINARY:
283 await self._handle_ws_binary(session, msg.data)
284 elif msg.type == WSMsgType.ERROR:
285 logger.error(f"WebSocket error: {ws.exception()}")
286 break
287 finally:
288 session.websockets.discard(ws)
289 if not session.websockets:
290 logger.info(f"All connections closed for session {session_id}")
291 else:
292 logger.info(f"WebSocket disconnected: session={session_id} ({len(session.websockets)} remaining)")
294 return ws
296 async def _handle_ws_message(self, session: WebSession, data: str) -> None:
297 """Handle incoming WebSocket message."""
298 try:
299 payload = json.loads(data)
300 msg_type = payload.get("type", "message")
302 session.touch()
304 if msg_type == "message":
305 # Convert to Message and dispatch
306 message = Message(
307 id=payload.get("id") or str(uuid.uuid4()),
308 channel=self.name,
309 sender_id=session.user_id,
310 sender_name=session.user_name,
311 chat_id=payload.get("chat_id", session.session_id),
312 text=payload.get("text", ""),
313 reply_to_id=payload.get("reply_to"),
314 timestamp=datetime.now(),
315 is_group=payload.get("is_group", False),
316 raw=payload,
317 )
319 # Add media if present
320 if payload.get("attachments"):
321 for att in payload["attachments"]:
322 message.media.append(MediaAttachment(
323 type=MessageType(att.get("type", "document")),
324 file_id=att.get("file_id"),
325 file_name=att.get("file_name"),
326 mime_type=att.get("mime_type"),
327 url=att.get("url"),
328 ))
330 await self._dispatch_message(message)
332 elif msg_type == "typing":
333 # Handle typing indicator
334 self._typing_status[session.session_id] = datetime.now() + timedelta(seconds=5)
335 await self._broadcast_typing(session.session_id, session.user_name or session.user_id)
337 elif msg_type == "read":
338 # Handle read receipt
339 message_ids = payload.get("message_ids", [])
340 for msg_id in message_ids:
341 if msg_id not in self._read_receipts:
342 self._read_receipts[msg_id] = set()
343 self._read_receipts[msg_id].add(session.session_id)
345 elif msg_type == "ping":
346 # Respond to ping
347 await self._send_to_session(session.session_id, {"type": "pong"})
349 except json.JSONDecodeError:
350 logger.warning(f"Invalid JSON from session {session.session_id}")
351 except Exception as e:
352 logger.error(f"Error handling WebSocket message: {e}")
354 async def _handle_ws_binary(self, session: WebSession, data: bytes) -> None:
355 """Handle incoming binary data (file upload via WebSocket)."""
356 try:
357 # First 4 bytes = metadata length
358 meta_len = int.from_bytes(data[:4], "big")
359 metadata = json.loads(data[4:4+meta_len].decode())
360 file_data = data[4+meta_len:]
362 # Save file
363 file_id = str(uuid.uuid4())
364 file_name = metadata.get("file_name", "upload")
365 file_path = self._upload_dir / f"{file_id}_{file_name}"
366 file_path.write_bytes(file_data)
368 # Send confirmation
369 await self._send_to_session(session.session_id, {
370 "type": "upload_complete",
371 "file_id": file_id,
372 "file_name": file_name,
373 "size": len(file_data),
374 })
376 except Exception as e:
377 logger.error(f"Error handling binary upload: {e}")
378 await self._send_to_session(session.session_id, {
379 "type": "upload_error",
380 "error": str(e),
381 })
383 async def _handle_rest_message(self, request: web.Request) -> web.Response:
384 """Handle REST API message submission."""
385 try:
386 session_id = request.headers.get("X-Session-ID")
387 if not session_id:
388 return web.json_response({"error": "X-Session-ID header required"}, status=400)
390 data = await request.json()
392 # Get or create session
393 session = self._sessions.get(session_id)
394 if not session:
395 session = WebSession(
396 session_id=session_id,
397 user_id=data.get("user_id", session_id),
398 user_name=data.get("user_name"),
399 )
400 self._sessions[session_id] = session
402 # Create message
403 message = Message(
404 id=data.get("id") or str(uuid.uuid4()),
405 channel=self.name,
406 sender_id=session.user_id,
407 sender_name=session.user_name,
408 chat_id=data.get("chat_id", session_id),
409 text=data.get("text", ""),
410 reply_to_id=data.get("reply_to"),
411 timestamp=datetime.now(),
412 is_group=data.get("is_group", False),
413 raw=data,
414 )
416 await self._dispatch_message(message)
418 return web.json_response({
419 "success": True,
420 "message_id": message.id,
421 })
423 except Exception as e:
424 logger.error(f"REST message error: {e}")
425 return web.json_response({"error": str(e)}, status=500)
427 async def _handle_get_messages(self, request: web.Request) -> web.Response:
428 """Handle polling for messages (REST fallback)."""
429 session_id = request.headers.get("X-Session-ID")
430 if not session_id:
431 return web.json_response({"error": "X-Session-ID header required"}, status=400)
433 # Get pending messages
434 pending = self._pending_messages.get(session_id, [])
435 messages = [pm.data for pm in pending]
437 if session_id in self._pending_messages:
438 del self._pending_messages[session_id]
440 return web.json_response({"messages": messages})
442 async def _handle_upload(self, request: web.Request) -> web.Response:
443 """Handle file upload via REST."""
444 try:
445 reader = await request.multipart()
447 files = []
448 async for part in reader:
449 if part.filename:
450 file_id = str(uuid.uuid4())
451 file_name = part.filename
452 file_path = self._upload_dir / f"{file_id}_{file_name}"
454 # Save file
455 with open(file_path, "wb") as f:
456 while True:
457 chunk = await part.read_chunk()
458 if not chunk:
459 break
460 f.write(chunk)
462 files.append({
463 "file_id": file_id,
464 "file_name": file_name,
465 "size": file_path.stat().st_size,
466 "mime_type": part.headers.get("Content-Type"),
467 })
469 return web.json_response({"files": files})
471 except Exception as e:
472 logger.error(f"Upload error: {e}")
473 return web.json_response({"error": str(e)}, status=500)
475 async def _handle_download(self, request: web.Request) -> web.Response:
476 """Handle file download."""
477 file_id = request.match_info["file_id"]
479 # Find file
480 for file_path in self._upload_dir.glob(f"{file_id}_*"):
481 if file_path.is_file():
482 file_name = file_path.name[len(file_id)+1:]
483 mime_type = mimetypes.guess_type(file_name)[0] or "application/octet-stream"
485 return web.FileResponse(
486 file_path,
487 headers={
488 "Content-Disposition": f'attachment; filename="{file_name}"',
489 "Content-Type": mime_type,
490 }
491 )
493 return web.json_response({"error": "File not found"}, status=404)
495 async def _handle_typing(self, request: web.Request) -> web.Response:
496 """Handle typing indicator via REST."""
497 session_id = request.headers.get("X-Session-ID")
498 if not session_id:
499 return web.json_response({"error": "X-Session-ID header required"}, status=400)
501 session = self._sessions.get(session_id)
502 if session:
503 self._typing_status[session_id] = datetime.now() + timedelta(seconds=5)
504 await self._broadcast_typing(session_id, session.user_name or session.user_id)
506 return web.json_response({"success": True})
508 async def _handle_read_receipt(self, request: web.Request) -> web.Response:
509 """Handle read receipt via REST."""
510 session_id = request.headers.get("X-Session-ID")
511 if not session_id:
512 return web.json_response({"error": "X-Session-ID header required"}, status=400)
514 data = await request.json()
515 message_ids = data.get("message_ids", [])
517 for msg_id in message_ids:
518 if msg_id not in self._read_receipts:
519 self._read_receipts[msg_id] = set()
520 self._read_receipts[msg_id].add(session_id)
522 return web.json_response({"success": True})
524 async def _handle_session_info(self, request: web.Request) -> web.Response:
525 """Get session information."""
526 session_id = request.headers.get("X-Session-ID")
527 if not session_id:
528 return web.json_response({"error": "X-Session-ID header required"}, status=400)
530 session = self._sessions.get(session_id)
531 if not session:
532 return web.json_response({"error": "Session not found"}, status=404)
534 return web.json_response({
535 "session_id": session.session_id,
536 "user_id": session.user_id,
537 "user_name": session.user_name,
538 "connected_at": session.connected_at.isoformat(),
539 "last_activity": session.last_activity.isoformat(),
540 "is_connected": session.is_connected,
541 "connection_count": len(session.websockets),
542 })
544 async def _broadcast_typing(self, from_session: str, from_name: str) -> None:
545 """Broadcast typing indicator to other sessions."""
546 data = {
547 "type": "typing",
548 "from_session": from_session,
549 "from_name": from_name,
550 }
552 for session_id, session in self._sessions.items():
553 if session_id != from_session:
554 await self._send_to_session(session_id, data)
556 async def _send_to_session(
557 self,
558 session_id: str,
559 data: Dict[str, Any],
560 queue_if_offline: bool = True,
561 ) -> bool:
562 """Send data to a session."""
563 session = self._sessions.get(session_id)
565 if session and session.websockets:
566 # Send to all connected WebSockets
567 for ws in list(session.websockets):
568 try:
569 await ws.send_json(data)
570 except Exception:
571 session.websockets.discard(ws)
573 return True
575 elif queue_if_offline:
576 # Queue for later delivery
577 if session_id not in self._pending_messages:
578 self._pending_messages[session_id] = []
580 self._pending_messages[session_id].append(PendingMessage(
581 id=str(uuid.uuid4()),
582 session_id=session_id,
583 data=data,
584 ))
586 return False
588 return False
590 async def _cleanup_loop(self) -> None:
591 """Periodically clean up expired sessions and messages."""
592 while True:
593 try:
594 await asyncio.sleep(60) # Run every minute
596 now = datetime.now()
597 timeout = timedelta(seconds=self._session_timeout)
599 # Clean up inactive sessions
600 expired_sessions = [
601 session_id
602 for session_id, session in self._sessions.items()
603 if not session.is_connected and (now - session.last_activity) > timeout
604 ]
606 for session_id in expired_sessions:
607 del self._sessions[session_id]
608 if session_id in self._pending_messages:
609 del self._pending_messages[session_id]
610 logger.debug(f"Cleaned up expired session: {session_id}")
612 # Clean up expired pending messages
613 for session_id, messages in list(self._pending_messages.items()):
614 self._pending_messages[session_id] = [
615 pm for pm in messages if pm.expires_at > now
616 ]
617 if not self._pending_messages[session_id]:
618 del self._pending_messages[session_id]
620 # Clean up old read receipts
621 if len(self._read_receipts) > 10000:
622 # Keep only most recent 5000
623 self._read_receipts = dict(list(self._read_receipts.items())[-5000:])
625 except asyncio.CancelledError:
626 break
627 except Exception as e:
628 logger.error(f"Cleanup error: {e}")
630 async def send_message(
631 self,
632 chat_id: str,
633 text: str,
634 reply_to: Optional[str] = None,
635 media: Optional[List[MediaAttachment]] = None,
636 buttons: Optional[List[Dict]] = None,
637 ) -> SendResult:
638 """Send a message to a web client."""
639 message_id = str(uuid.uuid4())
641 data = {
642 "type": "message",
643 "id": message_id,
644 "text": text,
645 "reply_to": reply_to,
646 "timestamp": datetime.now().isoformat(),
647 }
649 if media:
650 data["attachments"] = [
651 {
652 "type": m.type.value,
653 "file_id": m.file_id,
654 "file_name": m.file_name,
655 "mime_type": m.mime_type,
656 "url": m.url or f"/api/download/{m.file_id}" if m.file_id else None,
657 }
658 for m in media
659 ]
661 if buttons:
662 data["buttons"] = buttons
664 # Send to session
665 delivered = await self._send_to_session(chat_id, data)
667 return SendResult(
668 success=True,
669 message_id=message_id,
670 raw={"delivered": delivered, "queued": not delivered},
671 )
673 async def edit_message(
674 self,
675 chat_id: str,
676 message_id: str,
677 text: str,
678 buttons: Optional[List[Dict]] = None,
679 ) -> SendResult:
680 """Edit an existing message."""
681 data = {
682 "type": "message_edit",
683 "id": message_id,
684 "text": text,
685 }
687 if buttons:
688 data["buttons"] = buttons
690 await self._send_to_session(chat_id, data)
692 return SendResult(success=True, message_id=message_id)
694 async def delete_message(self, chat_id: str, message_id: str) -> bool:
695 """Delete a message."""
696 data = {
697 "type": "message_delete",
698 "id": message_id,
699 }
701 await self._send_to_session(chat_id, data)
702 return True
704 async def send_typing(self, chat_id: str) -> None:
705 """Send typing indicator."""
706 await self._send_to_session(chat_id, {
707 "type": "typing",
708 "from_name": "Bot",
709 }, queue_if_offline=False)
711 async def get_chat_info(self, chat_id: str) -> Optional[Dict[str, Any]]:
712 """Get information about a session."""
713 session = self._sessions.get(chat_id)
714 if not session:
715 return None
717 return {
718 "id": session.session_id,
719 "type": "web",
720 "user_id": session.user_id,
721 "user_name": session.user_name,
722 "is_connected": session.is_connected,
723 "connection_count": len(session.websockets),
724 }
726 def get_active_sessions(self) -> List[Dict[str, Any]]:
727 """Get all active sessions."""
728 return [
729 {
730 "session_id": s.session_id,
731 "user_id": s.user_id,
732 "user_name": s.user_name,
733 "is_connected": s.is_connected,
734 "last_activity": s.last_activity.isoformat(),
735 }
736 for s in self._sessions.values()
737 ]
739 def get_read_receipts(self, message_id: str) -> List[str]:
740 """Get list of session IDs that have read a message."""
741 return list(self._read_receipts.get(message_id, set()))
744def create_web_adapter(
745 host: str = None,
746 port: int = None,
747 **kwargs
748) -> WebAdapter:
749 """
750 Factory function to create Web adapter.
752 Args:
753 host: Host to bind to (default: 0.0.0.0)
754 port: Port to bind to (default: 8765, or WEB_ADAPTER_PORT env var)
755 **kwargs: Additional config options
757 Returns:
758 Configured WebAdapter
759 """
760 host = host or os.getenv("WEB_ADAPTER_HOST", "0.0.0.0")
761 port = port or int(os.getenv("WEB_ADAPTER_PORT", "8765"))
763 config = ChannelConfig(
764 extra={
765 "host": host,
766 "port": port,
767 **kwargs.get("extra", {}),
768 },
769 **{k: v for k, v in kwargs.items() if k != "extra"},
770 )
771 return WebAdapter(config)