Coverage for integrations / remote_desktop / session_manager.py: 87.7%

130 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Remote Desktop Session Manager — Session lifecycle, OTP auth, multi-viewer support. 

3 

4Session flow: 

5 1. Host starts hosting → SessionManager.generate_otp(device_id) → 6-char password 

6 2. Viewer connects → SessionManager.create_session(host_id, viewer_id, mode) 

7 3. Auth → SessionManager.authenticate_session(session_id, password) 

8 4. Connected → streaming begins 

9 5. Disconnect → SessionManager.disconnect_session(session_id) 

10 

11Same-user devices auto-accept (no OTP needed), matching compute_mesh_service.py:398. 

12Cross-user requires OTP + explicit consent notification. 

13""" 

14 

15import logging 

16import os 

17import secrets 

18import string 

19import threading 

20import time 

21from dataclasses import dataclass, field 

22from enum import Enum 

23from typing import Dict, List, Optional 

24 

25logger = logging.getLogger('hevolve.remote_desktop') 

26 

27 

28# ── Enums ─────────────────────────────────────────────────────── 

29 

30class SessionMode(Enum): 

31 VIEW_ONLY = 'view_only' 

32 FULL_CONTROL = 'full_control' 

33 FILE_TRANSFER = 'file_transfer' 

34 

35 

36class SessionState(Enum): 

37 PENDING = 'pending' 

38 AUTHENTICATING = 'authenticating' 

39 CONNECTED = 'connected' 

40 DISCONNECTED = 'disconnected' 

41 

42 

43# ── Data Classes ──────────────────────────────────────────────── 

44 

45@dataclass 

46class RemoteSession: 

47 session_id: str 

48 host_device_id: str 

49 host_user_id: Optional[str] 

50 mode: SessionMode 

51 state: SessionState = SessionState.PENDING 

52 viewers: List[dict] = field(default_factory=list) 

53 created_at: float = field(default_factory=time.time) 

54 connected_at: Optional[float] = None 

55 disconnected_at: Optional[float] = None 

56 transport_tier: Optional[str] = None 

57 

58 def add_viewer(self, device_id: str, user_id: Optional[str] = None) -> None: 

59 if not any(v['device_id'] == device_id for v in self.viewers): 

60 self.viewers.append({ 

61 'device_id': device_id, 

62 'user_id': user_id, 

63 'joined_at': time.time(), 

64 }) 

65 

66 def remove_viewer(self, device_id: str) -> None: 

67 self.viewers = [v for v in self.viewers if v['device_id'] != device_id] 

68 

69 def to_dict(self) -> dict: 

70 return { 

71 'session_id': self.session_id, 

72 'host_device_id': self.host_device_id, 

73 'host_user_id': self.host_user_id, 

74 'mode': self.mode.value, 

75 'state': self.state.value, 

76 'viewers': self.viewers, 

77 'created_at': self.created_at, 

78 'connected_at': self.connected_at, 

79 'disconnected_at': self.disconnected_at, 

80 'transport_tier': self.transport_tier, 

81 'duration_seconds': self._duration(), 

82 } 

83 

84 def _duration(self) -> Optional[float]: 

85 if self.connected_at: 

86 end = self.disconnected_at or time.time() 

87 return round(end - self.connected_at, 1) 

88 return None 

89 

90 

91# ── Session Manager (singleton) ───────────────────────────────── 

92 

93class SessionManager: 

94 """Manages remote desktop sessions, OTP passwords, and multi-viewer support.""" 

95 

96 OTP_LENGTH = 6 

97 OTP_CHARS = string.ascii_lowercase + string.digits # a-z, 0-9 

98 OTP_EXPIRY_SECONDS = 300 # 5 minutes 

99 MAX_SESSIONS_PER_HOST = 5 

100 SESSION_TIMEOUT_SECONDS = 86400 # 24 hours 

101 

102 def __init__(self): 

103 self._sessions: Dict[str, RemoteSession] = {} 

104 self._otps: Dict[str, dict] = {} # device_id → {password, created_at} 

105 self._lock = threading.Lock() 

106 logger.info("SessionManager initialized") 

107 

108 def generate_otp(self, device_id: str) -> str: 

109 """Generate one-time password for a hosting device. 

