Coverage for integrations / agent_lightning / store.py: 12.5%

160 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Agent Lightning Store 

3 

4Persistence layer for spans and training data. 

5Supports multiple backends: Redis, JSON, and in-memory. 

6""" 

7 

8import logging 

9import json 

10import os 

11from typing import Any, Dict, List, Optional 

12from datetime import datetime, timedelta 

13from collections import defaultdict 

14 

15from .config import AGENT_LIGHTNING_CONFIG 

16from .tracer import Span 

17 

18logger = logging.getLogger(__name__) 

19 

20 

21class LightningStore: 

22 """ 

23 Persistence layer for Agent Lightning data 

24 

25 Stores: 

26 - Spans (traces of agent interactions) 

27 - Training data (for continuous improvement) 

28 - Statistics and metrics 

29 

30 Supports multiple backends: 

31 - redis: Redis backend (recommended for production) 

32 - json: JSON file backend (for development) 

33 - memory: In-memory only (for testing) 

34 """ 

35 

36 def __init__( 

37 self, 

38 agent_id: str, 

39 backend: str = None 

40 ): 

41 self.agent_id = agent_id 

42 self.backend = backend or AGENT_LIGHTNING_CONFIG.get('store_backend', 'json') 

43 

44 # Backend-specific initialization 

45 self._backend_client = None 

46 self._init_backend() 

47 

48 # In-memory cache 

49 self._cache = { 

50 'spans': {}, 

51 'stats': defaultdict(float), 

52 } 

53 

54 logger.info(f"LightningStore initialized for {agent_id} with backend: {self.backend}") 

55 

56 def _init_backend(self): 

57 """Initialize storage backend""" 

58 if self.backend == 'redis': 

59 try: 

60 import redis 

61 redis_config = AGENT_LIGHTNING_CONFIG.get('redis', {}) 

62 self._backend_client = redis.Redis( 

63 host=redis_config.get('host', 'localhost'), 

64 port=redis_config.get('port', 6379), 

65 db=redis_config.get('db', 0), 

66 decode_responses=True, 

67 socket_connect_timeout=1, socket_timeout=2, 

68 retry_on_timeout=False, 

69 ) 

70 # Test connection 

71 self._backend_client.ping() 

72 logger.info("Connected to Redis backend") 

73 except Exception as e: 

74 logger.warning(f"Redis backend failed: {e}. Falling back to JSON.") 

75 self.backend = 'json' 

76 self._init_backend() 

77 

78 elif self.backend == 'json': 

79 # Ensure storage directory exists 

80 self.storage_path = AGENT_LIGHTNING_CONFIG.get( 

81 'traces_path', 

82 './agent_data/lightning_traces' 

83 ) 

84 os.makedirs(self.storage_path, exist_ok=True) 

85 logger.info(f"Using JSON backend at {self.storage_path}") 

86 

87 elif self.backend == 'memory': 

88 logger.info("Using in-memory backend (data will not persist)") 

89 

90 else: 

91 logger.warning(f"Unknown backend: {self.backend}. Using memory.") 

92 self.backend = 'memory' 

93 

94 def save_span(self, span: Span) -> bool: 

95 """ 

96 Save span to storage 

97 

98 Args: 

99 span: Span to save 

100 

101 Returns: 

102 Success status 

103 """ 

104 try: 

105 span_dict = span.to_dict() 

106 span_key = f"span:{self.agent_id}:{span.span_id}" 

107 

108 # Save to backend 

109 if self.backend == 'redis': 

110 self._backend_client.set( 

111 span_key, 

112 json.dumps(span_dict) 

113 ) 

114 # Add to agent's span list 

115 self._backend_client.sadd( 

116 f"spans:{self.agent_id}", 

117 span.span_id 

118 ) 

119 

120 elif self.backend == 'json': 

121 filename = os.path.join( 

122 self.storage_path, 

123 f"{span.span_id}.json" 

124 ) 

125 with open(filename, 'w') as f: 

126 json.dump(span_dict, f, indent=2) 

127 

128 # Always cache in memory 

129 self._cache['spans'][span.span_id] = span_dict 

130 

131 logger.debug(f"Saved span: {span.span_id}") 

132 return True 

133 

134 except Exception as e: 

135 logger.error(f"Error saving span: {e}") 

136 return False 

137 

138 def load_span(self, span_id: str) -> Optional[Dict]: 

139 """ 

