Coverage for integrations / channels / queue / rate_limit.py: 92.6%

188 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Rate Limiting System 

3 

4Limits request rates to prevent abuse and comply with API limits. 

5Ported from HevolveBot's src/channels/rate-limit.ts. 

6 

7Features: 

8- Sliding window rate limiting 

9- Per-channel limits 

10- Burst handling 

11- Token bucket algorithm 

12""" 

13 

14from __future__ import annotations 

15 

16import logging 

17import threading 

18import time 

19from collections import deque 

20from dataclasses import dataclass, field 

21from datetime import datetime, timedelta 

22from enum import Enum 

23from typing import Optional, Dict, Deque, Tuple 

24 

25logger = logging.getLogger(__name__) 

26 

27 

28class RateLimitResult(Enum): 

29 """Result of rate limit check.""" 

30 ALLOWED = "allowed" 

31 RATE_LIMITED = "rate_limited" 

32 BURST_EXCEEDED = "burst_exceeded" 

33 

34 

35@dataclass 

36class RateLimitConfig: 

37 """Configuration for rate limiting.""" 

38 requests_per_minute: int = 60 

39 requests_per_hour: int = 1000 

40 burst_limit: int = 10 

41 burst_window_seconds: int = 1 

42 per_channel_limits: Dict[str, int] = field(default_factory=dict) 

43 

44 

45@dataclass 

46class RateLimitInfo: 

47 """Information about current rate limit state.""" 

48 allowed: bool 

49 result: RateLimitResult 

50 remaining_minute: int 

51 remaining_hour: int 

52 remaining_burst: int 

53 reset_minute_at: datetime 

54 reset_hour_at: datetime 

55 retry_after_seconds: Optional[float] = None 

56 

57 

58@dataclass 

59class RateLimitStats: 

60 """Statistics for rate limiter.""" 

61 total_requests: int = 0 

62 total_allowed: int = 0 

63 total_rate_limited: int = 0 

64 total_burst_exceeded: int = 0 

65 

66 

67class SlidingWindowCounter: 

68 """Sliding window counter for rate limiting.""" 

69 

70 def __init__(self, window_seconds: int, max_requests: int): 

71 self.window_seconds = window_seconds 

72 self.max_requests = max_requests 

73 self._timestamps: Deque[float] = deque() 

74 self._lock = threading.Lock() 

75 

76 def _cleanup(self, now: float) -> None: 

77 """Remove expired timestamps.""" 

78 cutoff = now - self.window_seconds 

79 while self._timestamps and self._timestamps[0] < cutoff: 

80 self._timestamps.popleft() 

81 

82 def check(self) -> Tuple[bool, int]: 

83 """ 

84 Check if request is allowed. 

85 

86 Returns: 

87 Tuple of (allowed, remaining) 

88 """ 

89 now = time.time() 

90 with self._lock: 

91 self._cleanup(now) 

92 remaining = max(0, self.max_requests - len(self._timestamps)) 

93 return remaining > 0, remaining 

94 

95 def consume(self) -> bool: 

96 """ 

97 Consume one request slot. 

98 

99 Returns: 

100 True if consumed, False if at limit 

101 """ 

102 now = time.time() 

103 with self._lock: 

104 self._cleanup(now) 

105 if len(self._timestamps) >= self.max_requests: 

106 return False 

107 self._timestamps.append(now) 

108 return True 

109 

110 def get_remaining(self) -> int: 

111 """Get remaining requests in window.""" 

112 now = time.time() 

113 with self._lock: 

114 self._cleanup(now) 

115 return max(0, self.max_requests - len(self._timestamps)) 

116 

117 def get_reset_time(self) -> float: 

118 """Get time until window resets (oldest request expires).""" 

119 now = time.time() 

120 with self._lock: 

121 self._cleanup(now) 

122 if not self._timestamps: 

123 return 0 

124 oldest = self._timestamps[0] 

125 reset_at = oldest + self.window_seconds 

126 return max(0, reset_at - now) 

127 

128 def reset(self) -> None: 

129 """Reset the counter.""" 

130 with self._lock: 

131 self._timestamps.clear() 

132 

133 

134class TokenBucket: 

135 """Token bucket for burst handling.""" 

136 

137 def __init__(self, capacity: int, refill_rate: float): 

138 """ 

139 Args: 

140 capacity: Maximum tokens in bucket 

141 refill_rate: Tokens added per second 

142 """ 

143 self.capacity = capacity 

144 self.refill_rate = refill_rate 

145 self._tokens = float(capacity) 

146 self._last_refill = time.time() 

147 self._lock = threading.Lock() 

148 

149 def _refill(self) -> None: 

150 """Refill tokens based on elapsed time.""" 

151 now = time.time() 

152 elapsed = now - self._last_refill 

153 new_tokens = elapsed * self.refill_rate 

154 self._tokens = min(self.capacity, self._tokens + new_tokens) 

155 self._last_refill = now 

156 

157 def consume(self, tokens: int = 1) -> bool: 

158 """ 

