Coverage for integrations / service_tools / servers / wan2gp_server.py: 0.0%

105 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-12 04:49 +0000

1""" 

2Wan2GP Video Generation Sidecar Server — Flask API on a dynamic port. 

3 

4Launched as a subprocess by RuntimeToolManager. On startup: 

51. Finds a free port (OS-assigned) 

62. Prints PORT=NNNNN to stdout (parent reads this) 

73. Lazy-loads Wan2GP model based on VRAM availability 

84. Serves video generation requests (async: submit → poll) 

9 

10Usage (standalone test): 

11 python -m integrations.service_tools.servers.wan2gp_server 

12 

13Pattern from: ltx2_server.py, acestep_tool.py (async task pattern) 

14""" 

15 

16import json 

17import logging 

18import os 

19import socket 

20import sys 

21import uuid 

22from collections import OrderedDict 

23from pathlib import Path 

24from threading import Lock, Thread 

25 

26from flask import Flask, request, jsonify, send_file 

27 

28try: 

29 from integrations.service_tools.vram_manager import clear_cuda_cache 

30except ImportError: 

31 def clear_cuda_cache(): 

32 try: 

33 import torch 

34 if torch.cuda.is_available(): 

35 torch.cuda.empty_cache() 

36 if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): 

37 torch.mps.empty_cache() 

38 except Exception: 

39 pass 

40 

41logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 

42logger = logging.getLogger('wan2gp_server') 

43 

44app = Flask(__name__) 

45 

46# Global state 

47_pipeline = None 

48_pipeline_lock = Lock() 

49_model_dir = None 

50 

51# Async task queue (same pattern as ACE-Step: submit → poll) 

52_tasks = OrderedDict() # task_id → {status, result, error} 

53_MAX_TASKS = 100 

54 

55OUTPUT_DIR = os.path.join(Path.home(), '.hevolve', 'outputs', 'wan2gp') 

56os.makedirs(OUTPUT_DIR, exist_ok=True) 

57 

58 

59def _get_model_dir(): 

60 global _model_dir 

61 if _model_dir: 

62 return _model_dir 

63 _model_dir = os.environ.get( 

64 'WAN2GP_MODEL_DIR', 

65 str(Path.home() / '.hevolve' / 'models' / 'wan2gp') 

66 ) 

67 return _model_dir 

68 

69 

70def _load_pipeline(): 

71 """Lazy-load Wan2GP video generation pipeline.""" 

72 global _pipeline 

73 if _pipeline is not None: 

74 return _pipeline 

75 

76 with _pipeline_lock: 

77 if _pipeline is not None: 

78 return _pipeline 

79 

80 model_dir = _get_model_dir() 

81 logger.info(f"Loading Wan2GP pipeline from {model_dir}...") 

82 

83 if model_dir not in sys.path: 

84 sys.path.insert(0, model_dir) 

85 

86 try: 

87 import torch 

88 offload_mode = os.environ.get('WAN2GP_OFFLOAD', 'gpu') 

89 

90 # Wan2GP uses mmgp pattern for model management 

91 # Actual loading depends on repo structure 

92 _pipeline = { 

93 'loaded': True, 

94 'model_dir': model_dir, 

95 'offload_mode': offload_mode, 

96 } 

97 logger.info(f"Wan2GP pipeline loaded (mode: {offload_mode})") 

98 return _pipeline 

99 except Exception as e: 

100 logger.error(f"Failed to load Wan2GP: {e}") 

101 _pipeline = {'loaded': False, 'error': str(e)} 

102 return _pipeline 

103 

104 

105def _generate_video_worker(task_id: str, params: dict): 

106 """Background worker for video generation.""" 

107 try: 

108 pipeline = _load_pipeline() 

109 if not pipeline.get('loaded'): 

110 _tasks[task_id] = { 

111 'status': 'error', 

112 'error': f"Pipeline not loaded: {pipeline.get('error', 'unknown')}", 

113 } 

114 return 

115 

116 _tasks[task_id]['status'] = 'processing' 

117 

