Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions runner/app/live/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,22 @@ async def main(
):
loop = asyncio.get_event_loop()
loop.set_exception_handler(asyncio_exception_handler)

process = ProcessGuardian(pipeline, params or {})
# Only initialize the streamer if we have a protocol and URLs to connect to
streamer = None
if stream_protocol and subscribe_url and publish_url:
width = params.get('width')
height = params.get('height')
if stream_protocol == "trickle":
protocol = TrickleProtocol(
subscribe_url, publish_url, control_url, events_url
subscribe_url, publish_url, control_url, events_url,
width=width, height=height
)
elif stream_protocol == "zeromq":
protocol = ZeroMQProtocol(subscribe_url, publish_url)
else:
raise ValueError(f"Unsupported protocol: {stream_protocol}")
streamer = PipelineStreamer(protocol, process, request_id, stream_id)
streamer = PipelineStreamer(protocol, process, request_id, stream_id, width=width, height=height)

api = None
try:
Expand Down
62 changes: 52 additions & 10 deletions runner/app/live/pipelines/comfyui.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,35 @@
import json
import torch
import asyncio
from typing import Union
from typing import Union, Optional, Tuple
from pydantic import BaseModel, field_validator
import pathlib

from .interface import Pipeline
from comfystream.client import ComfyStreamClient
from trickle import VideoFrame, VideoOutput
from utils import ComfyUtils

import logging

COMFY_UI_WORKSPACE_ENV = "COMFY_UI_WORKSPACE"
WARMUP_RUNS = 1

_default_workflow_path = pathlib.Path(__file__).parent.absolute() / "comfyui_default_workflow.json"
with open(_default_workflow_path, 'r') as f:
DEFAULT_WORKFLOW_JSON = json.load(f)
def get_default_workflow_json():
_default_workflow_path = pathlib.Path(__file__).parent.absolute() / "comfyui_default_workflow.json"
with open(_default_workflow_path, 'r') as f:
return json.load(f)

# Get the default workflow json during startup
DEFAULT_WORKFLOW_JSON = get_default_workflow_json()

class ComfyUIParams(BaseModel):
class Config:
extra = "forbid"

prompt: Union[str, dict] = DEFAULT_WORKFLOW_JSON
width: Optional[int] = None
height: Optional[int] = None

@field_validator('prompt')
@classmethod
Expand Down Expand Up @@ -53,24 +59,38 @@ def __init__(self):
self.client = ComfyStreamClient(cwd=comfy_ui_workspace)
self.params: ComfyUIParams
self.video_incoming_frames: asyncio.Queue[VideoOutput] = asyncio.Queue()
self.width = ComfyUtils.DEFAULT_WIDTH
self.height = ComfyUtils.DEFAULT_HEIGHT
self.pause_input = False

async def initialize(self, **params):
new_params = ComfyUIParams(**params)
logging.info(f"Initializing ComfyUI Pipeline with prompt: {new_params.prompt}")
# TODO: currently its a single prompt, but need to support multiple prompts
await self.client.set_prompts([new_params.prompt])
self.params = new_params

# Warm up the pipeline

# Get dimensions from params or environment variable
width = new_params.width
height = new_params.height

# Fallback to default dimensions if not found
width = width or ComfyUtils.DEFAULT_WIDTH
height = height or ComfyUtils.DEFAULT_HEIGHT

# Warm up the pipeline with the workflow dimensions
logging.info(f"Warming up pipeline with dimensions: {width}x{height}")
dummy_frame = VideoFrame(None, 0, 0)
dummy_frame.side_data.input = torch.randn(1, 512, 512, 3)
dummy_frame.side_data.input = torch.randn(1, height, width, 3)

for _ in range(WARMUP_RUNS):
self.client.put_video_input(dummy_frame)
_ = await self.client.get_video_output()
logging.info("Pipeline initialization and warmup complete")

async def put_video_frame(self, frame: VideoFrame, request_id: str):
if self.pause_input:
return
tensor = frame.tensor
if tensor.is_cuda:
# Clone the tensor to be able to send it on comfystream internal queue
Expand Down Expand Up @@ -99,6 +119,28 @@ async def update_params(self, **params):
self.params = new_params

async def stop(self):
logging.info("Stopping ComfyUI pipeline")
await self.client.cleanup()
logging.info("ComfyUI pipeline stopped")
try:
self.pause_input = True
logging.info("Stopping ComfyUI pipeline")
await self.client.cleanup(unload_models=False)
# Wait for the pipeline to stop
await asyncio.sleep(1)
# Clear the video_incoming_frames queue
while not self.video_incoming_frames.empty():
try:
frame = self.video_incoming_frames.get_nowait()
# Ensure any CUDA tensors are properly handled
if frame.tensor is not None and frame.tensor.is_cuda:
frame.tensor.cpu()
except asyncio.QueueEmpty:
break