140 Load span from storage 

141 

142 Args: 

143 span_id: Span ID 

144 

145 Returns: 

146 Span dictionary or None 

147 """ 

148 # Check cache first 

149 if span_id in self._cache['spans']: 

150 return self._cache['spans'][span_id] 

151 

152 try: 

153 if self.backend == 'redis': 

154 span_key = f"span:{self.agent_id}:{span_id}" 

155 span_json = self._backend_client.get(span_key) 

156 if span_json: 

157 span_dict = json.loads(span_json) 

158 self._cache['spans'][span_id] = span_dict 

159 return span_dict 

160 

161 elif self.backend == 'json': 

162 filename = os.path.join( 

163 self.storage_path, 

164 f"{span_id}.json" 

165 ) 

166 if os.path.exists(filename): 

167 with open(filename, 'r') as f: 

168 span_dict = json.load(f) 

169 self._cache['spans'][span_id] = span_dict 

170 return span_dict 

171 

172 return None 

173 

174 except Exception as e: 

175 logger.error(f"Error loading span: {e}") 

176 return None 

177 

178 def list_spans( 

179 self, 

180 limit: int = 100, 

181 span_type: Optional[str] = None, 

182 status: Optional[str] = None 

183 ) -> List[Dict]: 

184 """ 

185 List spans with optional filtering 

186 

187 Args: 

188 limit: Maximum number of spans to return 

189 span_type: Filter by span type 

190 status: Filter by status 

191 

192 Returns: 

193 List of span dictionaries 

194 """ 

195 try: 

196 spans = [] 

197 

198 if self.backend == 'redis': 

199 span_ids = self._backend_client.smembers(f"spans:{self.agent_id}") 

200 for span_id in span_ids: 

201 span = self.load_span(span_id) 

202 if span: 

203 # Apply filters 

204 if span_type and span.get('span_type') != span_type: 

205 continue 

206 if status and span.get('status') != status: 

207 continue 

208 spans.append(span) 

209 

210 if len(spans) >= limit: 

211 break 

212 

213 elif self.backend == 'json': 

214 for filename in os.listdir(self.storage_path): 

215 if not filename.endswith('.json'): 

216 continue 

217 

218 span_id = filename[:-5] # Remove .json 

219 span = self.load_span(span_id) 

220 if span: 

221 # Apply filters 

222 if span_type and span.get('span_type') != span_type: 

223 continue 

224 if status and span.get('status') != status: 

225 continue 

226 spans.append(span) 

227 

228 if len(spans) >= limit: 

229 break 

230 

231 elif self.backend == 'memory': 

232 for span_id, span in self._cache['spans'].items(): 

233 # Apply filters 

234 if span_type and span.get('span_type') != span_type: 

235 continue 

236 if status and span.get('status') != status: 

237 continue 

238 spans.append(span) 

239 

240 if len(spans) >= limit: 

241 break 

242 

243 # Sort by start_time (most recent first) 

244 spans.sort(key=lambda s: s.get('start_time', 0), reverse=True) 

245 

246 return spans[:limit] 

247 

248 except Exception as e: 

249 logger.error(f"Error listing spans: {e}") 

250 return [] 

251 

252 def get_training_data( 

253 self, 

254 limit: int = 1000, 

255 min_reward: Optional[float] = None, 

256 max_reward: Optional[float] = None 

257 ) -> List[Dict]: 

258 """ 

259 Get training data for continuous improvement 

260 

261 Args: 

262 limit: Maximum samples to return 

263 min_reward: Minimum reward threshold 

264 max_reward: Maximum reward threshold 

265 

266 Returns: 

267 List of training samples 

