Coverage for integrations / agent_engine / budget_gate.py: 85.1%

154 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Budget Gate — pre-dispatch spend control for LLM calls and agent goals. 

3 

4Fail-closed economics: don't spend what you don't have. 

5 

6Functions: 

7 estimate_llm_cost_spark(prompt, model_name) — token-based cost estimate 

8 check_goal_budget(goal_id, estimated_cost) — atomic row-lock deduction 

9 check_platform_affordability() — 7-day net revenue check (cached 60s) 

10 pre_dispatch_budget_gate(goal_id, prompt, model_name) — combined gate 

11 

12Pattern extracted from: speculative_dispatcher._check_and_reserve_budget() (lines 314-340) 

13""" 

14import logging 

15import os 

16import threading 

17import time 

18from typing import Dict, Optional, Tuple 

19 

20logger = logging.getLogger(__name__) 

21 

22# Per-goal budget-check cache. Daemon ticks call check_goal_budget() on 

23# every speculative dispatch; py-spy traces showed the SQLAlchemy first() 

24# + new SQLite connection cycle is the dominant CPU consumer when many 

25# goals × idle_agents fire in tight succession. A short TTL is enough 

26# to break the storm — within 10s the goal's row hasn't materially 

27# changed (this function is the only writer in the daemon path). Burst 

28# under-counting is bounded: at most one un-deducted hit per goal per 

29# TTL window. Cache stores the FULL return tuple so callers see the 

30# exact same shape they'd see from a fresh DB query. 

31_BUDGET_CACHE_TTL_S = 10.0 

32_budget_cache: Dict[str, Tuple[float, Tuple[bool, int, str]]] = {} 

33_budget_cache_lock = threading.Lock() 

34 

35# ── Cost estimation ────────────────────────────────────────────────── 

36 

37# Approximate Spark cost per 1K tokens by model family. 

38# Order matters: most-specific prefix first (gpt-4o-mini before gpt-4o before gpt-4). 

39_MODEL_COST_MAP = { 

40 'gpt-4o-mini': 1, 

41 'gpt-4o': 4, 

42 'gpt-4': 6, 

43 'gpt-3.5': 1, 

44 'groq': 0, # Groq free tier — zero Spark 

45 'llama': 0, # Local model — zero metered cost 

46 'mistral': 0, # Local model 

47 'phi': 0, # Local model 

48 'qwen': 0, # Local model 

49} 

50 

51 

52def _is_local_model() -> bool: 

53 """Detect whether the active LLM is a local model (zero Spark cost). 

54 

55 Delegates to port_registry.is_local_llm() which checks whether the 

56 resolved LLM URL points to localhost/127.0.0.1, or if a local model 

57 name is configured. 

58 """ 

59 from core.port_registry import is_local_llm 

60 return is_local_llm() 

61 

62 

63def estimate_llm_cost_spark(prompt: str, model_name: str = 'gpt-4o') -> int: 

64 """Estimate Spark cost for an LLM call before execution. 

65 

66 Uses tiktoken if available (already in codebase), falls back to word-count 

67 heuristic (~1.3 tokens per word). Returns integer Spark cost (min 1 for 

68 paid models, 0 for local/self-hosted models). 

69 

70 If the active LLM is local (detected via HEVOLVE_LOCAL_LLM_URL env var), 

71 cost is always 0 — local inference has no metered Spark cost. 

