Coverage for integrations / channels / media / image_gen.py: 66.7%

219 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Image Generator for AI image generation. 

3 

4Supports multiple providers: openai, stability, midjourney 

5""" 

6 

7import asyncio 

8import base64 

9from dataclasses import dataclass, field 

10from enum import Enum 

11from typing import Optional, List, Dict, Any, Union 

12from pathlib import Path 

13import logging 

14import os 

15import time 

16import hashlib 

17 

18logger = logging.getLogger(__name__) 

19 

20# Docker-compatible paths 

21TEMP_DIR = os.environ.get("IMAGE_TEMP_DIR", "/tmp/images") 

22APP_TEMP_DIR = os.environ.get("APP_TEMP_DIR", "/app/temp") 

23 

24 

25class ImageProvider(Enum): 

26 """Supported image generation providers.""" 

27 OPENAI = "openai" 

28 STABILITY = "stability" 

29 MIDJOURNEY = "midjourney" 

30 

31 

32class ImageSize(Enum): 

33 """Standard image sizes.""" 

34 SQUARE_SMALL = "256x256" 

35 SQUARE_MEDIUM = "512x512" 

36 SQUARE_LARGE = "1024x1024" 

37 LANDSCAPE = "1792x1024" 

38 PORTRAIT = "1024x1792" 

39 HD_LANDSCAPE = "1920x1080" 

40 HD_PORTRAIT = "1080x1920" 

41 

42 

43class ImageStyle(Enum): 

44 """Image style presets.""" 

45 VIVID = "vivid" 

46 NATURAL = "natural" 

47 ANIME = "anime" 

48 PHOTOGRAPHIC = "photographic" 

49 DIGITAL_ART = "digital-art" 

50 CINEMATIC = "cinematic" 

51 FANTASY = "fantasy" 

52 NEON_PUNK = "neon-punk" 

53 ISOMETRIC = "isometric" 

54 ORIGAMI = "origami" 

55 

56 

57@dataclass 

58class GeneratedImage: 

59 """A generated image result.""" 

60 data: bytes 

61 format: str # png, jpg, webp 

62 width: int 

63 height: int 

64 prompt: str 

65 revised_prompt: Optional[str] = None 

66 seed: Optional[int] = None 

67 provider: Optional[str] = None 

68 model: Optional[str] = None 

69 metadata: Dict[str, Any] = field(default_factory=dict) 

70 

71 def to_dict(self) -> Dict[str, Any]: 

72 return { 

73 "format": self.format, 

74 "width": self.width, 

75 "height": self.height, 

76 "prompt": self.prompt, 

77 "revised_prompt": self.revised_prompt, 

78 "seed": self.seed, 

79 "provider": self.provider, 

80 "model": self.model, 

81 "size": len(self.data), 

82 "metadata": self.metadata 

83 } 

84 

85 def to_base64(self) -> str: 

86 """Convert image data to base64 string.""" 

87 return base64.b64encode(self.data).decode('utf-8') 

88 

89 def to_data_url(self) -> str: 

90 """Convert to data URL for embedding.""" 

91 mime_types = { 

92 "png": "image/png", 

93 "jpg": "image/jpeg", 

94 "jpeg": "image/jpeg", 

95 "webp": "image/webp" 

96 } 

97 mime = mime_types.get(self.format.lower(), "image/png") 

98 return f"data:{mime};base64,{self.to_base64()}" 

99 

100 

101@dataclass 

102class EditResult: 

103 """Result of an image edit operation.""" 

104 original_size: int 

105 edited_size: int 

106 edited_image: GeneratedImage 

107 operation: str # "edit", "inpaint", "outpaint" 

108 metadata: Dict[str, Any] = field(default_factory=dict) 

109 

110 def to_dict(self) -> Dict[str, Any]: 

111 return { 

112 "original_size": self.original_size, 

113 "edited_size": self.edited_size, 

114 "edited_image": self.edited_image.to_dict(), 

115 "operation": self.operation, 

116 "metadata": self.metadata 

117 } 

118 

119 

120@dataclass 

121class VariationResult: 

122 """Result of image variation operation.""" 

123 original_prompt: Optional[str] 

124 variations: List[GeneratedImage] 

125 count: int 

126 metadata: Dict[str, Any] = field(default_factory=dict) 

127 

128 def to_dict(self) -> Dict[str, Any]: 

129 return { 

130 "original_prompt": self.original_prompt, 

131 "variations": [v.to_dict() for v in self.variations], 

132 "count": self.count, 

133 "metadata": self.metadata 

134 } 

135 

136 

137class ImageGenerator: 

138 """ 