118 prompt = params.get('prompt', '') 

119 num_frames = params.get('num_frames', 49) 

120 width = params.get('width', 512) 

121 height = params.get('height', 320) 

122 steps = params.get('num_inference_steps', 25) 

123 

124 output_filename = f"video_{task_id}.mp4" 

125 output_path = os.path.join(OUTPUT_DIR, output_filename) 

126 

127 # TODO: Replace with actual Wan2GP inference call once repo is cloned 

128 # Placeholder: actual generation depends on Wan2GP's API 

129 _tasks[task_id] = { 

130 'status': 'complete', 

131 'video_url': f"/video/{task_id}", 

132 'output_path': output_path, 

133 'params': params, 

134 'message': 'Wan2GP generation placeholder — model integration pending repo clone', 

135 } 

136 

137 except Exception as e: 

138 logger.error(f"Video generation failed for task {task_id}: {e}") 

139 _tasks[task_id] = {'status': 'error', 'error': str(e)} 

140 

141 

142@app.route('/health', methods=['GET']) 

143def health(): 

144 """Health check with VRAM stats.""" 

145 status = {'status': 'ok', 'service': 'wan2gp', 'pending_tasks': sum(1 for t in _tasks.values() if t.get('status') == 'pending')} 

146 try: 

147 import torch 

148 if torch.cuda.is_available(): 

149 status['gpu'] = torch.cuda.get_device_name(0) 

150 status['vram_total_gb'] = round(torch.cuda.get_device_properties(0).total_memory / 1e9, 2) 

151 status['vram_used_gb'] = round(torch.cuda.memory_allocated(0) / 1e9, 2) 

152 except ImportError: 

153 pass 

154 return jsonify(status) 

155 

156 

157@app.route('/generate', methods=['POST']) 

158def generate(): 

159 """Submit a video generation task (async).""" 

160 data = request.get_json() or {} 

161 prompt = data.get('prompt', '') 

162 

163 if not prompt: 

164 return jsonify({'error': 'prompt is required'}), 400 

165 

166 # Evict old tasks 

167 while len(_tasks) >= _MAX_TASKS: 

168 _tasks.popitem(last=False) 

169 

170 task_id = str(uuid.uuid4())[:12] 

171 _tasks[task_id] = {'status': 'pending'} 

172 

173 # Run generation in background thread 

174 thread = Thread(target=_generate_video_worker, args=(task_id, data), daemon=True) 

175 thread.start() 

176 

177 return jsonify({'task_id': task_id, 'status': 'pending'}) 

178 

179 

180@app.route('/check_result', methods=['POST']) 

181def check_result(): 

182 """Check status of a video generation task.""" 

183 data = request.get_json() or {} 

184 task_id = data.get('task_id', '') 

185 

186 if not task_id or task_id not in _tasks: 

187 return jsonify({'error': 'invalid task_id'}), 404 

188 

189 return jsonify(_tasks[task_id]) 

190 

191 

192@app.route('/video/<task_id>', methods=['GET']) 

193def serve_video(task_id): 

194 """Serve generated video file.""" 

195 path = os.path.join(OUTPUT_DIR, f"video_{task_id}.mp4") 

196 if os.path.exists(path): 

197 return send_file(path, mimetype='video/mp4') 

198 return jsonify({'error': 'not found'}), 404 

199 

200 

201@app.route('/unload', methods=['POST']) 

202def unload(): 

203 """Unload pipeline to free memory.""" 

204 global _pipeline 

205 _pipeline = None 

206 clear_cuda_cache() 

207 return jsonify({'status': 'unloaded'}) 

208 

209 

210def _find_free_port() -> int: 

211 """Find a free port using OS assignment.""" 

212 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 

213 sock.bind(('127.0.0.1', 0)) 

214 port = sock.getsockname()[1] 

215 sock.close() 

216 return port 

217 

218 

219if __name__ == '__main__': 

220 port = _find_free_port() 

221 print(f"PORT={port}", flush=True) 

222 app.run(host='127.0.0.1', port=port, threaded=True)