159 Try to consume tokens. 

160 

161 Returns: 

162 True if consumed, False if not enough tokens 

163 """ 

164 with self._lock: 

165 self._refill() 

166 if self._tokens >= tokens: 

167 self._tokens -= tokens 

168 return True 

169 return False 

170 

171 def get_tokens(self) -> float: 

172 """Get current token count.""" 

173 with self._lock: 

174 self._refill() 

175 return self._tokens 

176 

177 def reset(self) -> None: 

178 """Reset to full capacity.""" 

179 with self._lock: 

180 self._tokens = float(self.capacity) 

181 self._last_refill = time.time() 

182 

183 

184class RateLimiter: 

185 """ 

186 Rate limiter with multiple windows and burst handling. 

187 

188 Usage: 

189 config = RateLimitConfig(requests_per_minute=60, burst_limit=10) 

190 limiter = RateLimiter(config) 

191 

192 # Check if request is allowed 

193 result = limiter.check("telegram", "chat123") 

194 if result.allowed: 

195 # Consume the slot 

196 limiter.consume("telegram", "chat123") 

197 # Process request 

198 else: 

199 # Handle rate limit 

200 print(f"Retry after {result.retry_after_seconds} seconds") 

201 """ 

202 

203 def __init__(self, config: RateLimitConfig): 

204 self.config = config 

205 

206 # Per-key rate limiters 

207 self._minute_counters: Dict[str, SlidingWindowCounter] = {} 

208 self._hour_counters: Dict[str, SlidingWindowCounter] = {} 

209 self._burst_buckets: Dict[str, TokenBucket] = {} 

210 

211 self._lock = threading.Lock() 

212 self._stats = RateLimitStats() 

213 

214 def _get_key(self, channel: str, chat_id: str) -> str: 

215 """Get rate limit key.""" 

216 return f"{channel}:{chat_id}" 

217 

218 def _get_limits(self, channel: str) -> Tuple[int, int, int]: 

219 """Get limits for a channel.""" 

220 per_minute = self.config.per_channel_limits.get( 

221 channel, 

222 self.config.requests_per_minute 

223 ) 

224 per_hour = self.config.requests_per_hour 

225 burst = self.config.burst_limit 

226 return per_minute, per_hour, burst 

227 

228 def _get_or_create_counters( 

229 self, 

230 key: str, 

231 channel: str, 

232 ) -> Tuple[SlidingWindowCounter, SlidingWindowCounter, TokenBucket]: 

233 """Get or create rate limit counters for a key.""" 

234 per_minute, per_hour, burst = self._get_limits(channel) 

235 

236 with self._lock: 

237 if key not in self._minute_counters: 

238 self._minute_counters[key] = SlidingWindowCounter(60, per_minute) 

239 if key not in self._hour_counters: 

240 self._hour_counters[key] = SlidingWindowCounter(3600, per_hour) 

241 if key not in self._burst_buckets: 

242 # Refill at minute rate 

243 refill_rate = per_minute / 60.0 

244 self._burst_buckets[key] = TokenBucket(burst, refill_rate) 

245 

246 return ( 

247 self._minute_counters[key], 

248 self._hour_counters[key], 

249 self._burst_buckets[key], 

250 ) 

251 

252 def check(self, channel: str, chat_id: str) -> RateLimitInfo: 

253 """ 

254 Check if a request is allowed. 

255 

256 Args: 

257 channel: Channel name 

258 chat_id: Chat identifier 

259 

260 Returns: 

261 RateLimitInfo with result and remaining quotas 

