Coverage for integrations / agent_engine / gradient_service.py: 72.0%

132 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Gradient Service — Distributed embedding synchronization service. 

3 

4Service Pattern: static methods, db: Session, db.flush() not db.commit(). 

5 

6Submit embedding deltas, aggregate across peers, request witnesses, 

7track convergence. Integrates with CCT (embedding_sync capability), 

8IntegrityService (fraud detection), and FederatedAggregator (gossip). 

9 

10Phase 1: Embedding delta sync (compressed, <100KB, trimmed mean). 

11Phase 2: LoRA gradient sync (stubs in federated_gradient_protocol.py). 

12""" 

13import logging 

14import time 

15import uuid 

16from datetime import datetime, timedelta 

17from typing import Dict, List, Optional 

18 

19logger = logging.getLogger('hevolve_social') 

20 

21# ─── Constants ─── 

22 

23GRADIENT_ROUND_DURATION_SECONDS = 300 # 5 min per aggregation round 

24MAX_DELTAS_PER_ROUND = 200 # Max deltas stored per round 

25WITNESS_REQUIREMENT = 2 # Min witnesses for embedding delta 

26CONVERGENCE_WINDOW = 50 # Rounds to track for convergence trend 

27 

28 

29class GradientSyncService: 

30 """Manages distributed embedding delta submission, aggregation, and witnessing.""" 

31 

32 # ─── Delta Submission ─── 

33 

34 @staticmethod 

35 def submit_embedding_delta(db, node_id: str, delta: Dict, 

36 cct_string: str = None) -> Dict: 

37 """Submit a compressed embedding delta from a node. 

38 

39 Validates: CCT (embedding_sync capability), delta format, magnitude, 

40 direction. Stores as NodeAttestation and feeds to FederatedAggregator. 

41 

42 Returns: {'accepted': bool, 'reason': str, 'attestation_id': str} 