72 """ 

73 # Local models cost nothing regardless of the model_name parameter. 

74 # Check env var first (definitive signal that a local backend is active). 

75 if _is_local_model(): 

76 return 0 

77 

78 # Map model to per-1K cost (check BEFORE token counting — skip work for free models) 

79 cost_per_1k = 2 # default for unknown cloud models 

80 model_lower = (model_name or '').lower() 

81 for prefix, cost in _MODEL_COST_MAP.items(): 

82 if prefix in model_lower: 

83 cost_per_1k = cost 

84 break 

85 

86 # Free-tier and local models cost 0 Spark even without the env var. 

87 # This catches cases where model_name is 'qwen', 'llama', 'phi', etc. 

88 # but HEVOLVE_LOCAL_LLM_URL is not explicitly set. 

89 if cost_per_1k == 0: 

90 return 0 

91 

92 # Token count (only computed for paid models) 

93 token_count = 0 

94 try: 

95 import tiktoken 

96 enc = tiktoken.encoding_for_model(model_name) 

97 token_count = len(enc.encode(prompt)) 

98 except Exception: 

99 # Fallback: ~1.3 tokens per word on average 

100 token_count = max(1, int(len(prompt.split()) * 1.3)) 

101 

102 spark_cost = max(1, int((token_count / 1000) * cost_per_1k)) 

103 return spark_cost 

104 

105 

106# ── Goal budget (row-lock atomic deduction) ────────────────────────── 

107 

108def check_goal_budget(goal_id: Optional[str], 

109 estimated_cost: int) -> Tuple[bool, int, str]: 

110 """Check and reserve Spark budget for a goal (atomic row lock). 

111 

112 Extracted from speculative_dispatcher._check_and_reserve_budget(). 

113 Returns: (allowed, remaining_budget, reason) 

114 

115 TTL cache (``_BUDGET_CACHE_TTL_S``) breaks the daemon-tick storm — 

116 repeated calls for the same goal within the window return the cached 

117 tuple without hitting the DB. Bounds under-counting at one 

118 un-deducted hit per goal per window; the only writer to 

119 ``goal.spark_spent`` is this function, so cache freshness is 

120 self-consistent. 

121 """ 

122 if not goal_id: 

123 return True, -1, 'no_goal_constraint' 

124 

125 # ── Cache fast-path ──────────────────────────────────────────────── 

126 now = time.time() 

127 with _budget_cache_lock: 

128 entry = _budget_cache.get(goal_id) 

129 if entry is not None: 

130 cached_ts, cached_result = entry 

131 if (now - cached_ts) < _BUDGET_CACHE_TTL_S: 

132 cached_allowed, cached_remaining, _ = cached_result 

133 # Only honor the cache when the cached remaining still covers 

134 # the current estimated_cost (cost varies per prompt — the 

135 # check the caller actually cares about is "can I afford 

136 # THIS one"). Denied results stay denied for the window; 

137 # allowed results stay allowed only if remaining headroom 

138 # still covers the new cost. 

139 if not cached_allowed: 

140 return cached_result 

141 if cached_remaining == -1 or cached_remaining >= estimated_cost: 

142 return cached_result 

143 

144 try: 

145 from integrations.social.models import get_db, AgentGoal 

146 db = get_db() 

147 try: 

148 goal = db.query(AgentGoal).filter_by( 

149 id=goal_id).with_for_update().first() 

150 if not goal: 

151 result = (True, -1, 'goal_not_found') 

152 with _budget_cache_lock: 

153 _budget_cache[goal_id] = (now, result) 

154 return result 

155 

156 budget = goal.spark_budget or 0 

157 spent = goal.spark_spent or 0 

158 remaining = budget - spent 

159 

160 if remaining < estimated_cost: 

161 db.rollback() 

162 result = (False, remaining, 

163 f'insufficient_budget ({remaining} < {estimated_cost})') 

164 with _budget_cache_lock: 

165 _budget_cache[goal_id] = (now, result) 

166 return result 

167 

168 goal.spark_spent = spent + estimated_cost 

169 db.commit() 

170 result = (True, remaining - estimated_cost, 'budget_reserved') 

171 with _budget_cache_lock: 

172 _budget_cache[goal_id] = (now, result) 

173 return result 

174 finally: 

175 db.close() 

176 except Exception as e: 

177 logger.debug(f"Budget check unavailable: {e}") 

178 return True, -1, 'budget_system_unavailable' 

179 

180 

181def invalidate_goal_budget_cache(goal_id: Optional[str] = None) -> None: 

182 """Clear the budget-check TTL cache. 

183 

184 Call this when the goal's spark_budget changes via a non-daemon 

185 path (admin top-up, manual goal edit, scheduled budget reset). 

186 Keeps the daemon's cache from holding a stale 'denied' verdict 