110 

111 Returns: 

112 6-char alphanumeric password (e.g., 'a8f2k9') 

113 """ 

114 password = ''.join(secrets.choice(self.OTP_CHARS) for _ in range(self.OTP_LENGTH)) 

115 with self._lock: 

116 self._otps[device_id] = { 

117 'password': password, 

118 'created_at': time.time(), 

119 'used': False, 

120 } 

121 logger.info(f"OTP generated for device {device_id[:8]}...") 

122 return password 

123 

124 def verify_otp(self, device_id: str, password: str) -> bool: 

125 """Verify one-time password (single-use, expires after OTP_EXPIRY_SECONDS). 

126 

127 Returns: 

128 True if password matches and hasn't been used/expired. 

129 """ 

130 with self._lock: 

131 otp_entry = self._otps.get(device_id) 

132 if not otp_entry: 

133 return False 

134 if otp_entry['used']: 

135 return False 

136 if time.time() - otp_entry['created_at'] > self.OTP_EXPIRY_SECONDS: 

137 del self._otps[device_id] 

138 return False 

139 if otp_entry['password'] != password: 

140 return False 

141 # Mark as used (single-use) 

142 otp_entry['used'] = True 

143 return True 

144 

145 def is_same_user(self, host_user_id: Optional[str], 

146 viewer_user_id: Optional[str]) -> bool: 

147 """Check if host and viewer belong to same user. 

148 

149 Same-user devices auto-accept (no OTP needed), 

150 matching compute_mesh_service.py:398 auto_accept pattern. 

151 """ 

152 if not host_user_id or not viewer_user_id: 

153 return False 

154 return str(host_user_id) == str(viewer_user_id) 

155 

156 def create_session(self, host_device_id: str, viewer_device_id: str, 

157 mode: SessionMode, 

158 host_user_id: Optional[str] = None, 

159 viewer_user_id: Optional[str] = None) -> RemoteSession: 

160 """Create a new remote desktop session. 

161 

162 Args: 

163 host_device_id: Device ID of the host 

164 viewer_device_id: Device ID of the viewer 

165 mode: Session mode (VIEW_ONLY, FULL_CONTROL, FILE_TRANSFER) 

166 host_user_id: User ID of the host device owner 

167 viewer_user_id: User ID of the viewer 

168 

169 Returns: 

170 RemoteSession instance 

171 """ 

172 session_id = secrets.token_hex(8) 

173 

174 # Check session limit per host 

175 with self._lock: 

176 active_count = sum( 

177 1 for s in self._sessions.values() 

178 if s.host_device_id == host_device_id 

179 and s.state in (SessionState.PENDING, SessionState.AUTHENTICATING, 

180 SessionState.CONNECTED) 

181 ) 

182 if active_count >= self.MAX_SESSIONS_PER_HOST: 

183 raise ValueError( 

184 f"Host {host_device_id[:8]} has {active_count} active sessions " 

185 f"(max {self.MAX_SESSIONS_PER_HOST})" 

186 ) 

187 

188 session = RemoteSession( 

189 session_id=session_id, 

190 host_device_id=host_device_id, 

191 host_user_id=host_user_id, 

192 mode=mode, 

193 state=SessionState.PENDING, 

194 ) 

195 session.add_viewer(viewer_device_id, viewer_user_id) 

196 

197 # Same-user auto-accept (compute_mesh_service.py:398 pattern) 

198 if self.is_same_user(host_user_id, viewer_user_id): 

199 session.state = SessionState.CONNECTED 

200 session.connected_at = time.time() 

201 logger.info( 

202 f"Session {session_id}: same-user auto-accept " 

203 f"(host={host_device_id[:8]}, viewer={viewer_device_id[:8]})" 

204 ) 

205 else: 

206 session.state = SessionState.AUTHENTICATING 

207 logger.info( 

208 f"Session {session_id}: cross-user, OTP required " 

209 f"(host_user={host_user_id}, viewer_user={viewer_user_id})" 

210 ) 

211 

212 with self._lock: 

213 self._sessions[session_id] = session 

214 return session 

215 

216 def authenticate_session(self, session_id: str, password: str) -> bool: 

217 """Authenticate a pending session with OTP. 

