Coverage for integrations / channels / queue / rate_limit.py: 92.6%
188 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Rate Limiting System
4Limits request rates to prevent abuse and comply with API limits.
5Ported from HevolveBot's src/channels/rate-limit.ts.
7Features:
8- Sliding window rate limiting
9- Per-channel limits
10- Burst handling
11- Token bucket algorithm
12"""
14from __future__ import annotations
16import logging
17import threading
18import time
19from collections import deque
20from dataclasses import dataclass, field
21from datetime import datetime, timedelta
22from enum import Enum
23from typing import Optional, Dict, Deque, Tuple
25logger = logging.getLogger(__name__)
28class RateLimitResult(Enum):
29 """Result of rate limit check."""
30 ALLOWED = "allowed"
31 RATE_LIMITED = "rate_limited"
32 BURST_EXCEEDED = "burst_exceeded"
35@dataclass
36class RateLimitConfig:
37 """Configuration for rate limiting."""
38 requests_per_minute: int = 60
39 requests_per_hour: int = 1000
40 burst_limit: int = 10
41 burst_window_seconds: int = 1
42 per_channel_limits: Dict[str, int] = field(default_factory=dict)
45@dataclass
46class RateLimitInfo:
47 """Information about current rate limit state."""
48 allowed: bool
49 result: RateLimitResult
50 remaining_minute: int
51 remaining_hour: int
52 remaining_burst: int
53 reset_minute_at: datetime
54 reset_hour_at: datetime
55 retry_after_seconds: Optional[float] = None
58@dataclass
59class RateLimitStats:
60 """Statistics for rate limiter."""
61 total_requests: int = 0
62 total_allowed: int = 0
63 total_rate_limited: int = 0
64 total_burst_exceeded: int = 0
67class SlidingWindowCounter:
68 """Sliding window counter for rate limiting."""
70 def __init__(self, window_seconds: int, max_requests: int):
71 self.window_seconds = window_seconds
72 self.max_requests = max_requests
73 self._timestamps: Deque[float] = deque()
74 self._lock = threading.Lock()
76 def _cleanup(self, now: float) -> None:
77 """Remove expired timestamps."""
78 cutoff = now - self.window_seconds
79 while self._timestamps and self._timestamps[0] < cutoff:
80 self._timestamps.popleft()
82 def check(self) -> Tuple[bool, int]:
83 """
84 Check if request is allowed.
86 Returns:
87 Tuple of (allowed, remaining)
88 """
89 now = time.time()
90 with self._lock:
91 self._cleanup(now)
92 remaining = max(0, self.max_requests - len(self._timestamps))
93 return remaining > 0, remaining
95 def consume(self) -> bool:
96 """
97 Consume one request slot.
99 Returns:
100 True if consumed, False if at limit
101 """
102 now = time.time()
103 with self._lock:
104 self._cleanup(now)
105 if len(self._timestamps) >= self.max_requests:
106 return False
107 self._timestamps.append(now)
108 return True
110 def get_remaining(self) -> int:
111 """Get remaining requests in window."""
112 now = time.time()
113 with self._lock:
114 self._cleanup(now)
115 return max(0, self.max_requests - len(self._timestamps))
117 def get_reset_time(self) -> float:
118 """Get time until window resets (oldest request expires)."""
119 now = time.time()
120 with self._lock:
121 self._cleanup(now)
122 if not self._timestamps:
123 return 0
124 oldest = self._timestamps[0]
125 reset_at = oldest + self.window_seconds
126 return max(0, reset_at - now)
128 def reset(self) -> None:
129 """Reset the counter."""
130 with self._lock:
131 self._timestamps.clear()
134class TokenBucket:
135 """Token bucket for burst handling."""
137 def __init__(self, capacity: int, refill_rate: float):
138 """
139 Args:
140 capacity: Maximum tokens in bucket
141 refill_rate: Tokens added per second
142 """
143 self.capacity = capacity
144 self.refill_rate = refill_rate
145 self._tokens = float(capacity)
146 self._last_refill = time.time()
147 self._lock = threading.Lock()
149 def _refill(self) -> None:
150 """Refill tokens based on elapsed time."""
151 now = time.time()
152 elapsed = now - self._last_refill
153 new_tokens = elapsed * self.refill_rate
154 self._tokens = min(self.capacity, self._tokens + new_tokens)
155 self._last_refill = now
157 def consume(self, tokens: int = 1) -> bool:
158 """
159 Try to consume tokens.
161 Returns:
162 True if consumed, False if not enough tokens
163 """
164 with self._lock:
165 self._refill()
166 if self._tokens >= tokens:
167 self._tokens -= tokens
168 return True
169 return False
171 def get_tokens(self) -> float:
172 """Get current token count."""
173 with self._lock:
174 self._refill()
175 return self._tokens
177 def reset(self) -> None:
178 """Reset to full capacity."""
179 with self._lock:
180 self._tokens = float(self.capacity)
181 self._last_refill = time.time()
184class RateLimiter:
185 """
186 Rate limiter with multiple windows and burst handling.
188 Usage:
189 config = RateLimitConfig(requests_per_minute=60, burst_limit=10)
190 limiter = RateLimiter(config)
192 # Check if request is allowed
193 result = limiter.check("telegram", "chat123")
194 if result.allowed:
195 # Consume the slot
196 limiter.consume("telegram", "chat123")
197 # Process request
198 else:
199 # Handle rate limit
200 print(f"Retry after {result.retry_after_seconds} seconds")
201 """
203 def __init__(self, config: RateLimitConfig):
204 self.config = config
206 # Per-key rate limiters
207 self._minute_counters: Dict[str, SlidingWindowCounter] = {}
208 self._hour_counters: Dict[str, SlidingWindowCounter] = {}
209 self._burst_buckets: Dict[str, TokenBucket] = {}
211 self._lock = threading.Lock()
212 self._stats = RateLimitStats()
214 def _get_key(self, channel: str, chat_id: str) -> str:
215 """Get rate limit key."""
216 return f"{channel}:{chat_id}"
218 def _get_limits(self, channel: str) -> Tuple[int, int, int]:
219 """Get limits for a channel."""
220 per_minute = self.config.per_channel_limits.get(
221 channel,
222 self.config.requests_per_minute
223 )
224 per_hour = self.config.requests_per_hour
225 burst = self.config.burst_limit
226 return per_minute, per_hour, burst
228 def _get_or_create_counters(
229 self,
230 key: str,
231 channel: str,
232 ) -> Tuple[SlidingWindowCounter, SlidingWindowCounter, TokenBucket]:
233 """Get or create rate limit counters for a key."""
234 per_minute, per_hour, burst = self._get_limits(channel)
236 with self._lock:
237 if key not in self._minute_counters:
238 self._minute_counters[key] = SlidingWindowCounter(60, per_minute)
239 if key not in self._hour_counters:
240 self._hour_counters[key] = SlidingWindowCounter(3600, per_hour)
241 if key not in self._burst_buckets:
242 # Refill at minute rate
243 refill_rate = per_minute / 60.0
244 self._burst_buckets[key] = TokenBucket(burst, refill_rate)
246 return (
247 self._minute_counters[key],
248 self._hour_counters[key],
249 self._burst_buckets[key],
250 )
252 def check(self, channel: str, chat_id: str) -> RateLimitInfo:
253 """
254 Check if a request is allowed.
256 Args:
257 channel: Channel name
258 chat_id: Chat identifier
260 Returns:
261 RateLimitInfo with result and remaining quotas
262 """
263 key = self._get_key(channel, chat_id)
264 minute_counter, hour_counter, burst_bucket = self._get_or_create_counters(
265 key, channel
266 )
268 self._stats.total_requests += 1
270 # Check burst limit
271 burst_tokens = int(burst_bucket.get_tokens())
272 if burst_tokens <= 0:
273 self._stats.total_burst_exceeded += 1
274 return RateLimitInfo(
275 allowed=False,
276 result=RateLimitResult.BURST_EXCEEDED,
277 remaining_minute=minute_counter.get_remaining(),
278 remaining_hour=hour_counter.get_remaining(),
279 remaining_burst=0,
280 reset_minute_at=datetime.now() + timedelta(seconds=minute_counter.get_reset_time()),
281 reset_hour_at=datetime.now() + timedelta(seconds=hour_counter.get_reset_time()),
282 retry_after_seconds=self.config.burst_window_seconds,
283 )
285 # Check minute limit
286 minute_allowed, minute_remaining = minute_counter.check()
287 if not minute_allowed:
288 self._stats.total_rate_limited += 1
289 return RateLimitInfo(
290 allowed=False,
291 result=RateLimitResult.RATE_LIMITED,
292 remaining_minute=0,
293 remaining_hour=hour_counter.get_remaining(),
294 remaining_burst=burst_tokens,
295 reset_minute_at=datetime.now() + timedelta(seconds=minute_counter.get_reset_time()),
296 reset_hour_at=datetime.now() + timedelta(seconds=hour_counter.get_reset_time()),
297 retry_after_seconds=minute_counter.get_reset_time(),
298 )
300 # Check hour limit
301 hour_allowed, hour_remaining = hour_counter.check()
302 if not hour_allowed:
303 self._stats.total_rate_limited += 1
304 return RateLimitInfo(
305 allowed=False,
306 result=RateLimitResult.RATE_LIMITED,
307 remaining_minute=minute_remaining,
308 remaining_hour=0,
309 remaining_burst=burst_tokens,
310 reset_minute_at=datetime.now() + timedelta(seconds=minute_counter.get_reset_time()),
311 reset_hour_at=datetime.now() + timedelta(seconds=hour_counter.get_reset_time()),
312 retry_after_seconds=hour_counter.get_reset_time(),
313 )
315 self._stats.total_allowed += 1
316 return RateLimitInfo(
317 allowed=True,
318 result=RateLimitResult.ALLOWED,
319 remaining_minute=minute_remaining,
320 remaining_hour=hour_remaining,
321 remaining_burst=burst_tokens,
322 reset_minute_at=datetime.now() + timedelta(seconds=60),
323 reset_hour_at=datetime.now() + timedelta(seconds=3600),
324 )
326 def consume(self, channel: str, chat_id: str) -> bool:
327 """
328 Consume a rate limit slot.
330 Args:
331 channel: Channel name
332 chat_id: Chat identifier
334 Returns:
335 True if consumed, False if at limit
336 """
337 key = self._get_key(channel, chat_id)
338 minute_counter, hour_counter, burst_bucket = self._get_or_create_counters(
339 key, channel
340 )
342 # Try to consume from all
343 if not burst_bucket.consume():
344 return False
345 if not minute_counter.consume():
346 return False
347 if not hour_counter.consume():
348 return False
350 return True
352 def check_and_consume(self, channel: str, chat_id: str) -> RateLimitInfo:
353 """
354 Check and consume in one operation.
356 Returns:
357 RateLimitInfo with result
358 """
359 info = self.check(channel, chat_id)
360 if info.allowed:
361 self.consume(channel, chat_id)
362 return info
364 def get_remaining(self, channel: str, chat_id: str) -> Tuple[int, int, int]:
365 """
366 Get remaining quotas.
368 Returns:
369 Tuple of (remaining_minute, remaining_hour, remaining_burst)
370 """
371 key = self._get_key(channel, chat_id)
372 minute_counter, hour_counter, burst_bucket = self._get_or_create_counters(
373 key, channel
374 )
375 return (
376 minute_counter.get_remaining(),
377 hour_counter.get_remaining(),
378 int(burst_bucket.get_tokens()),
379 )
381 def reset(self, channel: str, chat_id: str) -> None:
382 """Reset rate limits for a specific chat."""
383 key = self._get_key(channel, chat_id)
384 with self._lock:
385 if key in self._minute_counters:
386 self._minute_counters[key].reset()
387 if key in self._hour_counters:
388 self._hour_counters[key].reset()
389 if key in self._burst_buckets:
390 self._burst_buckets[key].reset()
392 def reset_all(self) -> None:
393 """Reset all rate limits."""
394 with self._lock:
395 for counter in self._minute_counters.values():
396 counter.reset()
397 for counter in self._hour_counters.values():
398 counter.reset()
399 for bucket in self._burst_buckets.values():
400 bucket.reset()
402 def get_stats(self) -> RateLimitStats:
403 """Get rate limiter statistics."""
404 return RateLimitStats(
405 total_requests=self._stats.total_requests,
406 total_allowed=self._stats.total_allowed,
407 total_rate_limited=self._stats.total_rate_limited,
408 total_burst_exceeded=self._stats.total_burst_exceeded,
409 )