Coverage for integrations / providers / gateway.py: 44.9%
287 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2ProviderGateway — smart router that agents call for any AI task.
4Usage:
5 from integrations.providers import get_gateway
6 gw = get_gateway()
8 # Text generation
9 result = gw.generate('Tell me a joke', model_type='llm')
11 # Image generation
12 result = gw.generate('A cat in space', model_type='image_gen')
14 # Specific model on specific provider
15 result = gw.generate('Hello', provider_id='groq', model_id='llama-3.3-70b-versatile')
17The gateway:
18 1. Picks the best provider (cheapest/fastest/balanced) from the registry
19 2. Calls the provider's API (OpenAI-compatible, Replicate, or custom)
20 3. Tracks cost, latency, tok/s — feeds back into registry stats
21 4. Falls back to next-best provider on failure
22 5. Falls back to local model as last resort
23"""
25import json
26import logging
27import os
28import time
29import threading
30from typing import Any, Dict, List, Optional, Generator
31from dataclasses import dataclass, field
33logger = logging.getLogger(__name__)
36@dataclass
37class GatewayResult:
38 """Result from a gateway call."""
39 success: bool
40 content: str = '' # Text response or URL for media
41 provider_id: str = ''
42 model_id: str = ''
43 usage: Dict[str, Any] = field(default_factory=dict)
44 cost_usd: float = 0.0 # Estimated cost in USD
45 latency_ms: float = 0.0
46 tok_per_s: float = 0.0
47 model_type: str = 'llm' # Request type for revenue tracking
48 error: str = ''
49 raw_response: Any = None # Full API response for advanced use
52class ProviderGateway:
53 """Smart router for all AI API calls.
55 Agents call gateway methods. The gateway picks the optimal provider
56 from the registry, calls the API, tracks stats, handles failures.
57 """
59 def __init__(self):
60 from integrations.providers.registry import get_registry
61 self._registry = get_registry()
62 self._usage_lock = threading.Lock()
63 self._total_cost_usd = 0.0
64 self._total_requests = 0
65 self._request_log: List[Dict] = [] # Last 100 requests for dashboard
67 # ═══════════════════════════════════════════════════════════════════
68 # Public API — what agents call
69 # ═══════════════════════════════════════════════════════════════════
71 def generate(self, prompt: str, model_type: str = 'llm',
72 provider_id: str = '', model_id: str = '',
73 strategy: str = 'balanced',
74 system_prompt: str = '',
75 max_tokens: int = 4096,
76 temperature: float = 0.7,
77 stream: bool = False,
78 **kwargs) -> GatewayResult:
79 """Generate content via the optimal provider.
81 Args:
82 prompt: User prompt or generation instruction
83 model_type: 'llm', 'image_gen', 'video_gen', 'tts', etc.
84 provider_id: Force a specific provider (optional)
85 model_id: Force a specific model (optional)
86 strategy: 'cheapest', 'fastest', 'quality', 'balanced'
87 system_prompt: System message for LLMs
88 max_tokens: Max output tokens for LLMs
89 temperature: Sampling temperature
90 stream: Whether to stream (LLM only)
91 **kwargs: Provider-specific params (image size, video duration, etc.)
92 """
93 t0 = time.time()
95 # Resolve provider + model
96 provider, provider_model = self._resolve(
97 model_type, provider_id, model_id, strategy)
99 if not provider:
100 return GatewayResult(
101 success=False,
102 error=f'No provider available for {model_type}'
103 f'{" model=" + model_id if model_id else ""}'
104 f' (strategy={strategy}). Configure API keys in Settings.',
105 )
107 # Try the primary provider, then fallbacks
108 providers_tried = []
109 result = self._call_provider(
110 provider, provider_model, prompt, model_type,
111 system_prompt=system_prompt, max_tokens=max_tokens,
112 temperature=temperature, stream=stream, **kwargs,
113 )
114 providers_tried.append(provider.id)
116 # On failure, try fallbacks (up to 2 more providers)
117 if not result.success:
118 for _ in range(2):
119 fb_provider, fb_model = self._resolve(
120 model_type, '', '', strategy,
121 exclude=providers_tried,
122 )
123 if not fb_provider:
124 break
125 result = self._call_provider(
126 fb_provider, fb_model, prompt, model_type,
127 system_prompt=system_prompt, max_tokens=max_tokens,
128 temperature=temperature, stream=stream, **kwargs,
129 )
130 providers_tried.append(fb_provider.id)
131 if result.success:
132 break
134 # Track stats
135 elapsed_ms = (time.time() - t0) * 1000
136 result.latency_ms = elapsed_ms
137 result.model_type = model_type
138 self._track(result)
140 return result
142 def generate_stream(self, prompt: str, model_type: str = 'llm',
143 **kwargs) -> Generator[str, None, None]:
144 """Stream text generation. Yields chunks."""
145 kwargs['stream'] = True
146 # For streaming, we need to handle it differently
147 provider_id = kwargs.pop('provider_id', '')
148 model_id = kwargs.pop('model_id', '')
149 strategy = kwargs.pop('strategy', 'balanced')
150 system_prompt = kwargs.pop('system_prompt', '')
151 max_tokens = kwargs.pop('max_tokens', 4096)
152 temperature = kwargs.pop('temperature', 0.7)
154 provider, provider_model = self._resolve(
155 model_type, provider_id, model_id, strategy)
157 if not provider:
158 yield '[Error: No provider available]'
159 return
161 yield from self._stream_openai(
162 provider, provider_model, prompt,
163 system_prompt=system_prompt,
164 max_tokens=max_tokens,
165 temperature=temperature,
166 **kwargs,
167 )
169 def get_stats(self) -> Dict[str, Any]:
170 """Return gateway usage stats for dashboards."""
171 with self._usage_lock:
172 return {
173 'total_cost_usd': round(self._total_cost_usd, 6),
174 'total_requests': self._total_requests,
175 'recent_requests': list(self._request_log[-20:]),
176 'capabilities': self._registry.get_capabilities_summary(),
177 }
179 # ═══════════════════════════════════════════════════════════════════
180 # Resolution — pick the best provider
181 # ═══════════════════════════════════════════════════════════════════
183 def _resolve(self, model_type, provider_id, model_id, strategy,
184 exclude=None):
185 """Resolve the best (Provider, ProviderModel) for the request."""
186 from integrations.providers.registry import (
187 Provider, ProviderModel, PROVIDER_TYPE_LOCAL)
189 exclude = exclude or []
191 # Specific provider requested
192 if provider_id:
193 p = self._registry.get(provider_id)
194 if p and p.id not in exclude:
195 if model_id and model_id in p.models:
196 return p, p.models[model_id]
197 # Find first matching model type
198 for pm in p.models.values():
199 if pm.model_type == model_type and pm.enabled:
200 return p, pm
201 return None, None
203 # Auto-select from registry
204 result = self._registry.find_best(model_type, strategy=strategy)
205 if result:
206 p, pm = result
207 if p.id not in exclude:
208 return p, pm
210 # Try all candidates excluding already-tried
211 for p in self._registry.list_api_providers():
212 if p.id in exclude or not p.has_api_key():
213 continue
214 for pm in p.models.values():
215 if pm.model_type == model_type and pm.enabled:
216 return p, pm
218 # Last resort: local provider
219 local = self._registry.get('local')
220 if local and local.id not in exclude:
221 return local, ProviderModel(
222 model_id='local', canonical_id='local',
223 model_type=model_type,
224 )
226 return None, None
228 # ═══════════════════════════════════════════════════════════════════
229 # Provider Callers — format-specific API calls
230 # ═══════════════════════════════════════════════════════════════════
232 @staticmethod
233 def _build_headers(provider) -> dict:
234 """Build HTTP headers with correct auth for a provider (DRY)."""
235 headers = {'Content-Type': 'application/json'}
236 api_key = provider.get_api_key()
237 if api_key:
238 if provider.id == 'fal':
239 headers['Authorization'] = f'Key {api_key}'
240 elif provider.auth_method == 'header':
241 headers[provider.auth_header] = f'Bearer {api_key}'
242 else: # bearer (default)
243 headers['Authorization'] = f'Bearer {api_key}'
244 return headers
246 def _call_provider(self, provider, provider_model, prompt, model_type,
247 **kwargs) -> GatewayResult:
248 """Dispatch to the correct API format handler."""
249 from integrations.providers.registry import PROVIDER_TYPE_LOCAL
251 try:
252 if provider.provider_type == PROVIDER_TYPE_LOCAL:
253 return self._call_local(prompt, model_type, **kwargs)
254 elif provider.api_format == 'openai':
255 return self._call_openai(provider, provider_model, prompt,
256 model_type, **kwargs)
257 elif provider.api_format == 'replicate':
258 return self._call_replicate(provider, provider_model, prompt,
259 model_type, **kwargs)
260 else:
261 return self._call_custom(provider, provider_model, prompt,
262 model_type, **kwargs)
263 except Exception as e:
264 logger.error("Provider %s call failed: %s", provider.id, e)
265 self._registry.update_model_stats(
266 provider.id, provider_model.model_id, success=False)
267 return GatewayResult(
268 success=False, error=str(e),
269 provider_id=provider.id, model_id=provider_model.model_id,
270 )
272 def _call_openai(self, provider, provider_model, prompt, model_type,
273 system_prompt='', max_tokens=4096, temperature=0.7,
274 stream=False, **kwargs) -> GatewayResult:
275 """Call an OpenAI-compatible API (Together, Fireworks, Groq, etc.)."""
276 import urllib.request
277 import urllib.error
279 url = f"{provider.base_url.rstrip('/')}/chat/completions"
281 messages = []
282 if system_prompt:
283 messages.append({'role': 'system', 'content': system_prompt})
284 messages.append({'role': 'user', 'content': prompt})
286 body = {
287 'model': provider_model.model_id,
288 'messages': messages,
289 'max_tokens': max_tokens,
290 'temperature': temperature,
291 'stream': False,
292 }
294 headers = self._build_headers(provider)
296 t0 = time.time()
297 req = urllib.request.Request(
298 url, data=json.dumps(body).encode(),
299 headers=headers, method='POST',
300 )
301 try:
302 with urllib.request.urlopen(req, timeout=120) as resp:
303 data = json.loads(resp.read().decode())
304 except urllib.error.HTTPError as e:
305 error_body = e.read().decode() if e.fp else ''
306 logger.error("OpenAI API error %d from %s: %s",
307 e.code, provider.id, error_body[:500])
308 return GatewayResult(
309 success=False,
310 error=f'HTTP {e.code}: {error_body[:200]}',
311 provider_id=provider.id,
312 model_id=provider_model.model_id,
313 )
315 elapsed_ms = (time.time() - t0) * 1000
317 # Parse response
318 content = ''
319 usage = data.get('usage', {})
320 if 'choices' in data and data['choices']:
321 content = data['choices'][0].get('message', {}).get('content', '')
323 # Calculate cost
324 input_tokens = usage.get('prompt_tokens', 0)
325 output_tokens = usage.get('completion_tokens', 0)
326 total_tokens = input_tokens + output_tokens
327 cost = self._calculate_cost(provider_model, input_tokens, output_tokens)
328 tok_per_s = (output_tokens / (elapsed_ms / 1000)) if elapsed_ms > 0 and output_tokens > 0 else 0
330 # Update provider stats
331 self._registry.update_model_stats(
332 provider.id, provider_model.model_id,
333 tok_per_s=tok_per_s, latency_ms=elapsed_ms, success=True,
334 )
336 return GatewayResult(
337 success=True,
338 content=content,
339 provider_id=provider.id,
340 model_id=provider_model.model_id,
341 usage={'input_tokens': input_tokens, 'output_tokens': output_tokens,
342 'total_tokens': total_tokens},
343 cost_usd=cost,
344 latency_ms=elapsed_ms,
345 tok_per_s=tok_per_s,
346 raw_response=data,
347 )
349 def _stream_openai(self, provider, provider_model, prompt,
350 system_prompt='', max_tokens=4096, temperature=0.7,
351 **kwargs) -> Generator[str, None, None]:
352 """Stream from OpenAI-compatible API."""
353 import urllib.request
355 url = f"{provider.base_url.rstrip('/')}/chat/completions"
357 messages = []
358 if system_prompt:
359 messages.append({'role': 'system', 'content': system_prompt})
360 messages.append({'role': 'user', 'content': prompt})
362 body = {
363 'model': provider_model.model_id,
364 'messages': messages,
365 'max_tokens': max_tokens,
366 'temperature': temperature,
367 'stream': True,
368 }
370 headers = self._build_headers(provider)
372 req = urllib.request.Request(
373 url, data=json.dumps(body).encode(),
374 headers=headers, method='POST',
375 )
376 try:
377 resp = urllib.request.urlopen(req, timeout=120)
378 for line in resp:
379 line = line.decode('utf-8').strip()
380 if line.startswith('data: ') and line != 'data: [DONE]':
381 try:
382 chunk = json.loads(line[6:])
383 delta = chunk.get('choices', [{}])[0].get('delta', {})
384 text = delta.get('content', '')
385 if text:
386 yield text
387 except json.JSONDecodeError:
388 continue
389 resp.close()
390 except Exception as e:
391 yield f'\n[Stream error: {e}]'
393 def _call_replicate(self, provider, provider_model, prompt, model_type,
394 **kwargs) -> GatewayResult:
395 """Call Replicate's prediction API."""
396 import urllib.request
397 import urllib.error
399 api_key = provider.get_api_key()
400 url = f"{provider.base_url.rstrip('/')}/predictions"
402 # Replicate uses a different input format per model
403 input_data = {'prompt': prompt}
404 if model_type == 'image_gen':
405 input_data.update({
406 'width': kwargs.get('width', 1024),
407 'height': kwargs.get('height', 1024),
408 'num_outputs': kwargs.get('num_outputs', 1),
409 })
410 elif model_type == 'video_gen':
411 input_data['duration'] = kwargs.get('duration', 5)
413 body = {
414 'version': provider_model.model_id,
415 'input': input_data,
416 }
418 headers = self._build_headers(provider)
419 headers['Prefer'] = 'wait' # Synchronous mode
421 t0 = time.time()
422 req = urllib.request.Request(
423 url, data=json.dumps(body).encode(),
424 headers=headers, method='POST',
425 )
426 try:
427 with urllib.request.urlopen(req, timeout=300) as resp:
428 data = json.loads(resp.read().decode())
429 except urllib.error.HTTPError as e:
430 return GatewayResult(
431 success=False, error=f'Replicate HTTP {e.code}',
432 provider_id=provider.id, model_id=provider_model.model_id,
433 )
435 elapsed_ms = (time.time() - t0) * 1000
436 output = data.get('output', '')
437 if isinstance(output, list):
438 output = output[0] if output else ''
440 return GatewayResult(
441 success=True, content=str(output),
442 provider_id=provider.id, model_id=provider_model.model_id,
443 latency_ms=elapsed_ms, raw_response=data,
444 )
446 def _call_custom(self, provider, provider_model, prompt, model_type,
447 **kwargs) -> GatewayResult:
448 """Call custom API format (fal.ai, HuggingFace, etc.)."""
449 import urllib.request
450 import urllib.error
452 api_key = provider.get_api_key()
454 if provider.id == 'fal':
455 return self._call_fal(provider, provider_model, prompt,
456 model_type, api_key, **kwargs)
458 # Generic: POST JSON to base_url/model_id
459 url = f"{provider.base_url.rstrip('/')}/{provider_model.model_id}"
460 body = {'inputs': prompt}
462 headers = self._build_headers(provider)
464 t0 = time.time()
465 req = urllib.request.Request(
466 url, data=json.dumps(body).encode(),
467 headers=headers, method='POST',
468 )
469 try:
470 with urllib.request.urlopen(req, timeout=120) as resp:
471 data = json.loads(resp.read().decode())
472 except Exception as e:
473 return GatewayResult(
474 success=False, error=str(e),
475 provider_id=provider.id, model_id=provider_model.model_id,
476 )
478 elapsed_ms = (time.time() - t0) * 1000
479 content = data if isinstance(data, str) else json.dumps(data)
481 return GatewayResult(
482 success=True, content=content,
483 provider_id=provider.id, model_id=provider_model.model_id,
484 latency_ms=elapsed_ms, raw_response=data,
485 )
487 def _call_fal(self, provider, provider_model, prompt, model_type,
488 api_key, **kwargs) -> GatewayResult:
489 """Call fal.ai serverless API."""
490 import urllib.request
492 url = f"https://fal.run/{provider_model.model_id}"
493 body = {'prompt': prompt}
494 if model_type == 'image_gen':
495 body.update({
496 'image_size': kwargs.get('image_size', 'landscape_16_9'),
497 'num_images': kwargs.get('num_images', 1),
498 })
500 headers = self._build_headers(provider)
502 t0 = time.time()
503 req = urllib.request.Request(
504 url, data=json.dumps(body).encode(),
505 headers=headers, method='POST',
506 )
507 try:
508 with urllib.request.urlopen(req, timeout=300) as resp:
509 data = json.loads(resp.read().decode())
510 except Exception as e:
511 return GatewayResult(
512 success=False, error=str(e),
513 provider_id=provider.id, model_id=provider_model.model_id,
514 )
516 elapsed_ms = (time.time() - t0) * 1000
517 # fal.ai returns images/videos in 'images' or 'video' fields
518 content = ''
519 if 'images' in data:
520 content = data['images'][0].get('url', '') if data['images'] else ''
521 elif 'video' in data:
522 content = data['video'].get('url', '')
523 elif 'audio' in data:
524 content = data['audio'].get('url', '')
525 else:
526 content = json.dumps(data)
528 return GatewayResult(
529 success=True, content=content,
530 provider_id=provider.id, model_id=provider_model.model_id,
531 latency_ms=elapsed_ms, raw_response=data,
532 )
534 def _call_local(self, prompt, model_type, **kwargs) -> GatewayResult:
535 """Route to local model via existing HARTOS infrastructure."""
536 try:
537 if model_type == 'llm':
538 # Use existing LangChain / llama.cpp path
539 import urllib.request
540 url = os.environ.get('HEVOLVE_LOCAL_LLM_URL',
541 'http://localhost:8080/v1')
542 body = {
543 'model': 'local',
544 'messages': [{'role': 'user', 'content': prompt}],
545 'max_tokens': kwargs.get('max_tokens', 4096),
546 'temperature': kwargs.get('temperature', 0.7),
547 'stream': False,
548 }
549 if kwargs.get('system_prompt'):
550 body['messages'].insert(0, {
551 'role': 'system',
552 'content': kwargs['system_prompt'],
553 })
554 req = urllib.request.Request(
555 f"{url.rstrip('/')}/chat/completions",
556 data=json.dumps(body).encode(),
557 headers={'Content-Type': 'application/json'},
558 method='POST',
559 )
560 t0 = time.time()
561 with urllib.request.urlopen(req, timeout=120) as resp:
562 data = json.loads(resp.read().decode())
563 elapsed_ms = (time.time() - t0) * 1000
564 content = data.get('choices', [{}])[0].get(
565 'message', {}).get('content', '')
566 return GatewayResult(
567 success=True, content=content,
568 provider_id='local', model_id='local-llm',
569 latency_ms=elapsed_ms, cost_usd=0.0,
570 )
571 else:
572 return GatewayResult(
573 success=False,
574 error=f'Local {model_type} not yet implemented via gateway',
575 provider_id='local',
576 )
577 except Exception as e:
578 return GatewayResult(
579 success=False, error=f'Local call failed: {e}',
580 provider_id='local',
581 )
583 # ═══════════════════════════════════════════════════════════════════
584 # Cost calculation
585 # ═══════════════════════════════════════════════════════════════════
587 @staticmethod
588 def _calculate_cost(provider_model, input_tokens, output_tokens):
589 from integrations.providers.registry import (
590 PRICE_PER_1M_TOKENS, PRICE_PER_1K_TOKENS, PRICE_PER_IMAGE,
591 PRICE_PER_SECOND, PRICE_PER_REQUEST, PRICE_FREE,
592 )
593 unit = provider_model.pricing_unit
594 if unit == PRICE_FREE:
595 return 0.0
596 if unit == PRICE_PER_1M_TOKENS:
597 return (input_tokens * provider_model.input_price / 1_000_000 +
598 output_tokens * provider_model.output_price / 1_000_000)
599 if unit == PRICE_PER_1K_TOKENS:
600 return (input_tokens * provider_model.input_price / 1_000 +
601 output_tokens * provider_model.output_price / 1_000)
602 if unit in (PRICE_PER_IMAGE, PRICE_PER_REQUEST):
603 return provider_model.input_price
604 if unit == PRICE_PER_SECOND:
605 return provider_model.input_price # Per-second, duration-dependent
606 return 0.0
608 # ═══════════════════════════════════════════════════════════════════
609 # Tracking
610 # ═══════════════════════════════════════════════════════════════════
612 def _track(self, result: GatewayResult):
613 with self._usage_lock:
614 self._total_requests += 1
615 self._total_cost_usd += result.cost_usd
616 self._request_log.append({
617 'ts': time.time(),
618 'provider': result.provider_id,
619 'model': result.model_id,
620 'success': result.success,
621 'cost': result.cost_usd,
622 'latency_ms': result.latency_ms,
623 'tok_per_s': result.tok_per_s,
624 })
625 # Keep last 100
626 if len(self._request_log) > 100:
627 self._request_log = self._request_log[-100:]
629 # Feed into efficiency matrix (continuous learning)
630 try:
631 from integrations.providers.efficiency_matrix import get_matrix
632 get_matrix().record_request(
633 provider_id=result.provider_id,
634 model_id=result.model_id,
635 tok_per_s=result.tok_per_s,
636 e2e_ms=result.latency_ms,
637 cost_usd=result.cost_usd,
638 output_tokens=result.usage.get('output_tokens', 0),
639 success=result.success,
640 )
641 except Exception:
642 pass
644 # Feed into revenue tracker (cost side — revenue recorded by affiliate layer)
645 if result.cost_usd > 0:
646 try:
647 from integrations.providers.revenue_tracker import get_tracker
648 get_tracker().record_cost(
649 provider_id=result.provider_id,
650 model_id=result.model_id,
651 cost_usd=result.cost_usd,
652 tokens_used=result.usage.get('total_tokens', 0),
653 request_type=result.model_type,
654 )
655 except Exception:
656 pass
659# ═══════════════════════════════════════════════════════════════════════
660# Singleton
661# ═══════════════════════════════════════════════════════════════════════
663_gateway: Optional[ProviderGateway] = None
664_gateway_lock = threading.Lock()
667def get_gateway() -> ProviderGateway:
668 global _gateway
669 if _gateway is None:
670 with _gateway_lock:
671 if _gateway is None:
672 _gateway = ProviderGateway()
673 return _gateway