187 after a top-up. ``goal_id=None`` clears every entry. 

188 """ 

189 with _budget_cache_lock: 

190 if goal_id is None: 

191 _budget_cache.clear() 

192 else: 

193 _budget_cache.pop(goal_id, None) 

194 

195 

196# ── Platform affordability (cached 60s) ────────────────────────────── 

197 

198_affordability_cache: Dict = {} 

199_CACHE_TTL = 60 # seconds 

200 

201 

202def check_platform_affordability() -> Tuple[bool, Dict]: 

203 """Check 7-day platform net revenue flow. 

204 

205 Uses query_revenue_streams() (revenue_aggregator.py) — single source of truth. 

206 Caches result for 60s to avoid per-request DB queries. 

207 Returns: (can_afford, details_dict) 

208 """ 

209 now = time.time() 

210 cached = _affordability_cache.get('result') 

211 if cached and (now - _affordability_cache.get('ts', 0)) < _CACHE_TTL: 

212 return cached 

213 

214 try: 

215 from integrations.social.models import get_db 

216 from integrations.agent_engine.revenue_aggregator import query_revenue_streams 

217 db = get_db() 

218 try: 

219 streams = query_revenue_streams(db, period_days=7) 

220 net = streams['total_gross'] - streams['hosting_payouts'] 

221 can_afford = net >= 0 

222 result = (can_afford, { 

223 'gross_7d': round(streams['total_gross'], 2), 

224 'payouts_7d': round(streams['hosting_payouts'], 2), 

225 'net_7d': round(net, 2), 

226 }) 

227 _affordability_cache['result'] = result 

228 _affordability_cache['ts'] = now 

229 return result 

230 finally: 

231 db.close() 

232 except Exception as e: 

233 logger.debug(f"Affordability check unavailable: {e}") 

234 return True, {'reason': 'affordability_check_unavailable'} 

235 

236 

237# ── Combined gate ──────────────────────────────────────────────────── 

238 

239def _resolve_model_name(model_name: str) -> str: 

240 """Resolve the effective model name for cost estimation. 

241 

242 If the caller passes the default 'gpt-4o' but a local model is actually 

243 active, return the local model name so pricing is correct (0 Spark). 

244 """ 

245 # If caller provided an explicit non-default model name, trust it 

246 if model_name and model_name != 'gpt-4o': 

247 return model_name 

248 

249 # Check if a local model is configured — override the default 'gpt-4o' 

250 local_model = os.environ.get('HEVOLVE_LOCAL_LLM_MODEL', '') 

251 if local_model: 

252 return local_model 

253 

254 # If the resolved LLM URL points to localhost, the active model 

255 # is local even though we don't know the exact name — use 'llama' 

256 # which maps to 0 Spark in _MODEL_COST_MAP. 

257 if _is_local_model(): 

258 return 'llama' 

259 

260 return model_name 

261 

262 

263def pre_dispatch_budget_gate(goal_id: Optional[str], 

264 prompt: str, 

265 model_name: str = 'gpt-4o') -> Tuple[bool, str]: 

266 """Combined pre-dispatch budget gate. 

267 

268 1. Resolve effective model name (local vs cloud) 

269 2. Estimate LLM cost 

270 3. Check goal budget (atomic deduction) 

271 4. Check platform affordability (cached) 

272 

273 Returns: (allowed, reason) 

274 """ 

275 model_name = _resolve_model_name(model_name) 

276 estimated_cost = estimate_llm_cost_spark(prompt, model_name) 

277 

278 # Goal-level budget 

279 allowed, remaining, reason = check_goal_budget(goal_id, estimated_cost) 

280 if not allowed: 

281 logger.warning(f"Budget gate BLOCKED: goal={goal_id}, {reason}") 

282 return False, f'goal_budget_exceeded: {reason}' 

283 

284 # Platform-level affordability 

285 can_afford, details = check_platform_affordability() 

286 if not can_afford: 

287 logger.warning(f"Budget gate BLOCKED: platform not affordable: {details}") 

288 return False, f'platform_not_affordable: net_7d={details.get("net_7d", "?")}' 

289 

290 return True, f'allowed (est_cost={estimated_cost}, remaining={remaining})' 

291 

292 

293# ── Metered API usage recording ────────────────────────────────────── 

294 

295def record_metered_usage(node_id: str, model_id: str, task_source: str, 

296 tokens_in: int, tokens_out: int, 

297 cost_per_1k: float, 

298 goal_id: str = None, 

299 requester_node_id: str = None) -> Optional[str]: 

300 """Record metered API usage for cost recovery. Returns usage ID or None. 

