Coverage for integrations / channels / queue / batching.py: 0.0%
348 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Message Batching System
4Collects messages by key (chat_id, user_id, or channel) and batches them
5together for efficient processing.
7Ported from HevolveBot's src/auto-reply/reply/batch.ts.
9Features:
10- Collect messages by configurable key
11- Max batch size limit
12- Max wait time before auto-flush
13- Manual and automatic flush methods
14- Thread-safe operation
15- Statistics tracking
16"""
18from __future__ import annotations
20import asyncio
21import logging
22import threading
23import time
24from collections import OrderedDict
25from dataclasses import dataclass, field
26from datetime import datetime, timedelta
27from enum import Enum
28from typing import (
29 Optional,
30 Dict,
31 List,
32 Any,
33 Callable,
34 TypeVar,
35 Generic,
36 Tuple,
37 Union,
38)
40logger = logging.getLogger(__name__)
42T = TypeVar('T')
45class BatchKeyType(Enum):
46 """Type of key used for batching messages."""
47 CHAT_ID = "chat_id"
48 USER_ID = "user_id"
49 CHANNEL = "channel"
50 CUSTOM = "custom"
53@dataclass
54class BatchConfig:
55 """Configuration for message batching."""
56 max_batch_size: int = 10
57 max_wait_ms: int = 5000
58 key_type: BatchKeyType = BatchKeyType.CHAT_ID
59 auto_flush: bool = True
60 flush_on_shutdown: bool = True
63@dataclass
64class BatchStats:
65 """Statistics for message batcher."""
66 total_received: int = 0
67 total_batched: int = 0
68 total_flushed: int = 0
69 total_batches_created: int = 0
70 total_auto_flushes: int = 0
71 total_manual_flushes: int = 0
72 total_size_flushes: int = 0
73 current_pending: int = 0
74 current_batch_count: int = 0
76 def to_dict(self) -> Dict[str, Any]:
77 """Convert stats to dictionary."""
78 return {
79 "total_received": self.total_received,
80 "total_batched": self.total_batched,
81 "total_flushed": self.total_flushed,
82 "total_batches_created": self.total_batches_created,
83 "total_auto_flushes": self.total_auto_flushes,
84 "total_manual_flushes": self.total_manual_flushes,
85 "total_size_flushes": self.total_size_flushes,
86 "current_pending": self.current_pending,
87 "current_batch_count": self.current_batch_count,
88 }
91@dataclass
92class Batch(Generic[T]):
93 """A batch of collected messages."""
94 key: str
95 items: List[T] = field(default_factory=list)
96 created_at: datetime = field(default_factory=datetime.now)
97 last_added: datetime = field(default_factory=datetime.now)
98 flush_timer: Optional[asyncio.Task] = field(default=None, repr=False)
99 sync_timer: Optional[threading.Timer] = field(default=None, repr=False)
101 def add(self, item: T) -> None:
102 """Add an item to the batch."""
103 self.items.append(item)
104 self.last_added = datetime.now()
106 def clear(self) -> List[T]:
107 """Clear and return all items."""
108 items = self.items
109 self.items = []
110 return items
112 def size(self) -> int:
113 """Get batch size."""
114 return len(self.items)
116 def age_ms(self) -> float:
117 """Get batch age in milliseconds."""
118 return (datetime.now() - self.created_at).total_seconds() * 1000
120 def cancel_timer(self) -> None:
121 """Cancel any pending flush timer."""
122 if self.flush_timer and not self.flush_timer.done():
123 self.flush_timer.cancel()
124 self.flush_timer = None
125 if self.sync_timer:
126 self.sync_timer.cancel()
127 self.sync_timer = None
130@dataclass
131class BatchResult(Generic[T]):
132 """Result of a batch flush operation."""
133 key: str
134 items: List[T]
135 batch_size: int
136 wait_time_ms: float
137 flush_reason: str # "size", "time", "manual", "shutdown"
140class MessageBatcher(Generic[T]):
141 """
142 Collects messages into batches by key.
144 Messages are grouped by a key (chat_id, user_id, channel, or custom)
145 and batched together until either:
146 - Max batch size is reached
147 - Max wait time expires
148 - Manual flush is called
150 Usage:
151 config = BatchConfig(max_batch_size=10, max_wait_ms=5000)
152 batcher = MessageBatcher(config)
154 # Add messages
155 result = await batcher.add(message, key="chat123")
156 if result:
157 # Batch was flushed
158 process_batch(result.items)
160 # Manual flush
161 batch = batcher.flush("chat123")
163 # Flush all
164 batches = batcher.flush_all()
165 """
167 def __init__(
168 self,
169 config: BatchConfig,
170 key_extractor: Optional[Callable[[T], str]] = None,
171 on_flush: Optional[Callable[[BatchResult[T]], Any]] = None,
172 on_error: Optional[Callable[[Exception, BatchResult[T]], None]] = None,
173 ):
174 """
175 Initialize the message batcher.
177 Args:
178 config: Batching configuration
179 key_extractor: Function to extract key from message
180 on_flush: Callback when batch is flushed
181 on_error: Callback on flush error
182 """
183 self.config = config
184 self.key_extractor = key_extractor
185 self.on_flush = on_flush
186 self.on_error = on_error
187 self._batches: Dict[str, Batch[T]] = {}
188 self._lock = threading.Lock()
189 self._stats = BatchStats()
190 self._shutdown = False
192 def _get_key(
193 self,
194 item: T,
195 key: Optional[str] = None,
196 ) -> str:
197 """
198 Get the batching key for an item.
200 Args:
201 item: The item to get key for
202 key: Optional explicit key
204 Returns:
205 The batching key
206 """
207 if key is not None:
208 return key
210 if self.key_extractor is not None:
211 return self.key_extractor(item)
213 # Try to extract from item attributes based on key_type
214 if self.config.key_type == BatchKeyType.CHAT_ID:
215 if hasattr(item, 'chat_id'):
216 return str(getattr(item, 'chat_id'))
217 elif self.config.key_type == BatchKeyType.USER_ID:
218 if hasattr(item, 'user_id') or hasattr(item, 'sender_id'):
219 return str(getattr(item, 'user_id', None) or getattr(item, 'sender_id', ''))
220 elif self.config.key_type == BatchKeyType.CHANNEL:
221 if hasattr(item, 'channel'):
222 return str(getattr(item, 'channel'))
224 return "default"
226 async def add(
227 self,
228 item: T,
229 key: Optional[str] = None,
230 ) -> Optional[BatchResult[T]]:
231 """
232 Add an item to a batch.
234 Args:
235 item: The item to add
236 key: Optional explicit key (otherwise extracted from item)
238 Returns:
239 BatchResult if batch was flushed, None if buffered
240 """
241 batch_key = self._get_key(item, key)
243 self._stats.total_received += 1
245 with self._lock:
246 # Get or create batch
247 if batch_key not in self._batches:
248 self._batches[batch_key] = Batch(key=batch_key)
249 self._stats.total_batches_created += 1
250 self._stats.current_batch_count = len(self._batches)
252 batch = self._batches[batch_key]
253 batch.add(item)
254 self._stats.total_batched += 1
255 self._stats.current_pending += 1
257 # Check if batch is full
258 if batch.size() >= self.config.max_batch_size:
259 # Flush immediately
260 return await self._flush_batch(batch_key, "size")
262 # Schedule auto-flush timer if enabled
263 if self.config.auto_flush and self.config.max_wait_ms > 0:
264 batch.cancel_timer()
265 batch.flush_timer = asyncio.create_task(
266 self._timer_flush(batch_key)
267 )
269 return None
271 async def _timer_flush(self, key: str) -> None:
272 """Timer callback for auto-flush."""
273 try:
274 await asyncio.sleep(self.config.max_wait_ms / 1000.0)
275 await self._flush_batch(key, "time")
276 except asyncio.CancelledError:
277 pass
278 except Exception as e:
279 logger.error(f"Error in batch timer flush: {e}")
281 async def _flush_batch(
282 self,
283 key: str,
284 reason: str,
285 ) -> Optional[BatchResult[T]]:
286 """
287 Flush a specific batch.
289 Args:
290 key: Batch key
291 reason: Reason for flush
293 Returns:
294 BatchResult with flushed items
295 """
296 with self._lock:
297 if key not in self._batches:
298 return None
300 batch = self._batches[key]
301 items = batch.clear()
302 wait_time_ms = batch.age_ms()
304 batch.cancel_timer()
305 del self._batches[key]
307 self._stats.current_pending -= len(items)
308 self._stats.current_batch_count = len(self._batches)
310 if not items:
311 return None
313 self._stats.total_flushed += len(items)
315 if reason == "time":
316 self._stats.total_auto_flushes += 1
317 elif reason == "size":
318 self._stats.total_size_flushes += 1
319 elif reason == "manual":
320 self._stats.total_manual_flushes += 1
322 result = BatchResult(
323 key=key,
324 items=items,
325 batch_size=len(items),
326 wait_time_ms=wait_time_ms,
327 flush_reason=reason,
328 )
330 # Call flush callback
331 if self.on_flush:
332 try:
333 callback_result = self.on_flush(result)
334 if asyncio.iscoroutine(callback_result):
335 await callback_result
336 except Exception as e:
337 if self.on_error:
338 self.on_error(e, result)
339 logger.error(f"Error in batch flush callback: {e}")
341 return result
343 async def flush(self, key: str) -> Optional[BatchResult[T]]:
344 """
345 Manually flush a specific batch.
347 Args:
348 key: Batch key to flush
350 Returns:
351 BatchResult with flushed items, or None if no batch
352 """
353 return await self._flush_batch(key, "manual")
355 def flush_sync(self, key: str) -> Optional[BatchResult[T]]:
356 """
357 Synchronously flush a specific batch.
359 Args:
360 key: Batch key to flush
362 Returns:
363 BatchResult with flushed items
364 """
365 with self._lock:
366 if key not in self._batches:
367 return None
369 batch = self._batches[key]
370 items = batch.clear()
371 wait_time_ms = batch.age_ms()
373 batch.cancel_timer()
374 del self._batches[key]
376 self._stats.current_pending -= len(items)
377 self._stats.current_batch_count = len(self._batches)
378 self._stats.total_flushed += len(items)
379 self._stats.total_manual_flushes += 1
381 if not items:
382 return None
384 return BatchResult(
385 key=key,
386 items=items,
387 batch_size=len(items),
388 wait_time_ms=wait_time_ms,
389 flush_reason="manual",
390 )
392 async def flush_all(self) -> List[BatchResult[T]]:
393 """
394 Flush all batches.
396 Returns:
397 List of BatchResults for each flushed batch
398 """
399 with self._lock:
400 keys = list(self._batches.keys())
402 results = []
403 for key in keys:
404 result = await self.flush(key)
405 if result:
406 results.append(result)
408 return results
410 def flush_all_sync(self) -> List[BatchResult[T]]:
411 """
412 Synchronously flush all batches.
414 Returns:
415 List of BatchResults
416 """
417 with self._lock:
418 keys = list(self._batches.keys())
420 results = []
421 for key in keys:
422 result = self.flush_sync(key)
423 if result:
424 results.append(result)
426 return results
428 def get_batch(self, key: str) -> Optional[List[T]]:
429 """
430 Get items in a batch without flushing.
432 Args:
433 key: Batch key
435 Returns:
436 List of items or None if no batch
437 """
438 with self._lock:
439 if key not in self._batches:
440 return None
441 return list(self._batches[key].items)
443 def get_batch_size(self, key: str) -> int:
444 """
445 Get size of a specific batch.
447 Args:
448 key: Batch key
450 Returns:
451 Number of items in batch
452 """
453 with self._lock:
454 if key not in self._batches:
455 return 0
456 return self._batches[key].size()
458 def get_pending_count(self) -> int:
459 """Get total pending items across all batches."""
460 with self._lock:
461 return sum(b.size() for b in self._batches.values())
463 def get_batch_count(self) -> int:
464 """Get number of active batches."""
465 with self._lock:
466 return len(self._batches)
468 def get_batch_keys(self) -> List[str]:
469 """Get list of active batch keys."""
470 with self._lock:
471 return list(self._batches.keys())
473 def get_stats(self) -> BatchStats:
474 """Get batching statistics."""
475 with self._lock:
476 self._stats.current_pending = sum(b.size() for b in self._batches.values())
477 self._stats.current_batch_count = len(self._batches)
479 return BatchStats(
480 total_received=self._stats.total_received,
481 total_batched=self._stats.total_batched,
482 total_flushed=self._stats.total_flushed,
483 total_batches_created=self._stats.total_batches_created,
484 total_auto_flushes=self._stats.total_auto_flushes,
485 total_manual_flushes=self._stats.total_manual_flushes,
486 total_size_flushes=self._stats.total_size_flushes,
487 current_pending=self._stats.current_pending,
488 current_batch_count=self._stats.current_batch_count,
489 )
491 def clear(self) -> int:
492 """
493 Clear all batches without flushing.
495 Returns:
496 Number of items cleared
497 """
498 with self._lock:
499 total = 0
500 for batch in self._batches.values():
501 total += batch.size()
502 batch.cancel_timer()
503 self._batches.clear()
504 self._stats.current_pending = 0
505 self._stats.current_batch_count = 0
506 return total
508 async def shutdown(self) -> List[BatchResult[T]]:
509 """
510 Shutdown the batcher, flushing remaining batches if configured.
512 Returns:
513 List of flushed BatchResults
514 """
515 self._shutdown = True
517 if self.config.flush_on_shutdown:
518 return await self.flush_all()
520 self.clear()
521 return []
524class SyncMessageBatcher(Generic[T]):
525 """
526 Synchronous version of MessageBatcher.
528 Uses threading.Timer for auto-flush instead of asyncio.
530 Usage:
531 config = BatchConfig(max_batch_size=10, max_wait_ms=5000)
532 batcher = SyncMessageBatcher(config)
534 # Add messages
535 result = batcher.add(message, key="chat123")
536 if result:
537 process_batch(result.items)
538 """
540 def __init__(
541 self,
542 config: BatchConfig,
543 key_extractor: Optional[Callable[[T], str]] = None,
544 on_flush: Optional[Callable[[BatchResult[T]], None]] = None,
545 ):
546 self.config = config
547 self.key_extractor = key_extractor
548 self.on_flush = on_flush
549 self._batches: Dict[str, Batch[T]] = {}
550 self._lock = threading.Lock()
551 self._stats = BatchStats()
553 def _get_key(self, item: T, key: Optional[str] = None) -> str:
554 """Get batching key for item."""
555 if key is not None:
556 return key
557 if self.key_extractor is not None:
558 return self.key_extractor(item)
559 if hasattr(item, 'chat_id'):
560 return str(getattr(item, 'chat_id'))
561 return "default"
563 def add(
564 self,
565 item: T,
566 key: Optional[str] = None,
567 ) -> Optional[BatchResult[T]]:
568 """
569 Add an item to a batch.
571 Args:
572 item: The item to add
573 key: Optional explicit key
575 Returns:
576 BatchResult if flushed, None if buffered
577 """
578 batch_key = self._get_key(item, key)
580 self._stats.total_received += 1
582 with self._lock:
583 if batch_key not in self._batches:
584 self._batches[batch_key] = Batch(key=batch_key)
585 self._stats.total_batches_created += 1
587 batch = self._batches[batch_key]
588 batch.add(item)
589 self._stats.total_batched += 1
590 self._stats.current_pending += 1
592 # Check if full
593 if batch.size() >= self.config.max_batch_size:
594 return self._flush_batch_locked(batch_key, "size")
596 # Schedule timer
597 if self.config.auto_flush and self.config.max_wait_ms > 0:
598 batch.cancel_timer()
599 timer = threading.Timer(
600 self.config.max_wait_ms / 1000.0,
601 self._timer_flush,
602 args=[batch_key],
603 )
604 timer.daemon = True
605 timer.start()
606 batch.sync_timer = timer
608 return None
610 def _timer_flush(self, key: str) -> None:
611 """Timer callback."""
612 result = self.flush(key, reason="time")
613 if result and self.on_flush:
614 self.on_flush(result)
616 def _flush_batch_locked(
617 self,
618 key: str,
619 reason: str,
620 ) -> Optional[BatchResult[T]]:
621 """Flush batch while holding lock."""
622 if key not in self._batches:
623 return None
625 batch = self._batches[key]
626 items = batch.clear()
627 wait_time_ms = batch.age_ms()
629 batch.cancel_timer()
630 del self._batches[key]
632 self._stats.current_pending -= len(items)
633 self._stats.current_batch_count = len(self._batches)
635 if not items:
636 return None
638 self._stats.total_flushed += len(items)
639 if reason == "time":
640 self._stats.total_auto_flushes += 1
641 elif reason == "size":
642 self._stats.total_size_flushes += 1
643 elif reason == "manual":
644 self._stats.total_manual_flushes += 1
646 result = BatchResult(
647 key=key,
648 items=items,
649 batch_size=len(items),
650 wait_time_ms=wait_time_ms,
651 flush_reason=reason,
652 )
654 if self.on_flush and reason != "time": # Timer calls on_flush itself
655 self.on_flush(result)
657 return result
659 def flush(
660 self,
661 key: str,
662 reason: str = "manual",
663 ) -> Optional[BatchResult[T]]:
664 """
665 Flush a specific batch.
667 Args:
668 key: Batch key
669 reason: Reason for flush
671 Returns:
672 BatchResult or None
673 """
674 with self._lock:
675 return self._flush_batch_locked(key, reason)
677 def flush_all(self) -> List[BatchResult[T]]:
678 """Flush all batches."""
679 with self._lock:
680 keys = list(self._batches.keys())
682 results = []
683 for key in keys:
684 result = self.flush(key)
685 if result:
686 results.append(result)
688 return results
690 def get_pending_count(self) -> int:
691 """Get total pending items."""
692 with self._lock:
693 return sum(b.size() for b in self._batches.values())
695 def get_batch_count(self) -> int:
696 """Get number of active batches."""
697 with self._lock:
698 return len(self._batches)
700 def get_batch_keys(self) -> List[str]:
701 """Get active batch keys."""
702 with self._lock:
703 return list(self._batches.keys())
705 def get_stats(self) -> BatchStats:
706 """Get statistics."""
707 with self._lock:
708 self._stats.current_pending = sum(b.size() for b in self._batches.values())
709 self._stats.current_batch_count = len(self._batches)
711 return BatchStats(
712 total_received=self._stats.total_received,
713 total_batched=self._stats.total_batched,
714 total_flushed=self._stats.total_flushed,
715 total_batches_created=self._stats.total_batches_created,
716 total_auto_flushes=self._stats.total_auto_flushes,
717 total_manual_flushes=self._stats.total_manual_flushes,
718 total_size_flushes=self._stats.total_size_flushes,
719 current_pending=self._stats.current_pending,
720 current_batch_count=self._stats.current_batch_count,
721 )
723 def clear(self) -> int:
724 """Clear all batches without flushing."""
725 with self._lock:
726 total = sum(b.size() for b in self._batches.values())
727 for batch in self._batches.values():
728 batch.cancel_timer()
729 self._batches.clear()
730 self._stats.current_pending = 0
731 self._stats.current_batch_count = 0
732 return total
735class BatchAggregator(Generic[T]):
736 """
737 Aggregates multiple batches into larger groups.
739 Useful for combining batches from multiple sources before processing.
740 """
742 def __init__(
743 self,
744 max_aggregate_size: int = 100,
745 max_sources: int = 10,
746 ):
747 self.max_aggregate_size = max_aggregate_size
748 self.max_sources = max_sources
749 self._pending: Dict[str, List[BatchResult[T]]] = {}
750 self._lock = threading.Lock()
752 def add_batch(
753 self,
754 batch: BatchResult[T],
755 aggregate_key: str = "default",
756 ) -> Optional[List[BatchResult[T]]]:
757 """
758 Add a batch to the aggregator.
760 Args:
761 batch: BatchResult to add
762 aggregate_key: Key for grouping batches
764 Returns:
765 List of batches if aggregate threshold reached
766 """
767 with self._lock:
768 if aggregate_key not in self._pending:
769 self._pending[aggregate_key] = []
771 self._pending[aggregate_key].append(batch)
773 # Check if we should flush
774 total_items = sum(b.batch_size for b in self._pending[aggregate_key])
775 if total_items >= self.max_aggregate_size or len(self._pending[aggregate_key]) >= self.max_sources:
776 batches = self._pending.pop(aggregate_key)
777 return batches
779 return None
781 def flush(self, aggregate_key: str) -> List[BatchResult[T]]:
782 """Flush a specific aggregate."""
783 with self._lock:
784 return self._pending.pop(aggregate_key, [])
786 def flush_all(self) -> Dict[str, List[BatchResult[T]]]:
787 """Flush all aggregates."""
788 with self._lock:
789 result = dict(self._pending)
790 self._pending.clear()
791 return result
793 def get_pending_count(self, aggregate_key: Optional[str] = None) -> int:
794 """Get pending batch count."""
795 with self._lock:
796 if aggregate_key:
797 return len(self._pending.get(aggregate_key, []))
798 return sum(len(batches) for batches in self._pending.values())