139 Image generator for AI image generation. 

140 

141 Supports multiple providers for text-to-image generation. 

142 """ 

143 

144 # Available providers 

145 providers: List[str] = ["openai", "stability", "midjourney"] 

146 

147 # Default models per provider 

148 DEFAULT_MODELS = { 

149 ImageProvider.OPENAI: "dall-e-3", 

150 ImageProvider.STABILITY: "stable-diffusion-xl-1024-v1-0", 

151 ImageProvider.MIDJOURNEY: "v6" 

152 } 

153 

154 # Supported sizes per provider 

155 SUPPORTED_SIZES = { 

156 ImageProvider.OPENAI: ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], 

157 ImageProvider.STABILITY: ["512x512", "768x768", "1024x1024", "1536x1536"], 

158 ImageProvider.MIDJOURNEY: ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] 

159 } 

160 

161 def __init__( 

162 self, 

163 provider: Union[ImageProvider, str] = ImageProvider.OPENAI, 

164 api_key: Optional[str] = None, 

165 model: Optional[str] = None, 

166 config: Optional[Dict[str, Any]] = None 

167 ): 

168 """ 

169 Initialize image generator. 

170 

171 Args: 

172 provider: Image generation provider to use 

173 api_key: API key for the provider 

174 model: Specific model to use 

175 config: Additional configuration options 

176 """ 

177 if isinstance(provider, str): 

178 provider = ImageProvider(provider.lower()) 

179 

180 self.provider = provider 

181 self.api_key = api_key 

182 self.config = config or {} 

183 

184 # Set default model per provider 

185 self.model = model or self.DEFAULT_MODELS.get(provider, "default") 

186 

187 # Initialize provider-specific client 

188 self._client = None 

189 self._initialized = False 

190 

191 # Rate limiting 

192 self._last_request_time = 0 

193 self._min_request_interval = config.get("min_request_interval", 1.0) if config else 1.0 

194 

195 # Ensure temp directories exist 

196 self._ensure_temp_dirs() 

197 

198 def _ensure_temp_dirs(self): 

199 """Ensure temp directories exist (Docker-compatible).""" 

200 for dir_path in [TEMP_DIR, APP_TEMP_DIR]: 

201 try: 

202 Path(dir_path).mkdir(parents=True, exist_ok=True) 

203 except (PermissionError, OSError): 

204 pass 

205 

206 async def _ensure_initialized(self): 

207 """Ensure provider client is initialized.""" 

208 if self._initialized: 

209 return 

210 

211 if self.provider == ImageProvider.OPENAI: 

212 # Would initialize OpenAI client 

213 pass 

214 elif self.provider == ImageProvider.STABILITY: 

215 # Would initialize Stability AI client 

216 pass 

217 elif self.provider == ImageProvider.MIDJOURNEY: 

218 # Would initialize Midjourney client (via Discord or API) 

219 pass 

220 

221 self._initialized = True 

222 

223 async def _rate_limit(self): 

224 """Apply rate limiting between requests.""" 

225 now = time.time() 

226 elapsed = now - self._last_request_time 

227 if elapsed < self._min_request_interval: 

228 await asyncio.sleep(self._min_request_interval - elapsed) 

229 self._last_request_time = time.time() 

230 

231 def _normalize_size(self, size: str) -> str: 

232 """Normalize size string to supported format.""" 

233 # Handle ImageSize enum 

234 if isinstance(size, ImageSize): 

235 size = size.value 

236 

237 # Check if size is supported 

238 supported = self.SUPPORTED_SIZES.get(self.provider, ["1024x1024"]) 

239 if size in supported: 

240 return size 

241 

242 # Find closest supported size 

243 try: 

244 w, h = map(int, size.split("x")) 

245 target_pixels = w * h 

246 

247 best_size = supported[0] 

248 best_diff = float("inf") 

249 

250 for s in supported: 

251 sw, sh = map(int, s.split("x")) 

252 diff = abs(sw * sh - target_pixels) 

253 if diff < best_diff: 

254 best_diff = diff 

255 best_size = s 

256 

257 return best_size 

258 except ValueError: 

259 return "1024x1024" 

260 

261 async def generate( 

262 self, 

263 prompt: str, 

264 size: str = "1024x1024", 

265 style: Optional[Union[ImageStyle, str]] = None, 

266 quality: str = "standard", 

267 n: int = 1 

268 ) -> bytes: 

269 """ 