43 """ 

44 # 1. Validate CCT has embedding_sync capability 

45 if cct_string: 

46 try: 

47 from .continual_learner_gate import ContinualLearnerGateService 

48 if not ContinualLearnerGateService.check_cct_capability( 

49 cct_string, 'embedding_sync', node_id): 

50 return {'accepted': False, 'reason': 'cct_no_embedding_sync'} 

51 except Exception: 

52 pass # If CCT check unavailable, allow (graceful degrade) 

53 else: 

54 # Check tier directly from DB 

55 try: 

56 from .continual_learner_gate import ( 

57 ContinualLearnerGateService, LEARNING_ACCESS_MATRIX) 

58 tier_info = ContinualLearnerGateService.compute_learning_tier( 

59 db, node_id) 

60 tier = tier_info.get('tier', 'none') 

61 if 'embedding_sync' not in LEARNING_ACCESS_MATRIX.get(tier, []): 

62 return {'accepted': False, 'reason': 'tier_insufficient', 

63 'current_tier': tier, 'required': 'full'} 

64 except Exception: 

65 pass 

66 

67 # 2. Validate delta format 

68 try: 

69 from .embedding_delta import validate_delta 

70 valid, reason = validate_delta(delta) 

71 if not valid: 

72 return {'accepted': False, 'reason': f'invalid_delta: {reason}'} 

73 except ImportError: 

74 return {'accepted': False, 'reason': 'embedding_delta_module_unavailable'} 

75 

76 # 3. Magnitude anomaly check 

77 anomaly_detected = False 

78 try: 

79 from .embedding_delta import detect_magnitude_anomaly 

80 magnitude = delta.get('magnitude', 0.0) 

81 peer_magnitudes = GradientSyncService._get_peer_magnitudes(db) 

82 if peer_magnitudes: 

83 anomaly_detected = detect_magnitude_anomaly( 

84 magnitude, peer_magnitudes) 

85 except Exception: 

86 pass 

87 

88 if anomaly_detected: 

89 # Record fraud signal but still accept (IntegrityService handles banning) 

90 try: 

91 GradientSyncService._record_gradient_fraud( 

92 db, node_id, 'gradient_magnitude_anomaly', 

93 {'magnitude': delta.get('magnitude', 0)}) 

94 except Exception: 

95 pass 

96 return {'accepted': False, 'reason': 'magnitude_anomaly'} 

97 

98 # 4. Direction flip check (vs previous delta from this node) 

99 direction_flipped = False 

100 try: 

101 from .embedding_delta import detect_direction_flip, decompress_delta 

102 previous = GradientSyncService._get_previous_delta(db, node_id) 

103 if previous: 

104 current_vals = decompress_delta(delta) 

105 prev_vals = decompress_delta(previous) 

106 direction_flipped = detect_direction_flip( 

107 current_vals, prev_vals) 

108 except Exception: 

109 pass 

110 

111 if direction_flipped: 

112 try: 

113 GradientSyncService._record_gradient_fraud( 

114 db, node_id, 'gradient_direction_flip', 

115 {'delta_dimension': delta.get('dimension', 0)}) 

116 except Exception: 

117 pass 

118 return {'accepted': False, 'reason': 'direction_flip'} 

119 

120 # 5. Store as NodeAttestation 

121 attestation_id = None 

122 try: 

123 from integrations.social.models import NodeAttestation 

124 from security.node_integrity import ( 

125 get_public_key_hex, sign_json_payload, get_node_identity) 

126 

127 identity = get_node_identity() 

128 evidence = { 

129 'delta_method': delta.get('method', 'unknown'), 

130 'delta_dimension': delta.get('dimension', 0), 

131 'delta_k': delta.get('k', 0), 

132 'magnitude': delta.get('magnitude', 0), 

133 'submitted_at': datetime.utcnow().isoformat(), 

134 } 

135 sig = sign_json_payload(evidence) 

136 

137 attestation = NodeAttestation( 

138 attester_node_id=identity.get('node_id', 'self'), 

139 subject_node_id=node_id, 

140 attestation_type='embedding_delta', 

141 payload_json={ 

142 'evidence': evidence, 

143 'delta': delta, # Store compressed delta for replay 

144 }, 

145 signature=sig[:256], 

146 attester_public_key=get_public_key_hex(), 

147 is_valid=True, 

148 expires_at=datetime.utcnow() + timedelta(hours=1), 

149 ) 

150 db.add(attestation) 

151 db.flush() 

152 attestation_id = attestation.id 

153 except ImportError: 

154 logger.debug("Cannot store embedding delta attestation: imports unavailable") 

155 except Exception as e: 

156 logger.debug(f"Embedding delta attestation failed: {e}") 

157 

158 # 6. Feed to FederatedAggregator 

159 try: 

160 from .federated_aggregator import get_federated_aggregator 

161 aggregator = get_federated_aggregator() 

162 aggregator.receive_embedding_delta(node_id, delta) 

163 except Exception as e: 

164 logger.debug(f"Aggregator feed failed: {e}") 

165 

166 return { 

167 'accepted': True, 

168 'attestation_id': attestation_id, 

169 'reason': 'ok', 

170 } 

171 

172 # ─── Aggregation Status ─── 

173 

174 @staticmethod 

175 def get_convergence_status(db) -> Dict: 

176 """Get current embedding sync convergence status. 

177 

178 Returns: {'epoch': int, 'peer_count': int, 'convergence_score': float, 

179 'deltas_this_round': int, 'round_duration': int} 

180 """ 

181 try: 

182 from .federated_aggregator import get_federated_aggregator 

183 aggregator = get_federated_aggregator() 

184 stats = aggregator.get_stats() 

185 

186 # Count embedding deltas in current round 

187 delta_count = 0 

188 try: 

189 from integrations.social.models import NodeAttestation 

190 cutoff = datetime.utcnow() - timedelta( 

191 seconds=GRADIENT_ROUND_DURATION_SECONDS) 

192 delta_count = db.query(NodeAttestation).filter( 

193 NodeAttestation.attestation_type == 'embedding_delta', 

194 NodeAttestation.is_valid == True, 

195 NodeAttestation.created_at >= cutoff, 

196 ).count() 

197 except Exception: 

198 pass 

199 

200 # Embedding-specific stats from aggregator 

201 embedding_stats = {} 

202 try: 

203 embedding_stats = aggregator.get_embedding_stats() 

204 except AttributeError: 

205 pass # Aggregator doesn't have embedding channel yet 

206 

207 return { 

208 'epoch': stats.get('epoch', 0), 

209 'peer_count': stats.get('peer_count', 0), 

210 'convergence_score': stats.get('convergence', 0.0), 

211 'deltas_this_round': delta_count, 

212 'round_duration_seconds': GRADIENT_ROUND_DURATION_SECONDS, 

213 'embedding_sync': embedding_stats, 

214 } 

215 except Exception as e: 

216 return {'epoch': 0, 'peer_count': 0, 'convergence_score': 0.0, 

217 'error': str(e)} 

218 

219 # ─── Witness Request ─── 

220 

221 @staticmethod 

222 def request_embedding_witnesses(db, delta: Dict, 

223 node_id: str) -> Dict: 

224 """Request peer witnesses for an embedding delta. 

