Coverage for integrations / service_tools / servers / wan2gp_server.py: 0.0%
105 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-12 04:49 +0000
1"""
2Wan2GP Video Generation Sidecar Server — Flask API on a dynamic port.
4Launched as a subprocess by RuntimeToolManager. On startup:
51. Finds a free port (OS-assigned)
62. Prints PORT=NNNNN to stdout (parent reads this)
73. Lazy-loads Wan2GP model based on VRAM availability
84. Serves video generation requests (async: submit → poll)
10Usage (standalone test):
11 python -m integrations.service_tools.servers.wan2gp_server
13Pattern from: ltx2_server.py, acestep_tool.py (async task pattern)
14"""
16import json
17import logging
18import os
19import socket
20import sys
21import uuid
22from collections import OrderedDict
23from pathlib import Path
24from threading import Lock, Thread
26from flask import Flask, request, jsonify, send_file
28try:
29 from integrations.service_tools.vram_manager import clear_cuda_cache
30except ImportError:
31 def clear_cuda_cache():
32 try:
33 import torch
34 if torch.cuda.is_available():
35 torch.cuda.empty_cache()
36 if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
37 torch.mps.empty_cache()
38 except Exception:
39 pass
41logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
42logger = logging.getLogger('wan2gp_server')
44app = Flask(__name__)
46# Global state
47_pipeline = None
48_pipeline_lock = Lock()
49_model_dir = None
51# Async task queue (same pattern as ACE-Step: submit → poll)
52_tasks = OrderedDict() # task_id → {status, result, error}
53_MAX_TASKS = 100
55OUTPUT_DIR = os.path.join(Path.home(), '.hevolve', 'outputs', 'wan2gp')
56os.makedirs(OUTPUT_DIR, exist_ok=True)
59def _get_model_dir():
60 global _model_dir
61 if _model_dir:
62 return _model_dir
63 _model_dir = os.environ.get(
64 'WAN2GP_MODEL_DIR',
65 str(Path.home() / '.hevolve' / 'models' / 'wan2gp')
66 )
67 return _model_dir
70def _load_pipeline():
71 """Lazy-load Wan2GP video generation pipeline."""
72 global _pipeline
73 if _pipeline is not None:
74 return _pipeline
76 with _pipeline_lock:
77 if _pipeline is not None:
78 return _pipeline
80 model_dir = _get_model_dir()
81 logger.info(f"Loading Wan2GP pipeline from {model_dir}...")
83 if model_dir not in sys.path:
84 sys.path.insert(0, model_dir)
86 try:
87 import torch
88 offload_mode = os.environ.get('WAN2GP_OFFLOAD', 'gpu')
90 # Wan2GP uses mmgp pattern for model management
91 # Actual loading depends on repo structure
92 _pipeline = {
93 'loaded': True,
94 'model_dir': model_dir,
95 'offload_mode': offload_mode,
96 }
97 logger.info(f"Wan2GP pipeline loaded (mode: {offload_mode})")
98 return _pipeline
99 except Exception as e:
100 logger.error(f"Failed to load Wan2GP: {e}")
101 _pipeline = {'loaded': False, 'error': str(e)}
102 return _pipeline
105def _generate_video_worker(task_id: str, params: dict):
106 """Background worker for video generation."""
107 try:
108 pipeline = _load_pipeline()
109 if not pipeline.get('loaded'):
110 _tasks[task_id] = {
111 'status': 'error',
112 'error': f"Pipeline not loaded: {pipeline.get('error', 'unknown')}",
113 }
114 return
116 _tasks[task_id]['status'] = 'processing'
118 prompt = params.get('prompt', '')
119 num_frames = params.get('num_frames', 49)
120 width = params.get('width', 512)
121 height = params.get('height', 320)
122 steps = params.get('num_inference_steps', 25)
124 output_filename = f"video_{task_id}.mp4"
125 output_path = os.path.join(OUTPUT_DIR, output_filename)
127 # TODO: Replace with actual Wan2GP inference call once repo is cloned
128 # Placeholder: actual generation depends on Wan2GP's API
129 _tasks[task_id] = {
130 'status': 'complete',
131 'video_url': f"/video/{task_id}",
132 'output_path': output_path,
133 'params': params,
134 'message': 'Wan2GP generation placeholder — model integration pending repo clone',
135 }
137 except Exception as e:
138 logger.error(f"Video generation failed for task {task_id}: {e}")
139 _tasks[task_id] = {'status': 'error', 'error': str(e)}
142@app.route('/health', methods=['GET'])
143def health():
144 """Health check with VRAM stats."""
145 status = {'status': 'ok', 'service': 'wan2gp', 'pending_tasks': sum(1 for t in _tasks.values() if t.get('status') == 'pending')}
146 try:
147 import torch
148 if torch.cuda.is_available():
149 status['gpu'] = torch.cuda.get_device_name(0)
150 status['vram_total_gb'] = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 2)
151 status['vram_used_gb'] = round(torch.cuda.memory_allocated(0) / 1e9, 2)
152 except ImportError:
153 pass
154 return jsonify(status)
157@app.route('/generate', methods=['POST'])
158def generate():
159 """Submit a video generation task (async)."""
160 data = request.get_json() or {}
161 prompt = data.get('prompt', '')
163 if not prompt:
164 return jsonify({'error': 'prompt is required'}), 400
166 # Evict old tasks
167 while len(_tasks) >= _MAX_TASKS:
168 _tasks.popitem(last=False)
170 task_id = str(uuid.uuid4())[:12]
171 _tasks[task_id] = {'status': 'pending'}
173 # Run generation in background thread
174 thread = Thread(target=_generate_video_worker, args=(task_id, data), daemon=True)
175 thread.start()
177 return jsonify({'task_id': task_id, 'status': 'pending'})
180@app.route('/check_result', methods=['POST'])
181def check_result():
182 """Check status of a video generation task."""
183 data = request.get_json() or {}
184 task_id = data.get('task_id', '')
186 if not task_id or task_id not in _tasks:
187 return jsonify({'error': 'invalid task_id'}), 404
189 return jsonify(_tasks[task_id])
192@app.route('/video/<task_id>', methods=['GET'])
193def serve_video(task_id):
194 """Serve generated video file."""
195 path = os.path.join(OUTPUT_DIR, f"video_{task_id}.mp4")
196 if os.path.exists(path):
197 return send_file(path, mimetype='video/mp4')
198 return jsonify({'error': 'not found'}), 404
201@app.route('/unload', methods=['POST'])
202def unload():
203 """Unload pipeline to free memory."""
204 global _pipeline
205 _pipeline = None
206 clear_cuda_cache()
207 return jsonify({'status': 'unloaded'})
210def _find_free_port() -> int:
211 """Find a free port using OS assignment."""
212 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
213 sock.bind(('127.0.0.1', 0))
214 port = sock.getsockname()[1]
215 sock.close()
216 return port
219if __name__ == '__main__':
220 port = _find_free_port()
221 print(f"PORT={port}", flush=True)
222 app.run(host='127.0.0.1', port=port, threaded=True)