Coverage for integrations / agent_lightning / rewards.py: 24.0%

104 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Agent Lightning Reward Calculator 

3 

4Calculates rewards for agent actions to enable reinforcement learning. 

5""" 

6 

7import logging 

8from enum import Enum 

9from typing import Dict, Optional, Any 

10from collections import defaultdict 

11 

12from .config import get_reward_value 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17class RewardType(str, Enum): 

18 """Types of rewards""" 

19 TASK_COMPLETION = "task_completion" 

20 TASK_FAILURE = "task_failure" 

21 TOOL_USE_EFFICIENCY = "tool_use_efficiency" 

22 RESPONSE_QUALITY = "response_quality" 

23 EXECUTION_TIME = "execution_time" 

24 USER_FEEDBACK = "user_feedback" 

25 CUSTOM = "custom" 

26 

27 

28class RewardCalculator: 

29 """ 

30 Calculates rewards for agent actions 

31 

32 Supports multiple reward types: 

33 - Task completion/failure 

34 - Tool usage efficiency 

35 - Response quality 

36 - Execution time penalties 

37 - User feedback 

38 """ 

39 

40 def __init__(self, agent_id: str): 

41 self.agent_id = agent_id 

42 self.stats = defaultdict(float) 

43 self.reward_history = [] 

44 

45 logger.info(f"RewardCalculator initialized for {agent_id}") 

46 

47 def calculate_reward( 

48 self, 

49 reward_type: RewardType, 

50 context: Optional[Dict[str, Any]] = None 

51 ) -> float: 

52 """ 

53 Calculate reward based on type and context 

54 

55 Args: 

56 reward_type: Type of reward 

57 context: Context data for reward calculation 

58 

59 Returns: 

60 Reward value 

61 """ 

62 context = context or {} 

63 

64 # Get base reward value from config 

65 base_reward = get_reward_value(reward_type.value) 

66 

67 # Apply context-based modifications 

68 reward = self._apply_context_modifiers(reward_type, base_reward, context) 

69 

70 # Track statistics 

71 self._track_reward(reward_type, reward, context) 

72 

73 logger.debug(f"Calculated reward: {reward} (type: {reward_type}, context: {context})") 

74 

75 return reward 

76 

77 def _apply_context_modifiers( 

78 self, 

79 reward_type: RewardType, 

80 base_reward: float, 

81 context: Dict[str, Any] 

82 ) -> float: 

83 """ 

84 Apply context-based modifications to base reward 

85 

86 Args: 

87 reward_type: Reward type 

88 base_reward: Base reward value 

89 context: Context data 

90 

91 Returns: 

92 Modified reward 

93 """ 

94 reward = base_reward 

95 

96 # Task completion rewards 

97 if reward_type == RewardType.TASK_COMPLETION: 

98 # Bonus for fast completion 

99 exec_time = context.get('execution_time', 0) 

100 if exec_time > 0 and exec_time < 1.0: # Under 1 second 

101 reward *= 1.2 

102 

103 # Success multiplier 

104 if context.get('success', False): 

105 reward *= 1.0 

106 else: 

107 reward *= 0.5 

108 

109 # Task failure penalties 

110 elif reward_type == RewardType.TASK_FAILURE: 

111 # More severe penalty for errors vs timeouts 

112 if 'error' in context: 

113 reward *= 1.5 # More negative 

114 if context.get('tool', False): 

115 reward *= 0.8 # Less negative for tool failures 

116 

117 # Tool use efficiency 

118 elif reward_type == RewardType.TOOL_USE_EFFICIENCY: 

119 exec_time = context.get('execution_time', 0) 

120 

121 # Reward fast tool execution 

122 if exec_time < 0.5: 

123 reward *= 1.5 

124 elif exec_time > 5.0: 

125 reward *= 0.5 

126 

127 # Penalty for tool failures 

128 if not context.get('success', True): 

129 reward = -abs(reward) 

130 

131 # Response quality (based on metrics if available) 

132 elif reward_type == RewardType.RESPONSE_QUALITY: 

133 quality_score = context.get('quality_score', 0.5) 

134 reward *= (quality_score * 2) # Scale by quality 

135 

136 # Length penalty for very long responses 

137 response_length = context.get('response_length', 0) 

138 if response_length > 2000: 

139 reward *= 0.9 

140 

141 # Execution time penalty 

142 elif reward_type == RewardType.EXECUTION_TIME: 

143 exec_time = context.get('execution_time', 0) 

144 # Penalize slow execution 

145 if exec_time > 10.0: 

146 reward *= (exec_time / 10.0) # More negative for slower 

147 

148 # User feedback 

149 elif reward_type == RewardType.USER_FEEDBACK: 

150 feedback_score = context.get('feedback_score', 0) 

151 # User feedback overrides base reward 

152 reward = feedback_score 

153 

154 # Custom rewards pass through 

155 elif reward_type == RewardType.CUSTOM: 

156 custom_value = context.get('reward_value', base_reward) 

157 reward = custom_value 

158 

159 return reward 

160 

161 def _track_reward( 

162 self, 

163 reward_type: RewardType, 

164 reward: float, 

165 context: Dict[str, Any] 

166 ): 

167 """Track reward statistics""" 

168 self.stats[f'total_{reward_type.value}'] += reward 

169 self.stats[f'count_{reward_type.value}'] += 1 

170 self.stats['total_reward'] += reward 

171 self.stats['reward_count'] += 1 

172 

173 # Track history 

174 self.reward_history.append({ 

175 'type': reward_type.value, 

176 'value': reward, 

177 'context': context 

178 }) 

179 

180 # Keep only last 1000 rewards 

181 if len(self.reward_history) > 1000: 

182 self.reward_history = self.reward_history[-1000:] 

183 

184 def calculate_task_completion_reward( 

185 self, 

186 success: bool, 

187 execution_time: float, 

188 quality_metrics: Optional[Dict] = None 

189 ) -> float: 

190 """ 

