Skip to content

Commit 290eeee

Browse files
committed
update default workflow resolution to 384x704, warmup to default workflow resolution
1 parent 6157fdd commit 290eeee

File tree

6 files changed

+96
-37
lines changed

6 files changed

+96
-37
lines changed

runner/app/live/pipelines/comfyui.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,49 @@ def validate_prompt(cls, v) -> dict:
5151
raise ValueError("Prompt must be either a JSON object or such JSON object serialized as a string")
5252

5353

54+
class ComfyUtils:
55+
@staticmethod
56+
def get_latent_image_dimensions(workflow: dict) -> tuple[int, int]:
57+
"""Get dimensions from the EmptyLatentImage node in the workflow.
58+
59+
Args:
60+
workflow: The workflow JSON dictionary
61+
62+
Returns:
63+
Tuple of (width, height) from the latent image, or (None, None) if not found
64+
"""
65+
for node_id, node in workflow.items():
66+
if node.get("class_type") == "EmptyLatentImage":
67+
try:
68+
inputs = node.get("inputs", {})
69+
return inputs.get("width"), inputs.get("height")
70+
except Exception as e:
71+
logging.warning(f"Failed to extract dimensions from latent image: {e}")
72+
return None, None
73+
return None, None
74+
75+
@staticmethod
76+
def update_latent_image_dimensions(workflow: dict, width: int, height: int) -> dict | None:
77+
"""Update the EmptyLatentImage node dimensions in the workflow.
78+
79+
Args:
80+
workflow: The workflow JSON dictionary
81+
width: Width to set
82+
height: Height to set
83+
"""
84+
for node_id, node in workflow.items():
85+
if node.get("class_type") == "EmptyLatentImage":
86+
try:
87+
if "inputs" not in node:
88+
node["inputs"] = {}
89+
node["inputs"]["width"] = width
90+
node["inputs"]["height"] = height
91+
logging.info(f"Updated latent image dimensions to {width}x{height}")
92+
except Exception as e:
93+
logging.warning(f"Failed to update latent image dimensions: {e}")
94+
break
95+
96+
5497
class ComfyUI(Pipeline):
5598
def __init__(self):
5699
comfy_ui_workspace = os.getenv(COMFY_UI_WORKSPACE_ENV)
@@ -61,14 +104,30 @@ def __init__(self):
61104
async def initialize(self, **params):
62105
new_params = ComfyUIParams(**params)
63106
logging.info(f"Initializing ComfyUI Pipeline with prompt: {new_params.prompt}")
64-
# TODO: currently its a single prompt, but need to support multiple prompts
107+
108+
# Get dimensions from workflow if it's a dict
109+
110+
if width is None or height is None:
111+
if isinstance(new_params.prompt, dict):
112+
# If dimensions not provided in params, get them from latent image
113+
latent_width, latent_height = ComfyUtils.get_latent_image_dimensions(new_params.prompt)
114+
new_params.width = width or latent_width or new_params.width
115+
new_params.height = height or latent_height or new_params.height
116+
else:
117+
# If dimensions provided in params, update the latent image
118+
ComfyUtils.update_latent_image_dimensions(new_params.prompt, width, height)
119+
120+
# TODO clean up extra vars
121+
width = width or new_params.width
122+
height = height or new_params.height
123+
65124
await self.client.set_prompts([new_params.prompt])
66125
self.params = new_params
67126

68-
# Warm up the pipeline
69-
logging.info(f"Warming up pipeline with dimensions: {new_params.width}x{new_params.height}")
127+
# Warm up the pipeline with the final dimensions
128+
logging.info(f"Warming up pipeline with dimensions: {width}x{height}")
70129
dummy_frame = VideoFrame(None, 0, 0)
71-
dummy_frame.side_data.input = torch.randn(1, new_params.height, new_params.width, 3)
130+
dummy_frame.side_data.input = torch.randn(1, height, width, 3)
72131

73132
for _ in range(WARMUP_RUNS):
74133
self.client.put_video_input(dummy_frame)

