Coverage for integrations / audio / diarization_server.py: 0.0%
97 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Speaker Diarization Server - standalone WebSocket server for the sidecar.
4Derived from speaker_diarization/main.py but with:
5- Security fixes (no eval, proper JSON)
6- Configurable via CLI args and env vars
7- Buffer cleanup on disconnect
8- GPU/CPU auto-detection
10Usage:
11 python -m integrations.audio.diarization_server --port 8004
12"""
13import argparse
14import asyncio
15import ast
16import io
17import json
18import logging
19import os
20import sys
22import numpy as np
24try:
25 from integrations.service_tools.vram_manager import clear_cuda_cache
26except ImportError:
27 def clear_cuda_cache():
28 try:
29 import torch
30 if torch.cuda.is_available():
31 torch.cuda.empty_cache()
32 if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
33 torch.mps.empty_cache()
34 except Exception:
35 pass
37logger = logging.getLogger('hevolve_diarization')
39# Audio parameters (16kHz, 16-bit, mono)
40SAMPLE_RATE = 16000
41BYTES_PER_SAMPLE = 2
42CHANNELS = 1
43SECONDS = 1
44EXPECTED_BYTES = SAMPLE_RATE * BYTES_PER_SAMPLE * CHANNELS * SECONDS # 32KB
46# Per-user audio stream buffers
47audio_streams = {}
50def _parse_message(raw):
51 """Parse incoming WebSocket message safely.
53 Handles both JSON format and Python dict format (single quotes)
54 sent by Android SpeechService.java for backward compatibility.
55 """
56 if isinstance(raw, bytes):
57 raw = raw.decode('utf-8')
58 try:
59 return json.loads(raw)
60 except (json.JSONDecodeError, ValueError):
61 return ast.literal_eval(raw)
64async def diarization(websocket, diarize_model, output_dir, device):
65 """Handle a single WebSocket connection for speaker diarization."""
66 import torch
68 user_id = None
69 try:
70 logging.info("Waiting for audio data...")
71 async for message in websocket:
72 parsed = _parse_message(message)
73 user_id = parsed['user_id']
74 pcm_bytes = parsed['chunk']
76 if user_id not in audio_streams:
77 audio_streams[user_id] = io.BytesIO()
79 # Handle hex-encoded or binary bytes
80 if isinstance(pcm_bytes, str):
81 pcm_bytes = bytes.fromhex(pcm_bytes)
83 audio_streams[user_id].write(pcm_bytes)
85 if audio_streams[user_id].getbuffer().nbytes >= EXPECTED_BYTES:
86 audio_streams[user_id].seek(0)
87 audio_data_bytes = audio_streams[user_id].read()
89 try:
90 with torch.no_grad():
91 logging.info(
92 f'Processing audio for user_id {user_id}')
93 audios = (
94 np.frombuffer(audio_data_bytes, np.int16)
95 .flatten()
96 .astype(np.float32) / 32768.0
97 )
98 diarize_segments = diarize_model(audios)
99 unique_speakers = diarize_segments['speaker'].unique()
100 no_of_speakers = len(unique_speakers)
101 logging.info(
102 f'Speakers: {unique_speakers} '
103 f'for user_id {user_id}')
105 # Export/append MP3 for audit trail
106 _export_audio(
107 audio_data_bytes, user_id, output_dir)
109 # Voice signature enrollment — dispatched to HevolveAI
110 try:
111 from core.resonance_identifier import ResonanceIdentifier
112 if user_id:
113 _identifier = ResonanceIdentifier()
114 _identifier.enroll_voice(str(user_id), audio_data_bytes)
115 except ImportError:
116 pass
117 except Exception:
118 pass
120 res = {
121 "no_of_speaker": no_of_speakers,
122 "stop_mic": no_of_speakers > 1,
123 }
124 await websocket.send(json.dumps(res))
125 logging.info(f"Result: {res}")
127 except Exception as e:
128 logging.error(
129 f'Diarization error at line '
130 f'{e.__traceback__.tb_lineno}: {e}')
131 finally:
132 _cleanup_stream(user_id)
133 clear_cuda_cache()
135 except Exception as e:
136 logging.debug(f"Connection ended: {e}")
137 finally:
138 # Cleanup buffer on disconnect (prevents memory leak)
139 if user_id:
140 _cleanup_stream(user_id)
143def _cleanup_stream(user_id):
144 """Close and remove a user's audio buffer."""
145 if user_id in audio_streams:
146 try:
147 audio_streams[user_id].close()
148 except Exception:
149 pass
150 del audio_streams[user_id]
153def _export_audio(audio_data_bytes, user_id, output_dir):
154 """Export audio chunk as MP3, appending to existing file."""
155 try:
156 from pydub import AudioSegment
157 except ImportError:
158 return
160 try:
161 audio_segment = AudioSegment(
162 data=audio_data_bytes,
163 sample_width=BYTES_PER_SAMPLE,
164 frame_rate=SAMPLE_RATE,
165 channels=CHANNELS,
166 )
167 mp3_path = os.path.join(output_dir, f'{user_id}.mp3')
168 if os.path.exists(mp3_path):
169 existing = AudioSegment.from_mp3(mp3_path)
170 audio_segment = existing + audio_segment
171 audio_segment.export(mp3_path, format='mp3')
172 except Exception as e:
173 logging.error(f"Failed to export audio: {e}")
176async def main(port, device, output_dir, hf_token):
177 """Start diarization model and WebSocket server."""
178 import torch
179 import websockets
181 # Load diarization model
182 logging.info(f"Loading diarization model on {device}...")
183 try:
184 from whisperx.diarize import DiarizationPipeline
185 diarize_model = DiarizationPipeline(
186 use_auth_token=hf_token, device=device)
187 except Exception as e:
188 logging.error(f"Failed to load diarization model: {e}")
189 sys.exit(1)
191 logging.info("Diarization model loaded")
193 os.makedirs(output_dir, exist_ok=True)
195 # Bind with dynamic port support
196 server = await websockets.serve(
197 lambda ws, path=None: diarization(
198 ws, diarize_model, output_dir, device),
199 '0.0.0.0', port,
200 )
202 actual_port = port
203 if server.sockets:
204 actual_port = server.sockets[0].getsockname()[1]
206 # Readiness signal (DiarizationService reads this from stdout)
207 print(f"DIARIZATION_READY:{actual_port}", flush=True)
208 logging.info(
209 f"Speaker diarization server on port {actual_port}")
211 try:
212 await asyncio.Future() # run forever
213 finally:
214 server.close()
215 await server.wait_closed()
218if __name__ == "__main__":
219 from core.port_registry import get_port as _get_port
220 parser = argparse.ArgumentParser(
221 description='Speaker Diarization Sidecar')
222 parser.add_argument(
223 '--port', type=int,
224 default=int(os.environ.get('HEVOLVE_DIARIZATION_PORT',
225 _get_port('diarization'))))
226 parser.add_argument(
227 '--device',
228 default=os.environ.get('HEVOLVE_DIARIZATION_DEVICE', None))
229 parser.add_argument(
230 '--output_dir',
231 default=os.path.join(
232 os.path.expanduser('~'), '.hevolve', 'audio'))
233 args = parser.parse_args()
235 # Auto-detect device
236 if args.device is None:
237 try:
238 import torch
239 if torch.cuda.is_available():
240 args.device = 'cuda'
241 elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
242 args.device = 'mps'
243 else:
244 args.device = 'cpu'
245 except ImportError:
246 args.device = 'cpu'
248 # HuggingFace token
249 hf_token = os.environ.get('HEVOLVE_HF_TOKEN', '')
250 if not hf_token:
251 # Fallback: try config.json in original location
252 for cfg_path in [
253 'config.json',
254 os.path.join(os.path.expanduser('~'), '.hevolve', 'config.json'),
255 ]:
256 if os.path.isfile(cfg_path):
257 try:
258 with open(cfg_path) as f:
259 cfg = json.load(f)
260 hf_token = cfg.get('huggingface', '')
261 if hf_token:
262 break
263 except Exception:
264 pass
266 if not hf_token:
267 print("ERROR: HEVOLVE_HF_TOKEN env var or config.json "
268 "'huggingface' key required", file=sys.stderr)
269 sys.exit(1)
271 # Logging
272 logging.basicConfig(
273 level=logging.INFO,
274 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
275 )
277 asyncio.run(main(args.port, args.device, args.output_dir, hf_token))