Coverage for integrations / channels / media / image_gen.py: 66.7%
219 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Image Generator for AI image generation.
4Supports multiple providers: openai, stability, midjourney
5"""
7import asyncio
8import base64
9from dataclasses import dataclass, field
10from enum import Enum
11from typing import Optional, List, Dict, Any, Union
12from pathlib import Path
13import logging
14import os
15import time
16import hashlib
18logger = logging.getLogger(__name__)
20# Docker-compatible paths
21TEMP_DIR = os.environ.get("IMAGE_TEMP_DIR", "/tmp/images")
22APP_TEMP_DIR = os.environ.get("APP_TEMP_DIR", "/app/temp")
25class ImageProvider(Enum):
26 """Supported image generation providers."""
27 OPENAI = "openai"
28 STABILITY = "stability"
29 MIDJOURNEY = "midjourney"
32class ImageSize(Enum):
33 """Standard image sizes."""
34 SQUARE_SMALL = "256x256"
35 SQUARE_MEDIUM = "512x512"
36 SQUARE_LARGE = "1024x1024"
37 LANDSCAPE = "1792x1024"
38 PORTRAIT = "1024x1792"
39 HD_LANDSCAPE = "1920x1080"
40 HD_PORTRAIT = "1080x1920"
43class ImageStyle(Enum):
44 """Image style presets."""
45 VIVID = "vivid"
46 NATURAL = "natural"
47 ANIME = "anime"
48 PHOTOGRAPHIC = "photographic"
49 DIGITAL_ART = "digital-art"
50 CINEMATIC = "cinematic"
51 FANTASY = "fantasy"
52 NEON_PUNK = "neon-punk"
53 ISOMETRIC = "isometric"
54 ORIGAMI = "origami"
57@dataclass
58class GeneratedImage:
59 """A generated image result."""
60 data: bytes
61 format: str # png, jpg, webp
62 width: int
63 height: int
64 prompt: str
65 revised_prompt: Optional[str] = None
66 seed: Optional[int] = None
67 provider: Optional[str] = None
68 model: Optional[str] = None
69 metadata: Dict[str, Any] = field(default_factory=dict)
71 def to_dict(self) -> Dict[str, Any]:
72 return {
73 "format": self.format,
74 "width": self.width,
75 "height": self.height,
76 "prompt": self.prompt,
77 "revised_prompt": self.revised_prompt,
78 "seed": self.seed,
79 "provider": self.provider,
80 "model": self.model,
81 "size": len(self.data),
82 "metadata": self.metadata
83 }
85 def to_base64(self) -> str:
86 """Convert image data to base64 string."""
87 return base64.b64encode(self.data).decode('utf-8')
89 def to_data_url(self) -> str:
90 """Convert to data URL for embedding."""
91 mime_types = {
92 "png": "image/png",
93 "jpg": "image/jpeg",
94 "jpeg": "image/jpeg",
95 "webp": "image/webp"
96 }
97 mime = mime_types.get(self.format.lower(), "image/png")
98 return f"data:{mime};base64,{self.to_base64()}"
101@dataclass
102class EditResult:
103 """Result of an image edit operation."""
104 original_size: int
105 edited_size: int
106 edited_image: GeneratedImage
107 operation: str # "edit", "inpaint", "outpaint"
108 metadata: Dict[str, Any] = field(default_factory=dict)
110 def to_dict(self) -> Dict[str, Any]:
111 return {
112 "original_size": self.original_size,
113 "edited_size": self.edited_size,
114 "edited_image": self.edited_image.to_dict(),
115 "operation": self.operation,
116 "metadata": self.metadata
117 }
120@dataclass
121class VariationResult:
122 """Result of image variation operation."""
123 original_prompt: Optional[str]
124 variations: List[GeneratedImage]
125 count: int
126 metadata: Dict[str, Any] = field(default_factory=dict)
128 def to_dict(self) -> Dict[str, Any]:
129 return {
130 "original_prompt": self.original_prompt,
131 "variations": [v.to_dict() for v in self.variations],
132 "count": self.count,
133 "metadata": self.metadata
134 }
137class ImageGenerator:
138 """
139 Image generator for AI image generation.
141 Supports multiple providers for text-to-image generation.
142 """
144 # Available providers
145 providers: List[str] = ["openai", "stability", "midjourney"]
147 # Default models per provider
148 DEFAULT_MODELS = {
149 ImageProvider.OPENAI: "dall-e-3",
150 ImageProvider.STABILITY: "stable-diffusion-xl-1024-v1-0",
151 ImageProvider.MIDJOURNEY: "v6"
152 }
154 # Supported sizes per provider
155 SUPPORTED_SIZES = {
156 ImageProvider.OPENAI: ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
157 ImageProvider.STABILITY: ["512x512", "768x768", "1024x1024", "1536x1536"],
158 ImageProvider.MIDJOURNEY: ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]
159 }
161 def __init__(
162 self,
163 provider: Union[ImageProvider, str] = ImageProvider.OPENAI,
164 api_key: Optional[str] = None,
165 model: Optional[str] = None,
166 config: Optional[Dict[str, Any]] = None
167 ):
168 """
169 Initialize image generator.
171 Args:
172 provider: Image generation provider to use
173 api_key: API key for the provider
174 model: Specific model to use
175 config: Additional configuration options
176 """
177 if isinstance(provider, str):
178 provider = ImageProvider(provider.lower())
180 self.provider = provider
181 self.api_key = api_key
182 self.config = config or {}
184 # Set default model per provider
185 self.model = model or self.DEFAULT_MODELS.get(provider, "default")
187 # Initialize provider-specific client
188 self._client = None
189 self._initialized = False
191 # Rate limiting
192 self._last_request_time = 0
193 self._min_request_interval = config.get("min_request_interval", 1.0) if config else 1.0
195 # Ensure temp directories exist
196 self._ensure_temp_dirs()
198 def _ensure_temp_dirs(self):
199 """Ensure temp directories exist (Docker-compatible)."""
200 for dir_path in [TEMP_DIR, APP_TEMP_DIR]:
201 try:
202 Path(dir_path).mkdir(parents=True, exist_ok=True)
203 except (PermissionError, OSError):
204 pass
206 async def _ensure_initialized(self):
207 """Ensure provider client is initialized."""
208 if self._initialized:
209 return
211 if self.provider == ImageProvider.OPENAI:
212 # Would initialize OpenAI client
213 pass
214 elif self.provider == ImageProvider.STABILITY:
215 # Would initialize Stability AI client
216 pass
217 elif self.provider == ImageProvider.MIDJOURNEY:
218 # Would initialize Midjourney client (via Discord or API)
219 pass
221 self._initialized = True
223 async def _rate_limit(self):
224 """Apply rate limiting between requests."""
225 now = time.time()
226 elapsed = now - self._last_request_time
227 if elapsed < self._min_request_interval:
228 await asyncio.sleep(self._min_request_interval - elapsed)
229 self._last_request_time = time.time()
231 def _normalize_size(self, size: str) -> str:
232 """Normalize size string to supported format."""
233 # Handle ImageSize enum
234 if isinstance(size, ImageSize):
235 size = size.value
237 # Check if size is supported
238 supported = self.SUPPORTED_SIZES.get(self.provider, ["1024x1024"])
239 if size in supported:
240 return size
242 # Find closest supported size
243 try:
244 w, h = map(int, size.split("x"))
245 target_pixels = w * h
247 best_size = supported[0]
248 best_diff = float("inf")
250 for s in supported:
251 sw, sh = map(int, s.split("x"))
252 diff = abs(sw * sh - target_pixels)
253 if diff < best_diff:
254 best_diff = diff
255 best_size = s
257 return best_size
258 except ValueError:
259 return "1024x1024"
261 async def generate(
262 self,
263 prompt: str,
264 size: str = "1024x1024",
265 style: Optional[Union[ImageStyle, str]] = None,
266 quality: str = "standard",
267 n: int = 1
268 ) -> bytes:
269 """
270 Generate image from prompt.
272 Args:
273 prompt: Text description of the image to generate
274 size: Image size (e.g., "1024x1024")
275 style: Style preset (vivid, natural, etc.)
276 quality: Quality level (standard, hd)
277 n: Number of images to generate
279 Returns:
280 Image bytes (first image if n > 1)
281 """
282 await self._ensure_initialized()
283 await self._rate_limit()
285 size = self._normalize_size(size)
286 if isinstance(style, ImageStyle):
287 style = style.value
289 logger.info(f"Generating image: {prompt[:50]}... ({size})")
291 # Provider-specific generation
292 if self.provider == ImageProvider.OPENAI:
293 return await self._generate_openai(prompt, size, style, quality, n)
294 elif self.provider == ImageProvider.STABILITY:
295 return await self._generate_stability(prompt, size, style, n)
296 elif self.provider == ImageProvider.MIDJOURNEY:
297 return await self._generate_midjourney(prompt, size, style)
299 return b""
301 async def _generate_openai(
302 self,
303 prompt: str,
304 size: str,
305 style: Optional[str],
306 quality: str,
307 n: int
308 ) -> bytes:
309 """Generate using OpenAI DALL-E."""
310 # Would use OpenAI API:
311 # response = await self._client.images.generate(
312 # model=self.model,
313 # prompt=prompt,
314 # size=size,
315 # style=style or "vivid",
316 # quality=quality,
317 # n=n,
318 # response_format="b64_json"
319 # )
320 # return base64.b64decode(response.data[0].b64_json)
321 return b""
323 async def _generate_stability(
324 self,
325 prompt: str,
326 size: str,
327 style: Optional[str],
328 n: int
329 ) -> bytes:
330 """Generate using Stability AI."""
331 # Would use Stability API
332 return b""
334 async def _generate_midjourney(
335 self,
336 prompt: str,
337 size: str,
338 style: Optional[str]
339 ) -> bytes:
340 """Generate using Midjourney."""
341 # Would use Midjourney API/Discord integration
342 return b""
344 async def generate_multiple(
345 self,
346 prompt: str,
347 n: int = 4,
348 size: str = "1024x1024",
349 style: Optional[Union[ImageStyle, str]] = None
350 ) -> List[GeneratedImage]:
351 """
352 Generate multiple images from prompt.
354 Args:
355 prompt: Text description
356 n: Number of images to generate
357 size: Image size
358 style: Style preset
360 Returns:
361 List of GeneratedImage objects
362 """
363 await self._ensure_initialized()
365 size = self._normalize_size(size)
366 w, h = map(int, size.split("x"))
368 images = []
369 for i in range(n):
370 await self._rate_limit()
371 data = await self.generate(prompt, size, style, n=1)
372 if data:
373 images.append(GeneratedImage(
374 data=data,
375 format="png",
376 width=w,
377 height=h,
378 prompt=prompt,
379 provider=self.provider.value,
380 model=self.model,
381 metadata={"index": i}
382 ))
384 return images
386 async def edit(
387 self,
388 image: bytes,
389 prompt: str,
390 mask: Optional[bytes] = None,
391 size: str = "1024x1024"
392 ) -> bytes:
393 """
394 Edit an existing image based on prompt.
396 Args:
397 image: Original image bytes
398 prompt: Edit instruction
399 mask: Optional mask indicating areas to edit (transparent = edit)
400 size: Output size
402 Returns:
403 Edited image bytes
404 """
405 await self._ensure_initialized()
406 await self._rate_limit()
408 size = self._normalize_size(size)
410 logger.info(f"Editing image with prompt: {prompt[:50]}...")
412 # Provider-specific editing
413 if self.provider == ImageProvider.OPENAI:
414 return await self._edit_openai(image, prompt, mask, size)
415 elif self.provider == ImageProvider.STABILITY:
416 return await self._edit_stability(image, prompt, mask, size)
418 return b""
420 async def _edit_openai(
421 self,
422 image: bytes,
423 prompt: str,
424 mask: Optional[bytes],
425 size: str
426 ) -> bytes:
427 """Edit using OpenAI DALL-E."""
428 # Would use OpenAI API:
429 # response = await self._client.images.edit(
430 # model="dall-e-2", # Only DALL-E 2 supports edit
431 # image=image,
432 # mask=mask,
433 # prompt=prompt,
434 # size=size,
435 # response_format="b64_json"
436 # )
437 # return base64.b64decode(response.data[0].b64_json)
438 return b""
440 async def _edit_stability(
441 self,
442 image: bytes,
443 prompt: str,
444 mask: Optional[bytes],
445 size: str
446 ) -> bytes:
447 """Edit using Stability AI."""
448 # Would use Stability API inpainting
449 return b""
451 async def variations(
452 self,
453 image: bytes,
454 n: int = 1,
455 size: str = "1024x1024"
456 ) -> List[bytes]:
457 """
458 Generate variations of an image.
460 Args:
461 image: Original image bytes
462 n: Number of variations to generate
463 size: Output size
465 Returns:
466 List of variation image bytes
467 """
468 await self._ensure_initialized()
470 size = self._normalize_size(size)
472 logger.info(f"Generating {n} variations")
474 variations = []
475 for _ in range(n):
476 await self._rate_limit()
477 var = await self._generate_variation(image, size)
478 if var:
479 variations.append(var)
481 return variations
483 async def _generate_variation(
484 self,
485 image: bytes,
486 size: str
487 ) -> bytes:
488 """Generate single variation."""
489 if self.provider == ImageProvider.OPENAI:
490 # Would use OpenAI API:
491 # response = await self._client.images.create_variation(
492 # model="dall-e-2", # Only DALL-E 2 supports variations
493 # image=image,
494 # size=size,
495 # response_format="b64_json"
496 # )
497 # return base64.b64decode(response.data[0].b64_json)
498 pass
499 elif self.provider == ImageProvider.STABILITY:
500 # Would use Stability API image-to-image
501 pass
503 return b""
505 async def upscale(
506 self,
507 image: bytes,
508 scale: int = 2
509 ) -> bytes:
510 """
511 Upscale an image.
513 Args:
514 image: Image to upscale
515 scale: Scale factor (2x, 4x)
517 Returns:
518 Upscaled image bytes
519 """
520 await self._ensure_initialized()
521 await self._rate_limit()
523 if self.provider == ImageProvider.STABILITY:
524 # Stability AI has dedicated upscaling
525 # Would use their upscale API
526 pass
528 # Placeholder - would use actual upscaling
529 return image
531 async def save_to_file(
532 self,
533 prompt: str,
534 file_path: str,
535 size: str = "1024x1024",
536 style: Optional[Union[ImageStyle, str]] = None
537 ) -> str:
538 """
539 Generate and save image to file.
541 Args:
542 prompt: Text description
543 file_path: Output file path
544 size: Image size
545 style: Style preset
547 Returns:
548 Path to saved file
549 """
550 image_data = await self.generate(prompt, size, style)
552 path = Path(file_path)
553 path.parent.mkdir(parents=True, exist_ok=True)
555 with open(path, 'wb') as f:
556 f.write(image_data)
558 return str(path)
560 def get_temp_path(self, prefix: str = "img") -> str:
561 """
562 Get a temporary file path for image storage.
564 Args:
565 prefix: File name prefix
567 Returns:
568 Temporary file path (Docker-compatible)
569 """
570 timestamp = int(time.time() * 1000)
571 random_hash = hashlib.md5(str(timestamp).encode()).hexdigest()[:8]
572 filename = f"{prefix}_{timestamp}_{random_hash}.png"
573 return os.path.join(TEMP_DIR, filename)
575 def get_supported_sizes(self) -> List[str]:
576 """Get list of supported sizes for current provider."""
577 return self.SUPPORTED_SIZES.get(self.provider, ["1024x1024"])
579 def get_supported_styles(self) -> List[str]:
580 """Get list of supported styles for current provider."""
581 styles = {
582 ImageProvider.OPENAI: ["vivid", "natural"],
583 ImageProvider.STABILITY: [s.value for s in ImageStyle],
584 ImageProvider.MIDJOURNEY: ["raw", "cute", "scenic", "expressive", "original"]
585 }
586 return styles.get(self.provider, [])
588 def get_max_prompt_length(self) -> int:
589 """Get maximum prompt length for current provider."""
590 limits = {
591 ImageProvider.OPENAI: 4000,
592 ImageProvider.STABILITY: 2000,
593 ImageProvider.MIDJOURNEY: 6000
594 }
595 return limits.get(self.provider, 2000)
597 def estimate_cost(
598 self,
599 size: str = "1024x1024",
600 quality: str = "standard",
601 n: int = 1
602 ) -> float:
603 """
604 Estimate cost for generation.
606 Args:
607 size: Image size
608 quality: Quality level
609 n: Number of images
611 Returns:
612 Estimated cost in USD
613 """
614 # Approximate pricing (may be outdated)
615 pricing = {
616 ImageProvider.OPENAI: {
617 ("1024x1024", "standard"): 0.04,
618 ("1024x1024", "hd"): 0.08,
619 ("1792x1024", "standard"): 0.08,
620 ("1792x1024", "hd"): 0.12,
621 ("1024x1792", "standard"): 0.08,
622 ("1024x1792", "hd"): 0.12,
623 },
624 ImageProvider.STABILITY: {
625 ("512x512", "standard"): 0.002,
626 ("1024x1024", "standard"): 0.008,
627 }
628 }
630 provider_pricing = pricing.get(self.provider, {})
631 cost = provider_pricing.get((size, quality), 0.04)
633 return cost * n
635 def get_provider_info(self) -> Dict[str, Any]:
636 """Get information about the current provider."""
637 return {
638 "provider": self.provider.value,
639 "model": self.model,
640 "supported_sizes": self.get_supported_sizes(),
641 "supported_styles": self.get_supported_styles(),
642 "max_prompt_length": self.get_max_prompt_length(),
643 "supports_edit": self.provider in [ImageProvider.OPENAI, ImageProvider.STABILITY],
644 "supports_variations": self.provider in [ImageProvider.OPENAI, ImageProvider.STABILITY],
645 "supports_upscale": self.provider == ImageProvider.STABILITY
646 }