270 Generate image from prompt. 

271 

272 Args: 

273 prompt: Text description of the image to generate 

274 size: Image size (e.g., "1024x1024") 

275 style: Style preset (vivid, natural, etc.) 

276 quality: Quality level (standard, hd) 

277 n: Number of images to generate 

278 

279 Returns: 

280 Image bytes (first image if n > 1) 

281 """ 

282 await self._ensure_initialized() 

283 await self._rate_limit() 

284 

285 size = self._normalize_size(size) 

286 if isinstance(style, ImageStyle): 

287 style = style.value 

288 

289 logger.info(f"Generating image: {prompt[:50]}... ({size})") 

290 

291 # Provider-specific generation 

292 if self.provider == ImageProvider.OPENAI: 

293 return await self._generate_openai(prompt, size, style, quality, n) 

294 elif self.provider == ImageProvider.STABILITY: 

295 return await self._generate_stability(prompt, size, style, n) 

296 elif self.provider == ImageProvider.MIDJOURNEY: 

297 return await self._generate_midjourney(prompt, size, style) 

298 

299 return b"" 

300 

301 async def _generate_openai( 

302 self, 

303 prompt: str, 

304 size: str, 

305 style: Optional[str], 

306 quality: str, 

307 n: int 

308 ) -> bytes: 

309 """Generate using OpenAI DALL-E.""" 

310 # Would use OpenAI API: 

311 # response = await self._client.images.generate( 

312 # model=self.model, 

313 # prompt=prompt, 

314 # size=size, 

315 # style=style or "vivid", 

316 # quality=quality, 

317 # n=n, 

318 # response_format="b64_json" 

319 # ) 

320 # return base64.b64decode(response.data[0].b64_json) 

321 return b"" 

322 

323 async def _generate_stability( 

324 self, 

325 prompt: str, 

326 size: str, 

327 style: Optional[str], 

328 n: int 

329 ) -> bytes: 

330 """Generate using Stability AI.""" 

331 # Would use Stability API 

332 return b"" 

333 

334 async def _generate_midjourney( 

335 self, 

336 prompt: str, 

337 size: str, 

338 style: Optional[str] 

339 ) -> bytes: 

340 """Generate using Midjourney.""" 

341 # Would use Midjourney API/Discord integration 

342 return b"" 

343 

344 async def generate_multiple( 

345 self, 

346 prompt: str, 

347 n: int = 4, 

348 size: str = "1024x1024", 

349 style: Optional[Union[ImageStyle, str]] = None 

350 ) -> List[GeneratedImage]: 

351 """ 

352 Generate multiple images from prompt. 

353 

354 Args: 

355 prompt: Text description 

356 n: Number of images to generate 

357 size: Image size 

358 style: Style preset 

359 

360 Returns: 

361 List of GeneratedImage objects 

362 """ 

363 await self._ensure_initialized() 

364 

365 size = self._normalize_size(size) 

366 w, h = map(int, size.split("x")) 

367 

368 images = [] 

369 for i in range(n): 

370 await self._rate_limit() 

371 data = await self.generate(prompt, size, style, n=1) 

372 if data: 

373 images.append(GeneratedImage( 

374 data=data, 

375 format="png", 

376 width=w, 

377 height=h, 

378 prompt=prompt, 

379 provider=self.provider.value, 

380 model=self.model, 

381 metadata={"index": i} 

382 )) 

383 

384 return images 

385 

386 async def edit( 

387 self, 

388 image: bytes, 

389 prompt: str, 

390 mask: Optional[bytes] = None, 

391 size: str = "1024x1024" 

392 ) -> bytes: 

393 """ 