262 """ 

263 key = self._get_key(channel, chat_id) 

264 minute_counter, hour_counter, burst_bucket = self._get_or_create_counters( 

265 key, channel 

266 ) 

267 

268 self._stats.total_requests += 1 

269 

270 # Check burst limit 

271 burst_tokens = int(burst_bucket.get_tokens()) 

272 if burst_tokens <= 0: 

273 self._stats.total_burst_exceeded += 1 

274 return RateLimitInfo( 

275 allowed=False, 

276 result=RateLimitResult.BURST_EXCEEDED, 

277 remaining_minute=minute_counter.get_remaining(), 

278 remaining_hour=hour_counter.get_remaining(), 

279 remaining_burst=0, 

280 reset_minute_at=datetime.now() + timedelta(seconds=minute_counter.get_reset_time()), 

281 reset_hour_at=datetime.now() + timedelta(seconds=hour_counter.get_reset_time()), 

282 retry_after_seconds=self.config.burst_window_seconds, 

283 ) 

284 

285 # Check minute limit 

286 minute_allowed, minute_remaining = minute_counter.check() 

287 if not minute_allowed: 

288 self._stats.total_rate_limited += 1 

289 return RateLimitInfo( 

290 allowed=False, 

291 result=RateLimitResult.RATE_LIMITED, 

292 remaining_minute=0, 

293 remaining_hour=hour_counter.get_remaining(), 

294 remaining_burst=burst_tokens, 

295 reset_minute_at=datetime.now() + timedelta(seconds=minute_counter.get_reset_time()), 

296 reset_hour_at=datetime.now() + timedelta(seconds=hour_counter.get_reset_time()), 

297 retry_after_seconds=minute_counter.get_reset_time(), 

298 ) 

299 

300 # Check hour limit 

301 hour_allowed, hour_remaining = hour_counter.check() 

302 if not hour_allowed: 

303 self._stats.total_rate_limited += 1 

304 return RateLimitInfo( 

305 allowed=False, 

306 result=RateLimitResult.RATE_LIMITED, 

307 remaining_minute=minute_remaining, 

308 remaining_hour=0, 

309 remaining_burst=burst_tokens, 

310 reset_minute_at=datetime.now() + timedelta(seconds=minute_counter.get_reset_time()), 

311 reset_hour_at=datetime.now() + timedelta(seconds=hour_counter.get_reset_time()), 

312 retry_after_seconds=hour_counter.get_reset_time(), 

313 ) 

314 

315 self._stats.total_allowed += 1 

316 return RateLimitInfo( 

317 allowed=True, 

318 result=RateLimitResult.ALLOWED, 

319 remaining_minute=minute_remaining, 

320 remaining_hour=hour_remaining, 

321 remaining_burst=burst_tokens, 

322 reset_minute_at=datetime.now() + timedelta(seconds=60), 

323 reset_hour_at=datetime.now() + timedelta(seconds=3600), 

324 ) 

325 

326 def consume(self, channel: str, chat_id: str) -> bool: 

327 """ 

328 Consume a rate limit slot. 

329 

330 Args: 

331 channel: Channel name 

332 chat_id: Chat identifier 

333 

334 Returns: 

335 True if consumed, False if at limit 

336 """ 

337 key = self._get_key(channel, chat_id) 

338 minute_counter, hour_counter, burst_bucket = self._get_or_create_counters( 

339 key, channel 

340 ) 

341 

342 # Try to consume from all 

343 if not burst_bucket.consume(): 

344 return False 

345 if not minute_counter.consume(): 

346 return False 

347 if not hour_counter.consume(): 

348 return False 

349 

350 return True 

351 

352 def check_and_consume(self, channel: str, chat_id: str) -> RateLimitInfo: 

353 """ 

354 Check and consume in one operation. 

355 

356 Returns: 

357 RateLimitInfo with result 

358 """ 

359 info = self.check(channel, chat_id) 

360 if info.allowed: 

361 self.consume(channel, chat_id) 

362 return info 

363 

364 def get_remaining(self, channel: str, chat_id: str) -> Tuple[int, int, int]: 

365 """ 

366 Get remaining quotas. 

367 

368 Returns: 

369 Tuple of (remaining_minute, remaining_hour, remaining_burst) 

370 """ 

371 key = self._get_key(channel, chat_id) 

372 minute_counter, hour_counter, burst_bucket = self._get_or_create_counters( 

373 key, channel 

374 ) 

375 return ( 

376 minute_counter.get_remaining(), 

377 hour_counter.get_remaining(), 

378 int(burst_bucket.get_tokens()), 

379 ) 

380 

381 def reset(self, channel: str, chat_id: str) -> None: 

382 """Reset rate limits for a specific chat.""" 

383 key = self._get_key(channel, chat_id) 

384 with self._lock: 

385 if key in self._minute_counters: 

386 self._minute_counters[key].reset() 

387 if key in self._hour_counters: 

388 self._hour_counters[key].reset() 

389 if key in self._burst_buckets: 

390 self._burst_buckets[key].reset() 

391 

392 def reset_all(self) -> None: 

393 """Reset all rate limits.""" 

394 with self._lock: 

395 for counter in self._minute_counters.values(): 

396 counter.reset() 

397 for counter in self._hour_counters.values(): 

398 counter.reset() 

399 for bucket in self._burst_buckets.values(): 

400 bucket.reset() 

401 

402 def get_stats(self) -> RateLimitStats: 

403 """Get rate limiter statistics.""" 

404 return RateLimitStats( 

405 total_requests=self._stats.total_requests, 

406 total_allowed=self._stats.total_allowed, 

407 total_rate_limited=self._stats.total_rate_limited, 

408 total_burst_exceeded=self._stats.total_burst_exceeded, 

409 )