forked from kandinskylab/kandinsky-5
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathworker.py
More file actions
381 lines (320 loc) · 14.3 KB
/
worker.py
File metadata and controls
381 lines (320 loc) · 14.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
"""
Background Worker for Kandinsky5 Video Generation
This worker process runs independently of Gradio and processes jobs from the queue.
It continues running even when browsers disconnect, ensuring all queued jobs complete.
Usage:
python worker.py [--poll-interval 2.0] [--queue-file job_queue.json]
The worker will automatically:
1. Poll for pending jobs
2. Execute them via subprocess
3. Update progress in the queue file
4. Handle cancellation requests
5. Clean up old completed jobs periodically
"""
import argparse
import json
import os
import sys
import re
import time
import signal
import subprocess
import threading
from typing import Optional, Tuple
from datetime import datetime
from job_queue import get_queue, JobQueue, JobStatus, Job
def add_metadata_to_video(video_path: str, parameters: dict) -> None:
"""Add generation parameters to video metadata using ffmpeg."""
params_json = json.dumps(parameters, indent=2)
temp_path = video_path.replace(".mp4", "_temp.mp4")
cmd = [
'ffmpeg', '-y',
'-i', video_path,
'-metadata', f'comment={params_json}',
'-codec', 'copy',
temp_path
]
try:
subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
os.replace(temp_path, video_path)
except subprocess.CalledProcessError as e:
print(f"[Worker] Failed to add metadata: {e.stderr.decode() if e.stderr else str(e)}")
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e:
print(f"[Worker] Metadata error: {str(e)}")
if os.path.exists(temp_path):
os.remove(temp_path)
class Worker:
"""
Background worker that processes video generation jobs.
Runs as a daemon process or thread, independent of the Gradio frontend.
"""
def __init__(self, queue: JobQueue, poll_interval: float = 2.0, use_signals: bool = False):
self.queue = queue
self.poll_interval = poll_interval
self.running = True
self.current_process: Optional[subprocess.Popen] = None
self.current_job_id: Optional[str] = None
# Only setup signal handlers when running as main process (not in thread)
if use_signals:
try:
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
except ValueError:
# Signal handlers can only be set in main thread
pass
def stop(self):
"""Stop the worker gracefully."""
self.running = False
if self.current_process:
try:
self.current_process.terminate()
self.current_process.wait(timeout=5)
except:
try:
self.current_process.kill()
except:
pass
def _signal_handler(self, signum, frame):
"""Handle shutdown signals gracefully."""
print(f"\n[Worker] Received signal {signum}, shutting down...")
self.stop()
def parse_progress_line(self, line: str) -> Tuple[Optional[float], Optional[str], int, int]:
"""
Parse progress bar lines and extract useful information.
Returns:
Tuple of (progress_percent, progress_text, current_step, total_steps)
"""
line = line.strip()
# Loading checkpoint shards
if "Loading checkpoint shards:" in line:
match = re.search(r'(\d+)%.*?(\d+/\d+)', line)
if match:
percent = float(match.group(1))
fraction = match.group(2)
# Loading is ~10% of total progress
return percent * 0.1, f"Loading model: {percent:.0f}% ({fraction} shards)", 0, 0
# Building DiT
if "Building DiT with block swapping" in line:
return 10.0, "Building DiT model...", 0, 0
# Loading DiT weights
if "Loading DiT weights from" in line:
return 12.0, "Loading DiT weights...", 0, 0
# Main generation progress bar
match = re.search(r'(\d+)%\|.*?\|\s*(\d+)/(\d+)\s*\[.*?<([\d:]+)', line)
if match:
percent = float(match.group(1))
current = int(match.group(2))
total = int(match.group(3))
eta = match.group(4)
# Main generation is 15% to 95% of total progress
adjusted_percent = 15.0 + (percent * 0.8)
return adjusted_percent, f"Generating: {percent:.0f}% ({current}/{total} steps) - ETA: {eta}", current, total
# Time elapsed (completion)
if "TIME ELAPSED:" in line:
match = re.search(r'TIME ELAPSED:\s*([\d.]+)', line)
if match:
elapsed = float(match.group(1))
return 95.0, f"Generation completed in {elapsed:.1f}s", 0, 0
# Video saved
if "Generated video is saved to" in line:
return 100.0, "Video saved successfully!", 0, 0
# VAE decoding
if "Decoding" in line or "VAE" in line.upper():
return 96.0, "Decoding video...", 0, 0
return None, None, 0, 0
def check_cancellation(self, job_id: str) -> bool:
"""Check if the job has been cancelled."""
job = self.queue.get_job(job_id)
return job is None or job.status == JobStatus.CANCELLED.value
def run_job(self, job: Job) -> bool:
"""
Execute a single job.
Returns:
True if job completed successfully, False otherwise
"""
self.current_job_id = job.id
print(f"\n[Worker] Starting job {job.id}")
print(f"[Worker] Command: {' '.join(job.command)}")
# Get preview path from job parameters
preview_suffix = None
for i, arg in enumerate(job.command):
if arg == "--preview_suffix" and i + 1 < len(job.command):
preview_suffix = job.command[i + 1]
break
preview_path = ""
if preview_suffix:
save_path = job.parameters.get('save_path', 'outputs')
preview_path = os.path.join(save_path, "previews", f"latent_preview_{preview_suffix}.mp4")
try:
# Start the subprocess
self.current_process = subprocess.Popen(
job.command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP if os.name == 'nt' else 0
)
# Mark job as running
self.queue.mark_running(job.id, self.current_process.pid)
last_preview_mtime = 0
output_lines = []
# Monitor the process
while True:
# Check for cancellation
if self.check_cancellation(job.id):
print(f"[Worker] Job {job.id} cancelled, terminating...")
self.current_process.terminate()
try:
self.current_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.current_process.kill()
return False
# Check if process has finished
if self.current_process.poll() is not None:
break
# Read output line
line = self.current_process.stdout.readline()
if line:
line = line.strip()
output_lines.append(line)
print(f"[Job {job.id}] {line}")
# Parse progress
progress, progress_text, current_step, total_steps = self.parse_progress_line(line)
if progress is not None:
# Check for updated preview
current_preview = ""
if preview_path and os.path.exists(preview_path):
try:
mtime = os.path.getmtime(preview_path)
if mtime > last_preview_mtime:
current_preview = preview_path
last_preview_mtime = mtime
except:
pass
self.queue.update_progress(
job.id,
progress=progress,
progress_text=progress_text,
current_step=current_step,
total_steps=total_steps,
preview_path=current_preview
)
else:
# No output, sleep briefly
time.sleep(0.1)
# Process finished - read any remaining output
remaining = self.current_process.stdout.read()
if remaining:
for line in remaining.strip().split('\n'):
print(f"[Job {job.id}] {line}")
output_lines.append(line)
return_code = self.current_process.returncode
self.current_process = None
self.current_job_id = None
# Check if output file exists
if return_code == 0 and os.path.exists(job.output_filename):
# Add metadata to the generated video
try:
add_metadata_to_video(job.output_filename, job.parameters)
print(f"[Worker] Added metadata to {job.output_filename}")
except Exception as meta_err:
print(f"[Worker] Warning: Failed to add metadata: {meta_err}")
self.queue.mark_completed(job.id, return_code)
print(f"[Worker] Job {job.id} completed successfully")
return True
else:
error_msg = f"Process exited with code {return_code}"
if not os.path.exists(job.output_filename):
error_msg += f", output file not found: {job.output_filename}"
# Check last few lines for error messages
for line in output_lines[-10:]:
if "error" in line.lower() or "exception" in line.lower():
error_msg = line
break
self.queue.mark_failed(job.id, error_msg, return_code)
print(f"[Worker] Job {job.id} failed: {error_msg}")
return False
except Exception as e:
self.current_process = None
self.current_job_id = None
self.queue.mark_failed(job.id, str(e))
print(f"[Worker] Job {job.id} exception: {e}")
return False
def recover_stale_jobs(self):
"""
Recover jobs that were marked as running but whose process is no longer alive.
This handles cases where the worker crashed without properly marking jobs as failed.
"""
running_jobs = self.queue.get_running_jobs()
for job in running_jobs:
if job.process_id:
# Check if process is still running
try:
if os.name == 'nt': # Windows
import ctypes
kernel32 = ctypes.windll.kernel32
handle = kernel32.OpenProcess(0x1000, False, job.process_id)
if handle:
kernel32.CloseHandle(handle)
continue # Process still running
else: # Unix
os.kill(job.process_id, 0)
continue # Process still running
except (OSError, ProcessLookupError):
pass
# Process is not running - mark as failed
print(f"[Worker] Recovering stale job {job.id} (process {job.process_id} not found)")
self.queue.mark_failed(job.id, "Worker process died unexpectedly")
def run(self):
"""Main worker loop."""
print("=" * 60)
print(f"[Worker] Started at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"[Worker] Queue file: {self.queue.queue_file}")
print(f"[Worker] Poll interval: {self.poll_interval}s")
print("=" * 60)
# Recover any stale jobs from previous crashes
self.recover_stale_jobs()
last_cleanup = time.time()
cleanup_interval = 3600 # Clean up old jobs every hour
while self.running:
try:
# Get next pending job
job = self.queue.get_next_pending()
if job:
self.run_job(job)
else:
# No jobs - display queue stats periodically
stats = self.queue.get_queue_stats()
if stats['pending'] == 0 and stats['running'] == 0:
# Only print idle message occasionally
pass
time.sleep(self.poll_interval)
# Periodic cleanup of old jobs
if time.time() - last_cleanup > cleanup_interval:
removed = self.queue.cleanup_old_jobs(max_age_hours=24.0)
if removed > 0:
print(f"[Worker] Cleaned up {removed} old jobs")
last_cleanup = time.time()
except KeyboardInterrupt:
break
except Exception as e:
print(f"[Worker] Error in main loop: {e}")
time.sleep(self.poll_interval)
print(f"\n[Worker] Stopped at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
def main():
parser = argparse.ArgumentParser(description="Kandinsky5 Background Worker")
parser.add_argument("--poll-interval", type=float, default=2.0,
help="Interval in seconds between queue polls (default: 2.0)")
parser.add_argument("--queue-file", type=str, default=None,
help="Path to the job queue file. Auto-detects GPU-specific file "
"from CUDA_VISIBLE_DEVICES if not specified.")
args = parser.parse_args()
queue = get_queue(args.queue_file)
# use_signals=True when running as standalone process
worker = Worker(queue, poll_interval=args.poll_interval, use_signals=True)
worker.run()
if __name__ == "__main__":
main()