Coverage for integrations / agent_lightning / rewards.py: 24.0%
104 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Agent Lightning Reward Calculator
4Calculates rewards for agent actions to enable reinforcement learning.
5"""
7import logging
8from enum import Enum
9from typing import Dict, Optional, Any
10from collections import defaultdict
12from .config import get_reward_value
14logger = logging.getLogger(__name__)
17class RewardType(str, Enum):
18 """Types of rewards"""
19 TASK_COMPLETION = "task_completion"
20 TASK_FAILURE = "task_failure"
21 TOOL_USE_EFFICIENCY = "tool_use_efficiency"
22 RESPONSE_QUALITY = "response_quality"
23 EXECUTION_TIME = "execution_time"
24 USER_FEEDBACK = "user_feedback"
25 CUSTOM = "custom"
28class RewardCalculator:
29 """
30 Calculates rewards for agent actions
32 Supports multiple reward types:
33 - Task completion/failure
34 - Tool usage efficiency
35 - Response quality
36 - Execution time penalties
37 - User feedback
38 """
40 def __init__(self, agent_id: str):
41 self.agent_id = agent_id
42 self.stats = defaultdict(float)
43 self.reward_history = []
45 logger.info(f"RewardCalculator initialized for {agent_id}")
47 def calculate_reward(
48 self,
49 reward_type: RewardType,
50 context: Optional[Dict[str, Any]] = None
51 ) -> float:
52 """
53 Calculate reward based on type and context
55 Args:
56 reward_type: Type of reward
57 context: Context data for reward calculation
59 Returns:
60 Reward value
61 """
62 context = context or {}
64 # Get base reward value from config
65 base_reward = get_reward_value(reward_type.value)
67 # Apply context-based modifications
68 reward = self._apply_context_modifiers(reward_type, base_reward, context)
70 # Track statistics
71 self._track_reward(reward_type, reward, context)
73 logger.debug(f"Calculated reward: {reward} (type: {reward_type}, context: {context})")
75 return reward
77 def _apply_context_modifiers(
78 self,
79 reward_type: RewardType,
80 base_reward: float,
81 context: Dict[str, Any]
82 ) -> float:
83 """
84 Apply context-based modifications to base reward
86 Args:
87 reward_type: Reward type
88 base_reward: Base reward value
89 context: Context data
91 Returns:
92 Modified reward
93 """
94 reward = base_reward
96 # Task completion rewards
97 if reward_type == RewardType.TASK_COMPLETION:
98 # Bonus for fast completion
99 exec_time = context.get('execution_time', 0)
100 if exec_time > 0 and exec_time < 1.0: # Under 1 second
101 reward *= 1.2
103 # Success multiplier
104 if context.get('success', False):
105 reward *= 1.0
106 else:
107 reward *= 0.5
109 # Task failure penalties
110 elif reward_type == RewardType.TASK_FAILURE:
111 # More severe penalty for errors vs timeouts
112 if 'error' in context:
113 reward *= 1.5 # More negative
114 if context.get('tool', False):
115 reward *= 0.8 # Less negative for tool failures
117 # Tool use efficiency
118 elif reward_type == RewardType.TOOL_USE_EFFICIENCY:
119 exec_time = context.get('execution_time', 0)
121 # Reward fast tool execution
122 if exec_time < 0.5:
123 reward *= 1.5
124 elif exec_time > 5.0:
125 reward *= 0.5
127 # Penalty for tool failures
128 if not context.get('success', True):
129 reward = -abs(reward)
131 # Response quality (based on metrics if available)
132 elif reward_type == RewardType.RESPONSE_QUALITY:
133 quality_score = context.get('quality_score', 0.5)
134 reward *= (quality_score * 2) # Scale by quality
136 # Length penalty for very long responses
137 response_length = context.get('response_length', 0)
138 if response_length > 2000:
139 reward *= 0.9
141 # Execution time penalty
142 elif reward_type == RewardType.EXECUTION_TIME:
143 exec_time = context.get('execution_time', 0)
144 # Penalize slow execution
145 if exec_time > 10.0:
146 reward *= (exec_time / 10.0) # More negative for slower
148 # User feedback
149 elif reward_type == RewardType.USER_FEEDBACK:
150 feedback_score = context.get('feedback_score', 0)
151 # User feedback overrides base reward
152 reward = feedback_score
154 # Custom rewards pass through
155 elif reward_type == RewardType.CUSTOM:
156 custom_value = context.get('reward_value', base_reward)
157 reward = custom_value
159 return reward
161 def _track_reward(
162 self,
163 reward_type: RewardType,
164 reward: float,
165 context: Dict[str, Any]
166 ):
167 """Track reward statistics"""
168 self.stats[f'total_{reward_type.value}'] += reward
169 self.stats[f'count_{reward_type.value}'] += 1
170 self.stats['total_reward'] += reward
171 self.stats['reward_count'] += 1
173 # Track history
174 self.reward_history.append({
175 'type': reward_type.value,
176 'value': reward,
177 'context': context
178 })
180 # Keep only last 1000 rewards
181 if len(self.reward_history) > 1000:
182 self.reward_history = self.reward_history[-1000:]
184 def calculate_task_completion_reward(
185 self,
186 success: bool,
187 execution_time: float,
188 quality_metrics: Optional[Dict] = None
189 ) -> float:
190 """
191 Convenience method for task completion rewards
193 Args:
194 success: Task succeeded
195 execution_time: Time to complete
196 quality_metrics: Optional quality metrics
198 Returns:
199 Reward value
200 """
201 if success:
202 context = {
203 'success': True,
204 'execution_time': execution_time,
205 **(quality_metrics or {})
206 }
207 return self.calculate_reward(RewardType.TASK_COMPLETION, context)
208 else:
209 context = {
210 'success': False,
211 'execution_time': execution_time
212 }
213 return self.calculate_reward(RewardType.TASK_FAILURE, context)
215 def calculate_tool_reward(
216 self,
217 tool_name: str,
218 success: bool,
219 execution_time: float
220 ) -> float:
221 """
222 Convenience method for tool execution rewards
224 Args:
225 tool_name: Tool name
226 success: Tool succeeded
227 execution_time: Execution time
229 Returns:
230 Reward value
231 """
232 if success:
233 context = {
234 'tool_name': tool_name,
235 'success': True,
236 'execution_time': execution_time
237 }
238 return self.calculate_reward(RewardType.TOOL_USE_EFFICIENCY, context)
239 else:
240 context = {
241 'tool_name': tool_name,
242 'success': False,
243 'execution_time': execution_time,
244 'tool': True
245 }
246 return self.calculate_reward(RewardType.TASK_FAILURE, context)
248 def get_statistics(self) -> Dict[str, Any]:
249 """
250 Get reward statistics
252 Returns:
253 Statistics dictionary
254 """
255 stats = dict(self.stats)
257 # Calculate averages
258 if stats.get('reward_count', 0) > 0:
259 stats['average_reward'] = stats['total_reward'] / stats['reward_count']
261 for reward_type in RewardType:
262 count_key = f'count_{reward_type.value}'
263 total_key = f'total_{reward_type.value}'
265 if stats.get(count_key, 0) > 0:
266 avg_key = f'average_{reward_type.value}'
267 stats[avg_key] = stats[total_key] / stats[count_key]
269 return stats
271 def get_recent_rewards(self, count: int = 10) -> list:
272 """
273 Get recent rewards
275 Args:
276 count: Number of recent rewards
278 Returns:
279 List of recent rewards
280 """
281 return self.reward_history[-count:]
283 def reset_statistics(self):
284 """Reset all statistics"""
285 self.stats.clear()
286 self.reward_history.clear()
287 logger.info(f"Reset reward statistics for {self.agent_id}")
290__all__ = [
291 'RewardCalculator',
292 'RewardType',
293]