runner/app/live/pipelines/comfyui_default_workflow.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
},
2424
"3": {
2525
"inputs": {
26-
"unet_name": "dynamic-dreamshaper8_SD15_dyn-b-1-4-2-h-448-704-512-w-448-704-512_00001_.engine",
26+
"unet_name": "static-dreamshaper8_SD15_$stat-b-1-h-384-w-704_00001_.engine",
2727
"model_type": "SD15"
2828
},
2929
"class_type": "TensorRTLoader",
@@ -194,7 +194,7 @@
194194
},
195195
"16": {
196196
"inputs": {
197-
"width": 448,
197+
"width": 384,
198198
"height": 704,
199199
"batch_size": 1
200200
},

runner/app/live/streamer/protocol/trickle.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,31 @@ def __init__(self, subscribe_url: str, publish_url: str, control_url: Optional[s
2323
self.events_publisher = None
2424
self.subscribe_task = None
2525
self.publish_task = None
26-
self.output_width = 512
27-
self.output_height = 512
2826

2927
async def start(self, params: dict = None):
3028
self.subscribe_queue = queue.Queue[InputFrame]()
3129
self.publish_queue = queue.Queue[OutputFrame]()
3230
metadata_cache = LastValueCache[dict]() # to pass video metadata from decoder to encoder
33-
34-
# Get resolution from params if available
35-
if params:
36-
self.output_width = params.get('width', self.output_width)
37-
self.output_height = params.get('height', self.output_height)
31+
32+
#TODO fix this default value issue
33+
output_width = params.get('width', 512)
34+
output_height = params.get('height', 512)
3835

3936
self.subscribe_task = asyncio.create_task(
40-
media.run_subscribe(self.subscribe_url, self.subscribe_queue.put, metadata_cache.put, self.emit_monitoring_event,
41-
output_width=self.output_width, output_height=self.output_height)
37+
media.run_subscribe(self.subscribe_url,
38+
self.subscribe_queue.put,
39+
metadata_cache.put,
40+
self.emit_monitoring_event,
41+
output_width,
42+
output_height)
4243
)
4344
self.publish_task = asyncio.create_task(
44-
media.run_publish(self.publish_url, self.publish_queue.get, metadata_cache.get, self.emit_monitoring_event,
45-
output_width=self.output_width, output_height=self.output_height)
45+
media.run_publish(self.publish_url,
46+
self.publish_queue.get,
47+
metadata_cache.get,
48+
self.emit_monitoring_event,
49+
output_width,
50+
output_height)
4651
)
4752
if self.control_url and self.control_url.strip() != "":
4853
self.control_subscriber = TrickleSubscriber(self.control_url)

runner/app/live/streamer/streamer.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ async def start(self, params: dict):
4848
self.request_id, self.stream_id, params, self
4949
)
5050

51+
params['width'] = params.get('width', self.output_width)
52+
params['height'] = params.get('height', self.output_height)
53+
5154
self.stop_event.clear()
5255
await self.protocol.start(params)
5356

@@ -180,24 +183,16 @@ async def run_ingress_loop(self):
180183
if frame.mode != "RGBA":
181184
frame = frame.convert("RGBA")
182185

183-
# Scale image to 512x512 as most models expect this size, especially when using tensorrt
186+
target_width = self.output_width
187+
target_height = self.output_width
188+
189+
# # Scale image to target size
184190
width, height = frame.size
185-
if (width, height) != (512, 512):
191+
if (width, height) != (target_width, target_height):
186192
frame_array = np.array(frame)
187193

188-
# Crop to the center square if image not already square
189-
square_size = min(width, height)
190-
if width != height:
191-
start_x = width // 2 - square_size // 2
192-
start_y = height // 2 - square_size // 2
193-
frame_array = frame_array[
194-
start_y : start_y + square_size, start_x : start_x + square_size
195-
]
196-
197194
# Resize using cv2 (much faster than PIL)
198-
if square_size != 512:
199-
frame_array = cv2.resize(frame_array, (512, 512))
200-
195+
frame_array = cv2.resize(frame_array, (target_width, target_height))
201196
frame = Image.fromarray(frame_array)
202197

203198
logging.debug(

runner/app/live/trickle/decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99

1010
MAX_FRAMERATE=24
1111

12-
def decode_av(pipe_input, frame_callback, put_metadata, output_width=512, output_height=512):
12+
def decode_av(pipe_input, frame_callback, put_metadata, output_width, output_height):
1313
"""
1414
Reads from a pipe (or file-like object).
1515
1616
:param pipe_input: File path, 'pipe:', sys.stdin, or another file-like object.
1717
:param frame_callback: A function that accepts an InputFrame object
1818
:param put_metadata: A function that accepts audio/video metadata
19-
:param output_width: Desired output width (default: 512)
20-
:param output_height: Desired output height (default: 512)
19+
:param output_width: Desired output width
20+
:param output_height: Desired output height
2121
"""
2222
container = cast(InputContainer, av.open(pipe_input, 'r'))
2323

runner/app/live/trickle/media.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
MAX_ENCODER_RETRIES = 3
1717
ENCODER_RETRY_RESET_SECONDS = 120 # reset retry counter after 2 minutes
1818

19-
async def run_subscribe(subscribe_url: str, image_callback, put_metadata, monitoring_callback, output_width=512, output_height=512):
19+
async def run_subscribe(subscribe_url: str, image_callback, put_metadata, monitoring_callback, output_width, output_height):
2020
# TODO add some pre-processing parameters, eg image size
2121
try:
2222
in_pipe, out_pipe = os.pipe()
@@ -112,7 +112,7 @@ def decode_runner():
112112
loop = asyncio.get_running_loop()
113113
await loop.run_in_executor(None, decode_runner)
114114

115-
def encode_in(task_pipes, task_lock, image_generator, sync_callback, get_metadata, output_width=512, output_height=512, **kwargs):
115+
def encode_in(task_pipes, task_lock, image_generator, sync_callback, get_metadata, output_width, output_height, **kwargs):
116116
# encode_av has a tendency to crash, so restart as necessary
117117
retryCount = 0
118118
last_retry_time = time.time()
@@ -146,7 +146,7 @@ def encode_in(task_pipes, task_lock, image_generator, sync_callback, get_metadat
146146
logging.exception("Error closing pipe on task list", stack_info=True)
147147
logging.info(f"Closed pipes - {pipe_count}/{total_pipes}")
148148

149-
async def run_publish(publish_url: str, image_generator, get_metadata, monitoring_callback, output_width=512, output_height=512):
149+
async def run_publish(publish_url: str, image_generator, get_metadata, monitoring_callback, output_width, output_height):
150150
first_segment = True
151151

152152
publisher = None

0 commit comments

Comments
 (0)