394 Edit an existing image based on prompt. 

395 

396 Args: 

397 image: Original image bytes 

398 prompt: Edit instruction 

399 mask: Optional mask indicating areas to edit (transparent = edit) 

400 size: Output size 

401 

402 Returns: 

403 Edited image bytes 

404 """ 

405 await self._ensure_initialized() 

406 await self._rate_limit() 

407 

408 size = self._normalize_size(size) 

409 

410 logger.info(f"Editing image with prompt: {prompt[:50]}...") 

411 

412 # Provider-specific editing 

413 if self.provider == ImageProvider.OPENAI: 

414 return await self._edit_openai(image, prompt, mask, size) 

415 elif self.provider == ImageProvider.STABILITY: 

416 return await self._edit_stability(image, prompt, mask, size) 

417 

418 return b"" 

419 

420 async def _edit_openai( 

421 self, 

422 image: bytes, 

423 prompt: str, 

424 mask: Optional[bytes], 

425 size: str 

426 ) -> bytes: 

427 """Edit using OpenAI DALL-E.""" 

428 # Would use OpenAI API: 

429 # response = await self._client.images.edit( 

430 # model="dall-e-2", # Only DALL-E 2 supports edit 

431 # image=image, 

432 # mask=mask, 

433 # prompt=prompt, 

434 # size=size, 

435 # response_format="b64_json" 

436 # ) 

437 # return base64.b64decode(response.data[0].b64_json) 

438 return b"" 

439 

440 async def _edit_stability( 

441 self, 

442 image: bytes, 

443 prompt: str, 

444 mask: Optional[bytes], 

445 size: str 

446 ) -> bytes: 

447 """Edit using Stability AI.""" 

448 # Would use Stability API inpainting 

449 return b"" 

450 

451 async def variations( 

452 self, 

453 image: bytes, 

454 n: int = 1, 

455 size: str = "1024x1024" 

456 ) -> List[bytes]: 

457 """ 

458 Generate variations of an image. 

459 

460 Args: 

461 image: Original image bytes 

462 n: Number of variations to generate 

463 size: Output size 

464 

465 Returns: 

466 List of variation image bytes 

467 """ 

468 await self._ensure_initialized() 

469 

470 size = self._normalize_size(size) 

471 

472 logger.info(f"Generating {n} variations") 

473 

474 variations = [] 

475 for _ in range(n): 

476 await self._rate_limit() 

477 var = await self._generate_variation(image, size) 

478 if var: 

479 variations.append(var) 

480 

481 return variations 

482 

483 async def _generate_variation( 

484 self, 

485 image: bytes, 

486 size: str 

487 ) -> bytes: 

488 """Generate single variation.""" 

489 if self.provider == ImageProvider.OPENAI: 

490 # Would use OpenAI API: 

491 # response = await self._client.images.create_variation( 

492 # model="dall-e-2", # Only DALL-E 2 supports variations 

493 # image=image, 

494 # size=size, 

495 # response_format="b64_json" 

496 # ) 

497 # return base64.b64decode(response.data[0].b64_json) 

498 pass 

499 elif self.provider == ImageProvider.STABILITY: 

500 # Would use Stability API image-to-image 

501 pass 

502 

503 return b"" 

504 

505 async def upscale( 

506 self, 

507 image: bytes, 

508 scale: int = 2 

509 ) -> bytes: 

510 """ 

511 Upscale an image. 

512 

513 Args: 

514 image: Image to upscale 

515 scale: Scale factor (2x, 4x) 

516 

517 Returns: 

518 Upscaled image bytes 

519 """ 

520 await self._ensure_initialized() 

521 await self._rate_limit() 

522 

523 if self.provider == ImageProvider.STABILITY: 

524 # Stability AI has dedicated upscaling 

525 # Would use their upscale API 

526 pass 

527 

528 # Placeholder - would use actual upscaling 

529 return image 

530 

531 async def save_to_file( 

532 self, 

533 prompt: str, 

534 file_path: str, 

535 size: str = "1024x1024", 

536 style: Optional[Union[ImageStyle, str]] = None 

537 ) -> str: 

538 """ 