225 

226 Uses IntegrityService witness pattern: need WITNESS_REQUIREMENT+ peers 

227 to validate. Returns witness request status. 

228 """ 

229 try: 

230 from integrations.social.models import PeerNode 

231 

232 # Find eligible witness peers (active, verified, different node) 

233 witnesses = db.query(PeerNode).filter( 

234 PeerNode.status == 'active', 

235 PeerNode.integrity_status == 'verified', 

236 PeerNode.node_id != node_id, 

237 ).limit(WITNESS_REQUIREMENT * 2).all() 

238 

239 if len(witnesses) < WITNESS_REQUIREMENT: 

240 return { 

241 'witnessed': False, 

242 'reason': 'insufficient_peers', 

243 'available': len(witnesses), 

244 'required': WITNESS_REQUIREMENT, 

245 } 

246 

247 # Request witnesses via gossip 

248 witness_ids = [] 

249 for peer in witnesses[:WITNESS_REQUIREMENT]: 

250 try: 

251 from core.http_pool import pooled_post 

252 url = f"{peer.url.rstrip('/')}/api/social/peers/embedding-delta" 

253 witness_payload = { 

254 'action': 'witness_request', 

255 'delta': delta, 

256 'submitter_node_id': node_id, 

257 'request_id': uuid.uuid4().hex[:12], 

258 } 

259 resp = pooled_post(url, json=witness_payload, timeout=5) 

260 if resp.status_code == 200: 

261 witness_ids.append(peer.node_id) 

262 except Exception: 

263 pass 

264 

265 return { 

266 'witnessed': len(witness_ids) >= WITNESS_REQUIREMENT, 

267 'witness_count': len(witness_ids), 

268 'witness_ids': witness_ids, 

269 'required': WITNESS_REQUIREMENT, 

270 } 

271 except Exception as e: 

272 return {'witnessed': False, 'reason': str(e)} 

273 

274 # ─── Internal Helpers ─── 

275 

276 @staticmethod 

277 def _get_peer_magnitudes(db) -> List[float]: 

278 """Get recent embedding delta magnitudes from all peers.""" 

279 try: 

280 from integrations.social.models import NodeAttestation 

281 cutoff = datetime.utcnow() - timedelta( 

282 seconds=GRADIENT_ROUND_DURATION_SECONDS) 

283 attestations = db.query(NodeAttestation).filter( 

284 NodeAttestation.attestation_type == 'embedding_delta', 

285 NodeAttestation.is_valid == True, 

286 NodeAttestation.created_at >= cutoff, 

287 ).all() 

288 

289 magnitudes = [] 

290 for att in attestations: 

291 payload = att.payload_json or {} 

292 evidence = payload.get('evidence', payload) 

293 mag = evidence.get('magnitude', 0.0) 

294 if isinstance(mag, (int, float)) and mag > 0: 

295 magnitudes.append(mag) 

296 return magnitudes 

297 except Exception: 

298 return [] 

299 

300 @staticmethod 

301 def _get_previous_delta(db, node_id: str) -> Optional[Dict]: 

302 """Get the most recent embedding delta from this node.""" 

303 try: 

304 from integrations.social.models import NodeAttestation 

305 from sqlalchemy import desc 

306 att = db.query(NodeAttestation).filter_by( 

307 subject_node_id=node_id, 

308 attestation_type='embedding_delta', 

309 is_valid=True, 

310 ).order_by(desc(NodeAttestation.created_at)).first() 

311 

312 if att and att.payload_json: 

313 return att.payload_json.get('delta') 

314 except Exception: 

315 pass 

316 return None 

317 

318 @staticmethod 

319 def _record_gradient_fraud(db, node_id: str, signal_type: str, 

320 details: dict): 

321 """Record a gradient fraud signal via IntegrityService.""" 

322 try: 

323 from integrations.social.models import FraudAlert 

324 alert = FraudAlert( 

325 node_id=node_id, 

326 alert_type=signal_type, 

327 severity='medium', 

328 details_json=details, 

329 ) 

330 db.add(alert) 

331 db.flush() 

332 logger.warning(f"Gradient fraud signal: {signal_type} " 

333 f"for node {node_id}") 

334 except Exception as e: 

335 logger.debug(f"Fraud signal record failed: {e}")