268 """ 

269 spans = self.list_spans(limit=limit) 

270 training_data = [] 

271 

272 for span in spans: 

273 # Extract reward events 

274 rewards = [ 

275 event for event in span.get('events', []) 

276 if event.get('type') == 'reward' 

277 ] 

278 

279 if not rewards: 

280 continue 

281 

282 # Calculate total reward 

283 total_reward = sum( 

284 event.get('data', {}).get('reward', 0) 

285 for event in rewards 

286 ) 

287 

288 # Apply reward filters 

289 if min_reward is not None and total_reward < min_reward: 

290 continue 

291 if max_reward is not None and total_reward > max_reward: 

292 continue 

293 

294 # Extract prompt and response 

295 prompt_events = [ 

296 event for event in span.get('events', []) 

297 if event.get('type') == 'prompt' 

298 ] 

299 response_events = [ 

300 event for event in span.get('events', []) 

301 if event.get('type') == 'response' 

302 ] 

303 

304 if prompt_events and response_events: 

305 training_sample = { 

306 'span_id': span.get('span_id'), 

307 'agent_id': span.get('agent_id'), 

308 'prompt': prompt_events[0].get('data', {}).get('prompt', ''), 

309 'response': response_events[0].get('data', {}).get('response', ''), 

310 'reward': total_reward, 

311 'duration': span.get('duration'), 

312 'status': span.get('status'), 

313 'timestamp': span.get('start_time') 

314 } 

315 training_data.append(training_sample) 

316 

317 return training_data 

318 

319 def delete_span(self, span_id: str) -> bool: 

320 """ 

321 Delete span from storage 

322 

323 Args: 

324 span_id: Span ID 

325 

326 Returns: 

327 Success status 

328 """ 

329 try: 

330 if self.backend == 'redis': 

331 span_key = f"span:{self.agent_id}:{span_id}" 

332 self._backend_client.delete(span_key) 

333 self._backend_client.srem(f"spans:{self.agent_id}", span_id) 

334 

335 elif self.backend == 'json': 

336 filename = os.path.join( 

337 self.storage_path, 

338 f"{span_id}.json" 

339 ) 

340 if os.path.exists(filename): 

341 os.remove(filename) 

342 

343 # Remove from cache 

344 self._cache['spans'].pop(span_id, None) 

345 

346 logger.debug(f"Deleted span: {span_id}") 

347 return True 

348 

349 except Exception as e: 

350 logger.error(f"Error deleting span: {e}") 

351 return False 

352 

353 def cleanup_old_spans(self, days: int = 30) -> int: 

354 """ 

355 Delete spans older than specified days 

356 

357 Args: 

358 days: Age threshold in days 

359 

360 Returns: 

361 Number of spans deleted 

362 """ 

363 try: 

364 threshold = datetime.now().timestamp() - (days * 86400) 

365 deleted_count = 0 

366 

367 spans = self.list_spans(limit=10000) 

368 for span in spans: 

369 if span.get('start_time', 0) < threshold: 

370 if self.delete_span(span.get('span_id')): 

371 deleted_count += 1 

372 

373 logger.info(f"Cleaned up {deleted_count} old spans") 

374 return deleted_count 

375 

376 except Exception as e: 

377 logger.error(f"Error cleaning up spans: {e}") 

378 return 0 

379 

380 def get_statistics(self) -> Dict[str, Any]: 

381 """ 

382 Get storage statistics 

383 

384 Returns: 

385 Statistics dictionary 

386 """ 

387 stats = { 

388 'agent_id': self.agent_id, 

389 'backend': self.backend, 

390 'cached_spans': len(self._cache['spans']) 

391 } 

392 

393 try: 

394 if self.backend == 'redis': 

395 stats['total_spans'] = self._backend_client.scard(f"spans:{self.agent_id}") 

396 

397 elif self.backend == 'json': 

398 json_files = [ 

399 f for f in os.listdir(self.storage_path) 

400 if f.endswith('.json') 

401 ] 

402 stats['total_spans'] = len(json_files) 

403 

404 elif self.backend == 'memory': 

405 stats['total_spans'] = len(self._cache['spans']) 

406 

407 except Exception as e: 

408 logger.error(f"Error getting statistics: {e}") 

409 

410 return stats 

411 

412 

413__all__ = [ 

414 'LightningStore', 

415]