Coverage for integrations / agent_engine / gradient_service.py: 72.0%
132 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Gradient Service — Distributed embedding synchronization service.
4Service Pattern: static methods, db: Session, db.flush() not db.commit().
6Submit embedding deltas, aggregate across peers, request witnesses,
7track convergence. Integrates with CCT (embedding_sync capability),
8IntegrityService (fraud detection), and FederatedAggregator (gossip).
10Phase 1: Embedding delta sync (compressed, <100KB, trimmed mean).
11Phase 2: LoRA gradient sync (stubs in federated_gradient_protocol.py).
12"""
13import logging
14import time
15import uuid
16from datetime import datetime, timedelta
17from typing import Dict, List, Optional
19logger = logging.getLogger('hevolve_social')
21# ─── Constants ───
23GRADIENT_ROUND_DURATION_SECONDS = 300 # 5 min per aggregation round
24MAX_DELTAS_PER_ROUND = 200 # Max deltas stored per round
25WITNESS_REQUIREMENT = 2 # Min witnesses for embedding delta
26CONVERGENCE_WINDOW = 50 # Rounds to track for convergence trend
29class GradientSyncService:
30 """Manages distributed embedding delta submission, aggregation, and witnessing."""
32 # ─── Delta Submission ───
34 @staticmethod
35 def submit_embedding_delta(db, node_id: str, delta: Dict,
36 cct_string: str = None) -> Dict:
37 """Submit a compressed embedding delta from a node.
39 Validates: CCT (embedding_sync capability), delta format, magnitude,
40 direction. Stores as NodeAttestation and feeds to FederatedAggregator.
42 Returns: {'accepted': bool, 'reason': str, 'attestation_id': str}
43 """
44 # 1. Validate CCT has embedding_sync capability
45 if cct_string:
46 try:
47 from .continual_learner_gate import ContinualLearnerGateService
48 if not ContinualLearnerGateService.check_cct_capability(
49 cct_string, 'embedding_sync', node_id):
50 return {'accepted': False, 'reason': 'cct_no_embedding_sync'}
51 except Exception:
52 pass # If CCT check unavailable, allow (graceful degrade)
53 else:
54 # Check tier directly from DB
55 try:
56 from .continual_learner_gate import (
57 ContinualLearnerGateService, LEARNING_ACCESS_MATRIX)
58 tier_info = ContinualLearnerGateService.compute_learning_tier(
59 db, node_id)
60 tier = tier_info.get('tier', 'none')
61 if 'embedding_sync' not in LEARNING_ACCESS_MATRIX.get(tier, []):
62 return {'accepted': False, 'reason': 'tier_insufficient',
63 'current_tier': tier, 'required': 'full'}
64 except Exception:
65 pass
67 # 2. Validate delta format
68 try:
69 from .embedding_delta import validate_delta
70 valid, reason = validate_delta(delta)
71 if not valid:
72 return {'accepted': False, 'reason': f'invalid_delta: {reason}'}
73 except ImportError:
74 return {'accepted': False, 'reason': 'embedding_delta_module_unavailable'}
76 # 3. Magnitude anomaly check
77 anomaly_detected = False
78 try:
79 from .embedding_delta import detect_magnitude_anomaly
80 magnitude = delta.get('magnitude', 0.0)
81 peer_magnitudes = GradientSyncService._get_peer_magnitudes(db)
82 if peer_magnitudes:
83 anomaly_detected = detect_magnitude_anomaly(
84 magnitude, peer_magnitudes)
85 except Exception:
86 pass
88 if anomaly_detected:
89 # Record fraud signal but still accept (IntegrityService handles banning)
90 try:
91 GradientSyncService._record_gradient_fraud(
92 db, node_id, 'gradient_magnitude_anomaly',
93 {'magnitude': delta.get('magnitude', 0)})
94 except Exception:
95 pass
96 return {'accepted': False, 'reason': 'magnitude_anomaly'}
98 # 4. Direction flip check (vs previous delta from this node)
99 direction_flipped = False
100 try:
101 from .embedding_delta import detect_direction_flip, decompress_delta
102 previous = GradientSyncService._get_previous_delta(db, node_id)
103 if previous:
104 current_vals = decompress_delta(delta)
105 prev_vals = decompress_delta(previous)
106 direction_flipped = detect_direction_flip(
107 current_vals, prev_vals)
108 except Exception:
109 pass
111 if direction_flipped:
112 try:
113 GradientSyncService._record_gradient_fraud(
114 db, node_id, 'gradient_direction_flip',
115 {'delta_dimension': delta.get('dimension', 0)})
116 except Exception:
117 pass
118 return {'accepted': False, 'reason': 'direction_flip'}
120 # 5. Store as NodeAttestation
121 attestation_id = None
122 try:
123 from integrations.social.models import NodeAttestation
124 from security.node_integrity import (
125 get_public_key_hex, sign_json_payload, get_node_identity)
127 identity = get_node_identity()
128 evidence = {
129 'delta_method': delta.get('method', 'unknown'),
130 'delta_dimension': delta.get('dimension', 0),
131 'delta_k': delta.get('k', 0),
132 'magnitude': delta.get('magnitude', 0),
133 'submitted_at': datetime.utcnow().isoformat(),
134 }
135 sig = sign_json_payload(evidence)
137 attestation = NodeAttestation(
138 attester_node_id=identity.get('node_id', 'self'),
139 subject_node_id=node_id,
140 attestation_type='embedding_delta',
141 payload_json={
142 'evidence': evidence,
143 'delta': delta, # Store compressed delta for replay
144 },
145 signature=sig[:256],
146 attester_public_key=get_public_key_hex(),
147 is_valid=True,
148 expires_at=datetime.utcnow() + timedelta(hours=1),
149 )
150 db.add(attestation)
151 db.flush()
152 attestation_id = attestation.id
153 except ImportError:
154 logger.debug("Cannot store embedding delta attestation: imports unavailable")
155 except Exception as e:
156 logger.debug(f"Embedding delta attestation failed: {e}")
158 # 6. Feed to FederatedAggregator
159 try:
160 from .federated_aggregator import get_federated_aggregator
161 aggregator = get_federated_aggregator()
162 aggregator.receive_embedding_delta(node_id, delta)
163 except Exception as e:
164 logger.debug(f"Aggregator feed failed: {e}")
166 return {
167 'accepted': True,
168 'attestation_id': attestation_id,
169 'reason': 'ok',
170 }
172 # ─── Aggregation Status ───
174 @staticmethod
175 def get_convergence_status(db) -> Dict:
176 """Get current embedding sync convergence status.
178 Returns: {'epoch': int, 'peer_count': int, 'convergence_score': float,
179 'deltas_this_round': int, 'round_duration': int}
180 """
181 try:
182 from .federated_aggregator import get_federated_aggregator
183 aggregator = get_federated_aggregator()
184 stats = aggregator.get_stats()
186 # Count embedding deltas in current round
187 delta_count = 0
188 try:
189 from integrations.social.models import NodeAttestation
190 cutoff = datetime.utcnow() - timedelta(
191 seconds=GRADIENT_ROUND_DURATION_SECONDS)
192 delta_count = db.query(NodeAttestation).filter(
193 NodeAttestation.attestation_type == 'embedding_delta',
194 NodeAttestation.is_valid == True,
195 NodeAttestation.created_at >= cutoff,
196 ).count()
197 except Exception:
198 pass
200 # Embedding-specific stats from aggregator
201 embedding_stats = {}
202 try:
203 embedding_stats = aggregator.get_embedding_stats()
204 except AttributeError:
205 pass # Aggregator doesn't have embedding channel yet
207 return {
208 'epoch': stats.get('epoch', 0),
209 'peer_count': stats.get('peer_count', 0),
210 'convergence_score': stats.get('convergence', 0.0),
211 'deltas_this_round': delta_count,
212 'round_duration_seconds': GRADIENT_ROUND_DURATION_SECONDS,
213 'embedding_sync': embedding_stats,
214 }
215 except Exception as e:
216 return {'epoch': 0, 'peer_count': 0, 'convergence_score': 0.0,
217 'error': str(e)}
219 # ─── Witness Request ───
221 @staticmethod
222 def request_embedding_witnesses(db, delta: Dict,
223 node_id: str) -> Dict:
224 """Request peer witnesses for an embedding delta.
226 Uses IntegrityService witness pattern: need WITNESS_REQUIREMENT+ peers
227 to validate. Returns witness request status.
228 """
229 try:
230 from integrations.social.models import PeerNode
232 # Find eligible witness peers (active, verified, different node)
233 witnesses = db.query(PeerNode).filter(
234 PeerNode.status == 'active',
235 PeerNode.integrity_status == 'verified',
236 PeerNode.node_id != node_id,
237 ).limit(WITNESS_REQUIREMENT * 2).all()
239 if len(witnesses) < WITNESS_REQUIREMENT:
240 return {
241 'witnessed': False,
242 'reason': 'insufficient_peers',
243 'available': len(witnesses),
244 'required': WITNESS_REQUIREMENT,
245 }
247 # Request witnesses via gossip
248 witness_ids = []
249 for peer in witnesses[:WITNESS_REQUIREMENT]:
250 try:
251 from core.http_pool import pooled_post
252 url = f"{peer.url.rstrip('/')}/api/social/peers/embedding-delta"
253 witness_payload = {
254 'action': 'witness_request',
255 'delta': delta,
256 'submitter_node_id': node_id,
257 'request_id': uuid.uuid4().hex[:12],
258 }
259 resp = pooled_post(url, json=witness_payload, timeout=5)
260 if resp.status_code == 200:
261 witness_ids.append(peer.node_id)
262 except Exception:
263 pass
265 return {
266 'witnessed': len(witness_ids) >= WITNESS_REQUIREMENT,
267 'witness_count': len(witness_ids),
268 'witness_ids': witness_ids,
269 'required': WITNESS_REQUIREMENT,
270 }
271 except Exception as e:
272 return {'witnessed': False, 'reason': str(e)}
274 # ─── Internal Helpers ───
276 @staticmethod
277 def _get_peer_magnitudes(db) -> List[float]:
278 """Get recent embedding delta magnitudes from all peers."""
279 try:
280 from integrations.social.models import NodeAttestation
281 cutoff = datetime.utcnow() - timedelta(
282 seconds=GRADIENT_ROUND_DURATION_SECONDS)
283 attestations = db.query(NodeAttestation).filter(
284 NodeAttestation.attestation_type == 'embedding_delta',
285 NodeAttestation.is_valid == True,
286 NodeAttestation.created_at >= cutoff,
287 ).all()
289 magnitudes = []
290 for att in attestations:
291 payload = att.payload_json or {}
292 evidence = payload.get('evidence', payload)
293 mag = evidence.get('magnitude', 0.0)
294 if isinstance(mag, (int, float)) and mag > 0:
295 magnitudes.append(mag)
296 return magnitudes
297 except Exception:
298 return []
300 @staticmethod
301 def _get_previous_delta(db, node_id: str) -> Optional[Dict]:
302 """Get the most recent embedding delta from this node."""
303 try:
304 from integrations.social.models import NodeAttestation
305 from sqlalchemy import desc
306 att = db.query(NodeAttestation).filter_by(
307 subject_node_id=node_id,
308 attestation_type='embedding_delta',
309 is_valid=True,
310 ).order_by(desc(NodeAttestation.created_at)).first()
312 if att and att.payload_json:
313 return att.payload_json.get('delta')
314 except Exception:
315 pass
316 return None
318 @staticmethod
319 def _record_gradient_fraud(db, node_id: str, signal_type: str,
320 details: dict):
321 """Record a gradient fraud signal via IntegrityService."""
322 try:
323 from integrations.social.models import FraudAlert
324 alert = FraudAlert(
325 node_id=node_id,
326 alert_type=signal_type,
327 severity='medium',
328 details_json=details,
329 )
330 db.add(alert)
331 db.flush()
332 logger.warning(f"Gradient fraud signal: {signal_type} "
333 f"for node {node_id}")
334 except Exception as e:
335 logger.debug(f"Fraud signal record failed: {e}")