539 Generate and save image to file. 

540 

541 Args: 

542 prompt: Text description 

543 file_path: Output file path 

544 size: Image size 

545 style: Style preset 

546 

547 Returns: 

548 Path to saved file 

549 """ 

550 image_data = await self.generate(prompt, size, style) 

551 

552 path = Path(file_path) 

553 path.parent.mkdir(parents=True, exist_ok=True) 

554 

555 with open(path, 'wb') as f: 

556 f.write(image_data) 

557 

558 return str(path) 

559 

560 def get_temp_path(self, prefix: str = "img") -> str: 

561 """ 

562 Get a temporary file path for image storage. 

563 

564 Args: 

565 prefix: File name prefix 

566 

567 Returns: 

568 Temporary file path (Docker-compatible) 

569 """ 

570 timestamp = int(time.time() * 1000) 

571 random_hash = hashlib.md5(str(timestamp).encode()).hexdigest()[:8] 

572 filename = f"{prefix}_{timestamp}_{random_hash}.png" 

573 return os.path.join(TEMP_DIR, filename) 

574 

575 def get_supported_sizes(self) -> List[str]: 

576 """Get list of supported sizes for current provider.""" 

577 return self.SUPPORTED_SIZES.get(self.provider, ["1024x1024"]) 

578 

579 def get_supported_styles(self) -> List[str]: 

580 """Get list of supported styles for current provider.""" 

581 styles = { 

582 ImageProvider.OPENAI: ["vivid", "natural"], 

583 ImageProvider.STABILITY: [s.value for s in ImageStyle], 

584 ImageProvider.MIDJOURNEY: ["raw", "cute", "scenic", "expressive", "original"] 

585 } 

586 return styles.get(self.provider, []) 

587 

588 def get_max_prompt_length(self) -> int: 

589 """Get maximum prompt length for current provider.""" 

590 limits = { 

591 ImageProvider.OPENAI: 4000, 

592 ImageProvider.STABILITY: 2000, 

593 ImageProvider.MIDJOURNEY: 6000 

594 } 

595 return limits.get(self.provider, 2000) 

596 

597 def estimate_cost( 

598 self, 

599 size: str = "1024x1024", 

600 quality: str = "standard", 

601 n: int = 1 

602 ) -> float: 

603 """ 

604 Estimate cost for generation. 

605 

606 Args: 

607 size: Image size 

608 quality: Quality level 

609 n: Number of images 

610 

611 Returns: 

612 Estimated cost in USD 

613 """ 

614 # Approximate pricing (may be outdated) 

615 pricing = { 

616 ImageProvider.OPENAI: { 

617 ("1024x1024", "standard"): 0.04, 

618 ("1024x1024", "hd"): 0.08, 

619 ("1792x1024", "standard"): 0.08, 

620 ("1792x1024", "hd"): 0.12, 

621 ("1024x1792", "standard"): 0.08, 

622 ("1024x1792", "hd"): 0.12, 

623 }, 

624 ImageProvider.STABILITY: { 

625 ("512x512", "standard"): 0.002, 

626 ("1024x1024", "standard"): 0.008, 

627 } 

628 } 

629 

630 provider_pricing = pricing.get(self.provider, {}) 

631 cost = provider_pricing.get((size, quality), 0.04) 

632 

633 return cost * n 

634 

635 def get_provider_info(self) -> Dict[str, Any]: 

636 """Get information about the current provider.""" 

637 return { 

638 "provider": self.provider.value, 

639 "model": self.model, 

640 "supported_sizes": self.get_supported_sizes(), 

641 "supported_styles": self.get_supported_styles(), 

642 "max_prompt_length": self.get_max_prompt_length(), 

643 "supports_edit": self.provider in [ImageProvider.OPENAI, ImageProvider.STABILITY], 

644 "supports_variations": self.provider in [ImageProvider.OPENAI, ImageProvider.STABILITY], 

645 "supports_upscale": self.provider == ImageProvider.STABILITY 

646 }