Coverage for integrations / audio / diarization_server.py: 0.0%

97 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Speaker Diarization Server - standalone WebSocket server for the sidecar. 

3 

4Derived from speaker_diarization/main.py but with: 

5- Security fixes (no eval, proper JSON) 

6- Configurable via CLI args and env vars 

7- Buffer cleanup on disconnect 

8- GPU/CPU auto-detection 

9 

10Usage: 

11 python -m integrations.audio.diarization_server --port 8004 

12""" 

13import argparse 

14import asyncio 

15import ast 

16import io 

17import json 

18import logging 

19import os 

20import sys 

21 

22import numpy as np 

23 

24try: 

25 from integrations.service_tools.vram_manager import clear_cuda_cache 

26except ImportError: 

27 def clear_cuda_cache(): 

28 try: 

29 import torch 

30 if torch.cuda.is_available(): 

31 torch.cuda.empty_cache() 

32 if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 

33 torch.mps.empty_cache() 

34 except Exception: 

35 pass 

36 

37logger = logging.getLogger('hevolve_diarization') 

38 

39# Audio parameters (16kHz, 16-bit, mono) 

40SAMPLE_RATE = 16000 

41BYTES_PER_SAMPLE = 2 

42CHANNELS = 1 

43SECONDS = 1 

44EXPECTED_BYTES = SAMPLE_RATE * BYTES_PER_SAMPLE * CHANNELS * SECONDS # 32KB 

45 

46# Per-user audio stream buffers 

47audio_streams = {} 

48 

49 

50def _parse_message(raw): 

51 """Parse incoming WebSocket message safely. 

52 

53 Handles both JSON format and Python dict format (single quotes) 

54 sent by Android SpeechService.java for backward compatibility. 

55 """ 

56 if isinstance(raw, bytes): 

57 raw = raw.decode('utf-8') 

58 try: 

59 return json.loads(raw) 

60 except (json.JSONDecodeError, ValueError): 

61 return ast.literal_eval(raw) 

62 

63 

64async def diarization(websocket, diarize_model, output_dir, device): 

65 """Handle a single WebSocket connection for speaker diarization.""" 

66 import torch 

67 

68 user_id = None 

69 try: 

70 logging.info("Waiting for audio data...") 

71 async for message in websocket: 

72 parsed = _parse_message(message) 

73 user_id = parsed['user_id'] 

74 pcm_bytes = parsed['chunk'] 

75 

76 if user_id not in audio_streams: 

77 audio_streams[user_id] = io.BytesIO() 

78 

79 # Handle hex-encoded or binary bytes 

80 if isinstance(pcm_bytes, str): 

81 pcm_bytes = bytes.fromhex(pcm_bytes) 

82 

83 audio_streams[user_id].write(pcm_bytes) 

84 

85 if audio_streams[user_id].getbuffer().nbytes >= EXPECTED_BYTES: 

86 audio_streams[user_id].seek(0) 

87 audio_data_bytes = audio_streams[user_id].read() 

88 

89 try: 

90 with torch.no_grad(): 

91 logging.info( 

92 f'Processing audio for user_id {user_id}') 

93 audios = ( 

94 np.frombuffer(audio_data_bytes, np.int16) 

95 .flatten() 

96 .astype(np.float32) / 32768.0 

97 ) 

98 diarize_segments = diarize_model(audios) 

99 unique_speakers = diarize_segments['speaker'].unique() 

100 no_of_speakers = len(unique_speakers) 

101 logging.info( 

102 f'Speakers: {unique_speakers} ' 

103 f'for user_id {user_id}') 

104 

105 # Export/append MP3 for audit trail 

106 _export_audio( 

107 audio_data_bytes, user_id, output_dir) 

108 

109 # Voice signature enrollment — dispatched to HevolveAI 

110 try: 

111 from core.resonance_identifier import ResonanceIdentifier 

112 if user_id: 

113 _identifier = ResonanceIdentifier() 

114 _identifier.enroll_voice(str(user_id), audio_data_bytes) 

115 except ImportError: 

116 pass 

117 except Exception: 

118 pass 

119 

120 res = { 

121 "no_of_speaker": no_of_speakers, 

122 "stop_mic": no_of_speakers > 1, 

123 } 

124 await websocket.send(json.dumps(res)) 

125 logging.info(f"Result: {res}") 

126 

127 except Exception as e: 

128 logging.error( 

129 f'Diarization error at line ' 

130 f'{e.__traceback__.tb_lineno}: {e}') 

131 finally: 

132 _cleanup_stream(user_id) 

133 clear_cuda_cache() 

134 

135 except Exception as e: 

136 logging.debug(f"Connection ended: {e}") 

137 finally: 

138 # Cleanup buffer on disconnect (prevents memory leak) 

139 if user_id: 

140 _cleanup_stream(user_id) 

141 

142 

143def _cleanup_stream(user_id): 

144 """Close and remove a user's audio buffer.""" 