218 

219 Returns: 

220 True if session authenticated successfully. 

221 """ 

222 with self._lock: 

223 session = self._sessions.get(session_id) 

224 if not session: 

225 return False 

226 if session.state != SessionState.AUTHENTICATING: 

227 return False 

228 

229 if self.verify_otp(session.host_device_id, password): 

230 session.state = SessionState.CONNECTED 

231 session.connected_at = time.time() 

232 logger.info(f"Session {session_id} authenticated") 

233 return True 

234 

235 logger.warning(f"Session {session_id} auth failed") 

236 return False 

237 

238 def add_viewer(self, session_id: str, device_id: str, 

239 user_id: Optional[str] = None) -> bool: 

240 """Add a viewer to an existing session (multi-viewer support). 

241 

242 Returns: 

243 True if viewer added successfully. 

244 """ 

245 with self._lock: 

246 session = self._sessions.get(session_id) 

247 if not session or session.state != SessionState.CONNECTED: 

248 return False 

249 session.add_viewer(device_id, user_id) 

250 logger.info(f"Viewer {device_id[:8]} added to session {session_id}") 

251 return True 

252 

253 def disconnect_session(self, session_id: str) -> bool: 

254 """Disconnect a session. 

255 

256 Returns: 

257 True if session was found and disconnected. 

258 """ 

259 with self._lock: 

260 session = self._sessions.get(session_id) 

261 if not session: 

262 return False 

263 if session.state == SessionState.DISCONNECTED: 

264 return False 

265 

266 session.state = SessionState.DISCONNECTED 

267 session.disconnected_at = time.time() 

268 logger.info(f"Session {session_id} disconnected") 

269 return True 

270 

271 def get_session(self, session_id: str) -> Optional[RemoteSession]: 

272 """Get session by ID.""" 

273 with self._lock: 

274 return self._sessions.get(session_id) 

275 

276 def get_active_sessions(self) -> List[RemoteSession]: 

277 """Get all active (non-disconnected) sessions.""" 

278 with self._lock: 

279 return [ 

280 s for s in self._sessions.values() 

281 if s.state != SessionState.DISCONNECTED 

282 ] 

283 

284 def get_sessions_for_device(self, device_id: str) -> List[RemoteSession]: 

285 """Get all sessions where device is host or viewer.""" 

286 with self._lock: 

287 results = [] 

288 for s in self._sessions.values(): 

289 if s.host_device_id == device_id: 

290 results.append(s) 

291 elif any(v['device_id'] == device_id for v in s.viewers): 

292 results.append(s) 

293 return results 

294 

295 def cleanup_stale(self) -> int: 

296 """Remove expired sessions. Returns count of removed sessions.""" 

297 cutoff = time.time() - self.SESSION_TIMEOUT_SECONDS 

298 removed = 0 

299 with self._lock: 

300 stale_ids = [ 

301 sid for sid, s in self._sessions.items() 

302 if s.created_at < cutoff or ( 

303 s.state == SessionState.DISCONNECTED 

304 and s.disconnected_at 

305 and s.disconnected_at < cutoff 

306 ) 

307 ] 

308 for sid in stale_ids: 

309 del self._sessions[sid] 

310 removed += 1 

311 

312 # Clean expired OTPs 

313 expired_devices = [ 

314 dev for dev, otp in self._otps.items() 

315 if time.time() - otp['created_at'] > self.OTP_EXPIRY_SECONDS 

316 ] 

317 for dev in expired_devices: 

318 del self._otps[dev] 

319 

320 if removed: 

321 logger.info(f"Cleaned up {removed} stale sessions") 

322 return removed 

323 

324 

325# ── Singleton ─────────────────────────────────────────────────── 

326 

327_session_manager: Optional[SessionManager] = None 

328 

329 

330def get_session_manager() -> SessionManager: 

331 """Get or create the singleton SessionManager.""" 

332 global _session_manager 

333 if _session_manager is None: 

334 _session_manager = SessionManager() 

335 return _session_manager