191 Convenience method for task completion rewards 

192 

193 Args: 

194 success: Task succeeded 

195 execution_time: Time to complete 

196 quality_metrics: Optional quality metrics 

197 

198 Returns: 

199 Reward value 

200 """ 

201 if success: 

202 context = { 

203 'success': True, 

204 'execution_time': execution_time, 

205 **(quality_metrics or {}) 

206 } 

207 return self.calculate_reward(RewardType.TASK_COMPLETION, context) 

208 else: 

209 context = { 

210 'success': False, 

211 'execution_time': execution_time 

212 } 

213 return self.calculate_reward(RewardType.TASK_FAILURE, context) 

214 

215 def calculate_tool_reward( 

216 self, 

217 tool_name: str, 

218 success: bool, 

219 execution_time: float 

220 ) -> float: 

221 """ 

222 Convenience method for tool execution rewards 

223 

224 Args: 

225 tool_name: Tool name 

226 success: Tool succeeded 

227 execution_time: Execution time 

228 

229 Returns: 

230 Reward value 

231 """ 

232 if success: 

233 context = { 

234 'tool_name': tool_name, 

235 'success': True, 

236 'execution_time': execution_time 

237 } 

238 return self.calculate_reward(RewardType.TOOL_USE_EFFICIENCY, context) 

239 else: 

240 context = { 

241 'tool_name': tool_name, 

242 'success': False, 

243 'execution_time': execution_time, 

244 'tool': True 

245 } 

246 return self.calculate_reward(RewardType.TASK_FAILURE, context) 

247 

248 def get_statistics(self) -> Dict[str, Any]: 

249 """ 

250 Get reward statistics 

251 

252 Returns: 

253 Statistics dictionary 

254 """ 

255 stats = dict(self.stats) 

256 

257 # Calculate averages 

258 if stats.get('reward_count', 0) > 0: 

259 stats['average_reward'] = stats['total_reward'] / stats['reward_count'] 

260 

261 for reward_type in RewardType: 

262 count_key = f'count_{reward_type.value}' 

263 total_key = f'total_{reward_type.value}' 

264 

265 if stats.get(count_key, 0) > 0: 

266 avg_key = f'average_{reward_type.value}' 

267 stats[avg_key] = stats[total_key] / stats[count_key] 

268 

269 return stats 

270 

271 def get_recent_rewards(self, count: int = 10) -> list: 

272 """ 

273 Get recent rewards 

274 

275 Args: 

276 count: Number of recent rewards 

277 

278 Returns: 

279 List of recent rewards 

280 """ 

281 return self.reward_history[-count:] 

282 

283 def reset_statistics(self): 

284 """Reset all statistics""" 

285 self.stats.clear() 

286 self.reward_history.clear() 

287 logger.info(f"Reset reward statistics for {self.agent_id}") 

288 

289 

290__all__ = [ 

291 'RewardCalculator', 

292 'RewardType', 

293]