145 if user_id in audio_streams: 

146 try: 

147 audio_streams[user_id].close() 

148 except Exception: 

149 pass 

150 del audio_streams[user_id] 

151 

152 

153def _export_audio(audio_data_bytes, user_id, output_dir): 

154 """Export audio chunk as MP3, appending to existing file.""" 

155 try: 

156 from pydub import AudioSegment 

157 except ImportError: 

158 return 

159 

160 try: 

161 audio_segment = AudioSegment( 

162 data=audio_data_bytes, 

163 sample_width=BYTES_PER_SAMPLE, 

164 frame_rate=SAMPLE_RATE, 

165 channels=CHANNELS, 

166 ) 

167 mp3_path = os.path.join(output_dir, f'{user_id}.mp3') 

168 if os.path.exists(mp3_path): 

169 existing = AudioSegment.from_mp3(mp3_path) 

170 audio_segment = existing + audio_segment 

171 audio_segment.export(mp3_path, format='mp3') 

172 except Exception as e: 

173 logging.error(f"Failed to export audio: {e}") 

174 

175 

176async def main(port, device, output_dir, hf_token): 

177 """Start diarization model and WebSocket server.""" 

178 import torch 

179 import websockets 

180 

181 # Load diarization model 

182 logging.info(f"Loading diarization model on {device}...") 

183 try: 

184 from whisperx.diarize import DiarizationPipeline 

185 diarize_model = DiarizationPipeline( 

186 use_auth_token=hf_token, device=device) 

187 except Exception as e: 

188 logging.error(f"Failed to load diarization model: {e}") 

189 sys.exit(1) 

190 

191 logging.info("Diarization model loaded") 

192 

193 os.makedirs(output_dir, exist_ok=True) 

194 

195 # Bind with dynamic port support 

196 server = await websockets.serve( 

197 lambda ws, path=None: diarization( 

198 ws, diarize_model, output_dir, device), 

199 '0.0.0.0', port, 

200 ) 

201 

202 actual_port = port 

203 if server.sockets: 

204 actual_port = server.sockets[0].getsockname()[1] 

205 

206 # Readiness signal (DiarizationService reads this from stdout) 

207 print(f"DIARIZATION_READY:{actual_port}", flush=True) 

208 logging.info( 

209 f"Speaker diarization server on port {actual_port}") 

210 

211 try: 

212 await asyncio.Future() # run forever 

213 finally: 

214 server.close() 

215 await server.wait_closed() 

216 

217 

218if __name__ == "__main__": 

219 from core.port_registry import get_port as _get_port 

220 parser = argparse.ArgumentParser( 

221 description='Speaker Diarization Sidecar') 

222 parser.add_argument( 

223 '--port', type=int, 

224 default=int(os.environ.get('HEVOLVE_DIARIZATION_PORT', 

225 _get_port('diarization')))) 

226 parser.add_argument( 

227 '--device', 

228 default=os.environ.get('HEVOLVE_DIARIZATION_DEVICE', None)) 

229 parser.add_argument( 

230 '--output_dir', 

231 default=os.path.join( 

232 os.path.expanduser('~'), '.hevolve', 'audio')) 

233 args = parser.parse_args() 

234 

235 # Auto-detect device 

236 if args.device is None: 

237 try: 

238 import torch 

239 if torch.cuda.is_available(): 

240 args.device = 'cuda' 

241 elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 

242 args.device = 'mps' 

243 else: 

244 args.device = 'cpu' 

245 except ImportError: 

246 args.device = 'cpu' 

247 

248 # HuggingFace token 

249 hf_token = os.environ.get('HEVOLVE_HF_TOKEN', '') 

250 if not hf_token: 

251 # Fallback: try config.json in original location 

252 for cfg_path in [ 

253 'config.json', 

254 os.path.join(os.path.expanduser('~'), '.hevolve', 'config.json'), 

255 ]: 

256 if os.path.isfile(cfg_path): 

257 try: 

258 with open(cfg_path) as f: 

259 cfg = json.load(f) 

260 hf_token = cfg.get('huggingface', '') 

261 if hf_token: 

262 break 

263 except Exception: 

264 pass 

265 

266 if not hf_token: 

267 print("ERROR: HEVOLVE_HF_TOKEN env var or config.json " 

268 "'huggingface' key required", file=sys.stderr) 

269 sys.exit(1) 

270 

271 # Logging 

272 logging.basicConfig( 

273 level=logging.INFO, 

274 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 

275 ) 

276 

277 asyncio.run(main(args.port, args.device, args.output_dir, hf_token))