Coverage for integrations / agent_lightning / wrapper.py: 24.7%
93 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Agent Lightning Wrapper
4Wraps AutoGen agents with Agent Lightning instrumentation for training and optimization.
5Provides minimal-change integration with automatic tracing.
6"""
8import logging
9import time
10import json
11from typing import Any, Dict, List, Optional, Callable
12from datetime import datetime
13from functools import wraps
15from .config import get_agent_config, is_enabled
16from .tracer import LightningTracer
17from .rewards import RewardCalculator, RewardType
19logger = logging.getLogger(__name__)
22class AgentLightningWrapper:
23 """
24 Wraps an AutoGen agent with Agent Lightning instrumentation
26 Provides:
27 - Automatic tracing of agent interactions
28 - Reward tracking for reinforcement learning
29 - Performance monitoring
30 - Zero impact on agent behavior (transparent wrapper)
32 Registered as a virtual subclass of autogen.Agent so isinstance()
33 checks pass in GroupChat speaker selection and transition validation.
34 """
36 def __init__(
37 self,
38 agent: Any,
39 agent_id: str,
40 track_rewards: bool = True,
41 auto_trace: bool = True
42 ):
43 """
44 Initialize wrapper
46 Args:
47 agent: AutoGen agent to wrap
48 agent_id: Unique identifier for this agent
49 track_rewards: Enable reward tracking
50 auto_trace: Enable automatic tracing
51 """
52 self.agent = agent
53 self.agent_id = agent_id
54 self.track_rewards = track_rewards
55 self.auto_trace = auto_trace
57 # Get agent-specific configuration
58 self.config = get_agent_config(agent_id)
60 # Initialize components
61 self.tracer = LightningTracer(agent_id) if auto_trace else None
62 self.reward_calculator = RewardCalculator(agent_id) if track_rewards else None
64 # Execution tracking
65 self.execution_count = 0
66 self.start_time = None
67 self.current_span_id = None
69 # Wrap agent methods
70 self._wrap_agent_methods()
72 logger.info(f"AgentLightningWrapper initialized for {agent_id}")
74 def _wrap_agent_methods(self):
75 """Wrap key agent methods for instrumentation"""
76 if not is_enabled():
77 logger.info("Agent Lightning disabled, skipping method wrapping")
78 return
80 # Wrap generate_reply if it exists (AutoGen pattern)
81 if hasattr(self.agent, 'generate_reply'):
82 original_generate_reply = self.agent.generate_reply
83 self.agent.generate_reply = self._wrap_generate_reply(original_generate_reply)
85 # Wrap _execute_function if it exists (tool execution)
86 if hasattr(self.agent, '_execute_function'):
87 original_execute = self.agent._execute_function
88 self.agent._execute_function = self._wrap_tool_execution(original_execute)
90 def _wrap_generate_reply(self, original_func: Callable) -> Callable:
91 """Wrap generate_reply method"""
92 @wraps(original_func)
93 def wrapped(*args, **kwargs):
94 # Start span
95 span_id = None
96 if self.tracer:
97 span_id = self.tracer.start_span(
98 span_type='generate_reply',
99 context={'args': str(args)[:200], 'kwargs': str(kwargs)[:200]}
100 )
101 self.current_span_id = span_id
103 start_time = time.time()
105 try:
106 # Execute original function
107 result = original_func(*args, **kwargs)
109 # Calculate execution time
110 execution_time = time.time() - start_time
112 # Emit events
113 if self.tracer and span_id:
114 self.tracer.emit_prompt(
115 span_id=span_id,
116 prompt=str(args)[:500],
117 context={'execution_time': execution_time}
118 )
120 self.tracer.emit_response(
121 span_id=span_id,
122 response=str(result)[:500],
123 context={'execution_time': execution_time}
124 )
126 # End span
127 self.tracer.end_span(
128 span_id=span_id,
129 status='success',
130 result={'execution_time': execution_time}
131 )
133 # Calculate reward
134 if self.reward_calculator:
135 reward = self.reward_calculator.calculate_reward(
136 reward_type=RewardType.TASK_COMPLETION,
137 context={
138 'execution_time': execution_time,
139 'success': True
140 }
141 )
143 if self.tracer and span_id:
144 self.tracer.emit_reward(span_id, reward)
146 self.execution_count += 1
147 return result
149 except Exception as e:
150 logger.error(f"Error in generate_reply: {e}")
152 # Track failure
153 if self.tracer and span_id:
154 self.tracer.end_span(
155 span_id=span_id,
156 status='error',
157 result={'error': str(e)}
158 )
160 # Negative reward for failure
161 if self.reward_calculator:
162 reward = self.reward_calculator.calculate_reward(
163 reward_type=RewardType.TASK_FAILURE,
164 context={'error': str(e)}
165 )
167 if self.tracer and span_id:
168 self.tracer.emit_reward(span_id, reward)
170 raise
172 return wrapped
174 def _wrap_tool_execution(self, original_func: Callable) -> Callable:
175 """Wrap tool execution method"""
176 @wraps(original_func)
177 def wrapped(*args, **kwargs):
178 # Emit tool call event
179 if self.tracer and self.current_span_id:
180 self.tracer.emit_tool_call(
181 span_id=self.current_span_id,
182 tool_name=str(args[0]) if args else 'unknown',
183 tool_args=str(args[1:])[:200] if len(args) > 1 else '',
184 context=kwargs
185 )
187 start_time = time.time()
189 try:
190 # Execute original function
191 result = original_func(*args, **kwargs)
193 execution_time = time.time() - start_time
195 # Tool execution reward
196 if self.reward_calculator:
197 reward = self.reward_calculator.calculate_reward(
198 reward_type=RewardType.TOOL_USE_EFFICIENCY,
199 context={
200 'execution_time': execution_time,
201 'success': True
202 }
203 )
205 if self.tracer and self.current_span_id:
206 self.tracer.emit_reward(self.current_span_id, reward)
208 return result
210 except Exception as e:
211 logger.error(f"Error in tool execution: {e}")
213 # Negative reward for tool failure
214 if self.reward_calculator:
215 reward = self.reward_calculator.calculate_reward(
216 reward_type=RewardType.TASK_FAILURE,
217 context={'error': str(e), 'tool': True}
218 )
220 if self.tracer and self.current_span_id:
221 self.tracer.emit_reward(self.current_span_id, reward)
223 raise
225 return wrapped
227 def emit_custom_reward(self, reward_value: float, context: Optional[Dict] = None):
228 """
229 Emit custom reward value
231 Args:
232 reward_value: Reward value
233 context: Optional context
234 """
235 if self.tracer and self.current_span_id:
236 self.tracer.emit_reward(self.current_span_id, reward_value, context)
238 def get_statistics(self) -> Dict[str, Any]:
239 """
240 Get agent statistics
242 Returns:
243 Dictionary with statistics
244 """
245 stats = {
246 'agent_id': self.agent_id,
247 'execution_count': self.execution_count,
248 'config': self.config
249 }
251 if self.tracer:
252 stats['tracer_stats'] = self.tracer.get_statistics()
254 if self.reward_calculator:
255 stats['reward_stats'] = self.reward_calculator.get_statistics()
257 return stats
259 def __getattr__(self, name: str):
260 """Delegate attribute access to wrapped agent"""
261 return getattr(self.agent, name)
263 def __repr__(self) -> str:
264 return f"AgentLightningWrapper({self.agent_id}, wrapped={self.agent.__class__.__name__})"
267# Register as virtual subclass of autogen.Agent so isinstance() checks pass
268# in GroupChat (speaker selection, transition validation, graph validity).
269# This is the ABC way to say "this class IS-A Agent" without inheriting.
270try:
271 import autogen
272 autogen.Agent.register(AgentLightningWrapper)
273except (ImportError, AttributeError):
274 pass # autogen not installed or Agent doesn't support register()
277def instrument_autogen_agent(
278 agent: Any,
279 agent_id: str,
280 track_rewards: bool = True,
281 auto_trace: bool = True
282) -> AgentLightningWrapper:
283 """
284 Convenience function to instrument an AutoGen agent
286 Args:
287 agent: AutoGen agent
288 agent_id: Agent identifier
289 track_rewards: Enable reward tracking
290 auto_trace: Enable automatic tracing
292 Returns:
293 Wrapped agent
294 """
295 if not is_enabled():
296 logger.info("Agent Lightning disabled, returning unwrapped agent")
297 return agent
299 return AgentLightningWrapper(
300 agent=agent,
301 agent_id=agent_id,
302 track_rewards=track_rewards,
303 auto_trace=auto_trace
304 )
307__all__ = [
308 'AgentLightningWrapper',
309 'instrument_autogen_agent',
310]