# Force CUDA cache clear
if torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
logging.error(f"Error stopping ComfyUI pipeline: {e}")
finally:
self.pause_input = False

logging.info("ComfyUI pipeline stopped")
6 changes: 3 additions & 3 deletions runner/app/live/pipelines/comfyui_default_workflow.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
},
"3": {
"inputs": {
"unet_name": "static-dreamshaper8_SD15_$stat-b-1-h-512-w-512_00001_.engine",
"unet_name": "static-dreamshaper8_SD15_$stat-b-1-h-704-w-384_00001_.engine",
"model_type": "SD15"
},
"class_type": "TensorRTLoader",
Expand Down Expand Up @@ -146,8 +146,8 @@
},
"16": {
"inputs": {
"width": 512,
"height": 512,
"width": 384,
"height": 704,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
Expand Down
17 changes: 16 additions & 1 deletion runner/app/live/pipelines/noop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import asyncio
from PIL import Image

import torch

from .interface import Pipeline
from trickle import VideoFrame, VideoOutput
Expand All @@ -26,3 +26,18 @@ async def update_params(self, **params):

async def stop(self):
logging.info("Stopping pipeline")

# Clear the frame queue and move any CUDA tensors to CPU
while not self.frame_queue.empty():
try:
frame = self.frame_queue.get_nowait()
if frame.tensor.is_cuda:
frame.tensor.cpu() # Move tensor to CPU before deletion
except asyncio.QueueEmpty:
break
except Exception as e:
logging.error(f"Error clearing frame queue: {e}")

# Force CUDA cache clear
if torch.cuda.is_available():
torch.cuda.empty_cache()
50 changes: 27 additions & 23 deletions runner/app/live/streamer/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pipelines import load_pipeline, Pipeline
from log import config_logging, config_logging_fields, log_timing
from trickle import InputFrame, AudioFrame, VideoFrame, OutputFrame, VideoOutput, AudioOutput
from utils import ComfyUtils

class PipelineProcess:
@staticmethod
Expand All @@ -24,6 +25,7 @@ def start(pipeline_name: str, params: dict):

def __init__(self, pipeline_name: str):
self.pipeline_name = pipeline_name
self.pipeline = None # Initialize pipeline as None
self.ctx = mp.get_context("spawn")

self.input_queue = self.ctx.Queue(maxsize=2)
Expand Down Expand Up @@ -165,9 +167,12 @@ async def _initialize_pipeline(self):
logging.info("PipelineProcess: No params found in param_update_queue, loading with default params")

with log_timing(f"PipelineProcess: Pipeline loading with {params}"):
pipeline = load_pipeline(self.pipeline_name)
await pipeline.initialize(**params)
return pipeline
self.pipeline = load_pipeline(self.pipeline_name)

# TODO: We may need to call reset_stream when resolution is changed and start the pipeline again
# Changing the engine causes issues, maybe cleanup related
await self.pipeline.initialize(**params)
return self.pipeline
except Exception as e:
self._report_error(f"Error loading pipeline: {e}")
if not params:
Expand All @@ -177,19 +182,19 @@ async def _initialize_pipeline(self):
with log_timing(
f"PipelineProcess: Pipeline loading with default params due to error with params: {params}"
):
pipeline = load_pipeline(self.pipeline_name)
await pipeline.initialize()
return pipeline
self.pipeline = load_pipeline(self.pipeline_name)
await self.pipeline.initialize()
return self.pipeline
except Exception as e:
self._report_error(f"Error loading pipeline with default params: {e}")
raise

async def _run_pipeline_loops(self):
pipeline = await self._initialize_pipeline()
await self._initialize_pipeline()
self.pipeline_initialized.set()
input_task = asyncio.create_task(self._input_loop(pipeline))
output_task = asyncio.create_task(self._output_loop(pipeline))
param_task = asyncio.create_task(self._param_update_loop(pipeline))
input_task = asyncio.create_task(self._input_loop())
output_task = asyncio.create_task(self._output_loop())
param_task = asyncio.create_task(self._param_update_loop())

async def wait_for_stop():
while not self.is_done():
Expand All @@ -205,17 +210,17 @@ async def wait_for_stop():
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
await self._cleanup_pipeline(pipeline)
await self._cleanup_pipeline()

logging.info("PipelineProcess: _run_pipeline_loops finished.")

async def _input_loop(self, pipeline: Pipeline):
async def _input_loop(self):
while not self.is_done():
try:
input_frame = await asyncio.to_thread(self.input_queue.get, timeout=0.1)
if isinstance(input_frame, VideoFrame):
input_frame.log_timestamps["pre_process_frame"] = time.time()
await pipeline.put_video_frame(input_frame, self.request_id)
await self.pipeline.put_video_frame(input_frame, self.request_id)
elif isinstance(input_frame, AudioFrame):
self._try_queue_put(self.output_queue, AudioOutput([input_frame], self.request_id))
except queue.Empty:
Expand All @@ -224,30 +229,29 @@ async def _input_loop(self, pipeline: Pipeline):
except Exception as e:
self._report_error(f"Error processing input frame: {e}")

async def _output_loop(self, pipeline: Pipeline):
async def _output_loop(self):
while not self.is_done():
try:
output = await pipeline.get_processed_video_frame()
output = await self.pipeline.get_processed_video_frame()
if isinstance(output, VideoOutput) and not output.tensor.is_cuda and torch.cuda.is_available():
output = output.replace_tensor(output.tensor.cuda())
output.log_timestamps["post_process_frame"] = time.time()
self._try_queue_put(self.output_queue, output)
except Exception as e:
self._report_error(f"Error processing output frame: {e}")

async def _param_update_loop(self, pipeline: Pipeline):
async def _param_update_loop(self):
while not self.is_done():
try:
params = await asyncio.to_thread(self.param_update_queue.get, timeout=0.1)

if self._handle_logging_params(params):
logging.info(f"PipelineProcess: Updating pipeline parameters: {params}")
await pipeline.update_params(**params)
await self.pipeline.update_params(**params)
except queue.Empty:
# Timeout ensures the non-daemon threads from to_thread can exit if task is cancelled
continue
except Exception as e:
self._report_error(f"Error updating params: {e}")
self._report_error(f"Error updating parameters: {e}")

def _report_error(self, error_msg: str):
error_event = {
Expand All @@ -257,12 +261,12 @@ def _report_error(self, error_msg: str):
logging.error(error_msg)
self._try_queue_put(self.error_queue, error_event)

async def _cleanup_pipeline(self, pipeline):
if pipeline is not None:
async def _cleanup_pipeline(self):
if self.pipeline:
try:
await pipeline.stop()
await self.pipeline.stop()
except Exception as e:
logging.error(f"Error stopping pipeline: {e}")
self._report_error(f"Error cleaning up pipeline: {e}")

def _setup_logging(self):
level = (
Expand Down
27 changes: 26 additions & 1 deletion runner/app/live/streamer/process_guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from trickle import InputFrame, OutputFrame
from .process import PipelineProcess
from .status import PipelineState, PipelineStatus, InferenceStatus, InputStatus
from utils import ComfyUtils

FPS_LOG_INTERVAL = 10.0

Expand All @@ -19,6 +20,9 @@ class StreamerCallbacks(abc.ABC):
@abc.abstractmethod
def is_stream_running(self) -> bool: ...



class ProcessCallbacks(abc.ABC):
@abc.abstractmethod
async def emit_monitoring_event(self, event_data: dict) -> None: ...

Expand All @@ -44,6 +48,7 @@ def __init__(
):
self.pipeline = pipeline
self.initial_params = params
self.width, self.height = ComfyUtils.get_latent_image_dimensions(params.get('prompt'))
self.streamer: StreamerCallbacks = _NoopStreamerCallbacks()

self.process: Optional[PipelineProcess] = None
Expand Down Expand Up @@ -82,13 +87,33 @@ async def reset_stream(
):
if not self.process:
raise RuntimeError("Process not running")

# Check if resolution has changed
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove, comfyui pipeline will stop itself

new_width = params.get("width", None)
new_height = params.get("height", None)
if (new_width is None or new_height is None):
new_width, new_height = ComfyUtils.DEFAULT_WIDTH, ComfyUtils.DEFAULT_HEIGHT

# If resolution changed, we need to restart the process (does not work for comfyui)
if (new_width != self.width or new_height != self.height):
logging.info(f"Resolution changed from {self.width}x{self.height} to {new_width}x{new_height}, restarting process")
self.width = new_width
self.height = new_height
await self.process._cleanup_pipeline()
await self.stop()
# Create new process with current pipeline name and params
params.update({"width": new_width, "height": new_height})
self.process = PipelineProcess.start(self.pipeline, params)

self.status.start_time = time.time()
self.status.input_status = InputStatus()
self.input_fps_counter.reset()
self.output_fps_counter.reset()
self.streamer = streamer or _NoopStreamerCallbacks()

self.process.reset_stream(request_id, manifest_id, stream_id)
self.process.update_params(params)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely not needed


await self.update_params(params)
self.status.update_state(PipelineState.ONLINE)

Expand Down Expand Up @@ -310,7 +335,7 @@ async def _monitor_loop(self):
# Hot fix: the comfyui pipeline process is having trouble shutting down and causes restarts not to recover.
# So we skip the restart here and move the state to ERROR so the worker will restart the whole container.
# TODO: Remove this exception once pipeline shutdown is fixed and restarting process is useful again.
raise Exception("Skipping process restart due to pipeline shutdown issues")
#raise Exception("Skipping process restart due to pipeline shutdown issues")
await self._restart_process()
except Exception:
logging.exception("Failed to stop streamer and restart process. Moving to ERROR state", stack_info=True)
Expand Down
Loading