301 

302 Called after every non-local LLM call. If task_source != 'own', creates 

303 a MeteredAPIUsage record so the revenue agent can settle it. 

304 Only records for metered (non-local) models with cost > 0. 

305 """ 

306 if cost_per_1k <= 0: 

307 return None # Local model — no cost to recover 

308 

309 actual_usd_cost = ((tokens_in + tokens_out) / 1000.0) * cost_per_1k 

310 if actual_usd_cost <= 0: 

311 return None 

312 

313 # Check daily limit for hive/idle tasks 

314 if task_source in ('hive', 'idle'): 

315 try: 

316 from integrations.agent_engine.compute_config import get_compute_policy 

317 policy = get_compute_policy(os.environ.get('HEVOLVE_NODE_ID')) 

318 daily_limit = policy.get('metered_daily_limit_usd', 0.0) 

319 if daily_limit > 0: 

320 # Check today's spend 

321 from integrations.social.models import db_session, MeteredAPIUsage 

322 from sqlalchemy import func as sa_func 

323 from datetime import datetime, timedelta 

324 with db_session() as db: 

325 today_start = datetime.utcnow().replace( 

326 hour=0, minute=0, second=0, microsecond=0) 

327 today_spend = db.query( 

328 sa_func.coalesce(sa_func.sum(MeteredAPIUsage.actual_usd_cost), 0) 

329 ).filter( 

330 MeteredAPIUsage.node_id == node_id, 

331 MeteredAPIUsage.task_source.in_(['hive', 'idle']), 

332 MeteredAPIUsage.created_at >= today_start, 

333 ).scalar() or 0.0 

334 if today_spend + actual_usd_cost > daily_limit: 

335 logger.warning( 

336 f"Metered daily limit exceeded: " 

337 f"${today_spend:.2f}+${actual_usd_cost:.2f} > ${daily_limit:.2f}") 

338 return None 

339 except Exception as e: 

340 logger.debug(f"Daily limit check skipped: {e}") 

341 

342 # Look up operator_id from PeerNode 

343 operator_id = None 

344 try: 

345 from integrations.social.models import db_session, PeerNode 

346 with db_session() as db: 

347 peer = db.query(PeerNode).filter_by(node_id=node_id).first() 

348 if peer: 

349 operator_id = peer.node_operator_id 

350 except Exception: 

351 pass 

352 

353 # Estimate Spark cost 

354 estimated_spark = max(1, int(actual_usd_cost * int( 

355 os.environ.get('HEVOLVE_SPARK_PER_USD', '100')))) 

356 

357 # Write MeteredAPIUsage record 

358 try: 

359 from integrations.social.models import db_session, MeteredAPIUsage 

360 with db_session() as db: 

361 usage = MeteredAPIUsage( 

362 node_id=node_id, 

363 operator_id=operator_id, 

364 model_id=model_id, 

365 task_source=task_source, 

366 goal_id=goal_id, 

367 requester_node_id=requester_node_id, 

368 tokens_in=tokens_in, 

369 tokens_out=tokens_out, 

370 cost_per_1k_tokens=cost_per_1k, 

371 estimated_spark_cost=estimated_spark, 

372 actual_usd_cost=actual_usd_cost, 

373 settlement_status='pending' if task_source != 'own' else 'settled', 

374 ) 

375 db.add(usage) 

376 db.commit() 

377 logger.debug(f"Metered usage recorded: model={model_id}, " 

378 f"source={task_source}, cost=${actual_usd_cost:.4f}") 

379 return usage.id 

380 except Exception as e: 

381 logger.debug(f"Metered usage recording failed: {e}") 

382 return None