Skip to content

Commit 7f9aecc

Browse files
committed
fix bugs in qwen3-omni
Signed-off-by: Chenguang ZHENG <645327136@qq.com>
1 parent 8668907 commit 7f9aecc

File tree

9 files changed

+339
-81
lines changed

9 files changed

+339
-81
lines changed

examples/offline_inference/image_to_image/image_edit.py

Lines changed: 175 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,33 @@
2222
--cfg_scale 4.0 \
2323
--guidance_scale 1.0
2424
25+
Usage (with cache-dit acceleration):
26+
python image_edit.py \
27+
--image input.png \
28+
--prompt "Edit description" \
29+
--cache_backend cache_dit \
30+
--cache_dit_max_continuous_cached_steps 3 \
31+
--cache_dit_residual_diff_threshold 0.24 \
32+
--cache_dit_enable_taylorseer
33+
34+
Usage (with tea_cache acceleration):
35+
python image_edit.py \
36+
--image input.png \
37+
--prompt "Edit description" \
38+
--cache_backend tea_cache \
39+
--tea_cache_rel_l1_thresh 0.25
40+
41+
Usage (layered):
42+
python image_edit.py \
43+
--model "Qwen/Qwen-Image-Layered" \
44+
--image input.png \
45+
--prompt "" \
46+
--output "layered" \
47+
--num_inference_steps 50 \
48+
--cfg_scale 4.0 \
49+
--layers 4 \
50+
--color-format "RGBA"
51+
2552
For more options, run:
2653
python image_edit.py --help
2754
"""
@@ -100,7 +127,7 @@ def parse_args() -> argparse.Namespace:
100127
"--output",
101128
type=str,
102129
default="output_image_edit.png",
103-
help="Path to save the edited image (PNG).",
130+
help=("Path to save the edited image (PNG). Or prefix for Qwen-Image-Layered model save images(PNG)."),
104131
)
105132
parser.add_argument(
106133
"--num_outputs_per_prompt",
@@ -132,6 +159,87 @@ def parse_args() -> argparse.Namespace:
132159
help="Number of GPUs used for ulysses sequence parallelism.",
133160
)
134161

162+
parser.add_argument("--layers", type=int, default=4, help="Number of layers to decompose the input image into.")
163+
parser.add_argument(
164+
"--resolution",
165+
type=int,
166+
default=640,
167+
help="Bucket in (640, 1024) to determine the condition and output resolution",
168+
)
169+
170+
parser.add_argument(
171+
"--color-format",
172+
type=str,
173+
default="RGB",
174+
help="For Qwen-Image-Layered, set to RGBA.",
175+
)
176+
177+
# Cache-DiT specific parameters
178+
parser.add_argument(
179+
"--cache_dit_fn_compute_blocks",
180+
type=int,
181+
default=1,
182+
help="[cache-dit] Number of forward compute blocks. Optimized for single-transformer models.",
183+
)
184+
parser.add_argument(
185+
"--cache_dit_bn_compute_blocks",
186+
type=int,
187+
default=0,
188+
help="[cache-dit] Number of backward compute blocks.",
189+
)
190+
parser.add_argument(
191+
"--cache_dit_max_warmup_steps",
192+
type=int,
193+
default=4,
194+
help="[cache-dit] Maximum warmup steps (works for few-step models).",
195+
)
196+
parser.add_argument(
197+
"--cache_dit_residual_diff_threshold",
198+
type=float,
199+
default=0.24,
200+
help="[cache-dit] Residual diff threshold. Higher values enable more aggressive caching.",
201+
)
202+
parser.add_argument(
203+
"--cache_dit_max_continuous_cached_steps",
204+
type=int,
205+
default=3,
206+
help="[cache-dit] Maximum continuous cached steps to prevent precision degradation.",
207+
)
208+
parser.add_argument(
209+
"--cache_dit_enable_taylorseer",
210+
action="store_true",
211+
default=False,
212+
help="[cache-dit] Enable TaylorSeer acceleration (not suitable for few-step models).",
213+
)
214+
parser.add_argument(
215+
"--cache_dit_taylorseer_order",
216+
type=int,
217+
default=1,
218+
help="[cache-dit] TaylorSeer polynomial order.",
219+
)
220+
parser.add_argument(
221+
"--cache_dit_scm_steps_mask_policy",
222+
type=str,
223+
default=None,
224+
choices=[None, "slow", "medium", "fast", "ultra"],
225+
help="[cache-dit] SCM mask policy: None (disabled), slow, medium, fast, ultra.",
226+
)
227+
parser.add_argument(
228+
"--cache_dit_scm_steps_policy",
229+
type=str,
230+
default="dynamic",
231+
choices=["dynamic", "static"],
232+
help="[cache-dit] SCM steps policy: dynamic or static.",
233+
)
234+
235+
# TeaCache specific parameters
236+
parser.add_argument(
237+
"--tea_cache_rel_l1_thresh",
238+
type=float,
239+
default=0.2,
240+
help="[tea_cache] Threshold for accumulated relative L1 distance.",
241+
)
242+
135243
return parser.parse_args()
136244

137245

@@ -143,7 +251,8 @@ def main():
143251
for image_path in args.image:
144252
if not os.path.exists(image_path):
145253
raise FileNotFoundError(f"Input image not found: {image_path}")
146-
img = Image.open(image_path).convert("RGB")
254+
255+
img = Image.open(image_path).convert(args.color_format)
147256
input_images.append(img)
148257

149258
# Use single image or list based on number of inputs
@@ -164,29 +273,22 @@ def main():
164273
cache_config = None
165274
if args.cache_backend == "cache_dit":
166275
# cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer
167-
# All parameters marked with [cache-dit only] in DiffusionCacheConfig
168276
cache_config = {
169-
# DBCache parameters [cache-dit only]
170-
"Fn_compute_blocks": 1, # Optimized for single-transformer models
171-
"Bn_compute_blocks": 0, # Number of backward compute blocks
172-
"max_warmup_steps": 4, # Maximum warmup steps (works for few-step models)
173-
"residual_diff_threshold": 0.24, # Higher threshold for more aggressive caching
174-
"max_continuous_cached_steps": 3, # Limit to prevent precision degradation
175-
# TaylorSeer parameters [cache-dit only]
176-
"enable_taylorseer": False, # Disabled by default (not suitable for few-step models)
177-
"taylorseer_order": 1, # TaylorSeer polynomial order
178-
# SCM (Step Computation Masking) parameters [cache-dit only]
179-
"scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra"
180-
"scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static"
277+
"Fn_compute_blocks": args.cache_dit_fn_compute_blocks,
278+
"Bn_compute_blocks": args.cache_dit_bn_compute_blocks,
279+
"max_warmup_steps": args.cache_dit_max_warmup_steps,
280+
"residual_diff_threshold": args.cache_dit_residual_diff_threshold,
281+
"max_continuous_cached_steps": args.cache_dit_max_continuous_cached_steps,
282+
"enable_taylorseer": args.cache_dit_enable_taylorseer,
283+
"taylorseer_order": args.cache_dit_taylorseer_order,
284+
"scm_steps_mask_policy": args.cache_dit_scm_steps_mask_policy,
285+
"scm_steps_policy": args.cache_dit_scm_steps_policy,
181286
}
182287
elif args.cache_backend == "tea_cache":
183288
# TeaCache configuration
184-
# All parameters marked with [tea_cache only] in DiffusionCacheConfig
185289
cache_config = {
186-
# TeaCache parameters [tea_cache only]
187-
"rel_l1_thresh": 0.2, # Threshold for accumulated relative L1 distance
290+
"rel_l1_thresh": args.tea_cache_rel_l1_thresh,
188291
# Note: coefficients will use model-specific defaults based on model_type
189-
# (e.g., QwenImagePipeline or FluxPipeline)
190292
}
191293

192294
# Initialize Omni with appropriate pipeline
@@ -218,16 +320,20 @@ def main():
218320
try:
219321
generation_start = time.perf_counter()
220322
# Generate edited image
221-
outputs = omni.generate(
222-
prompt=args.prompt,
223-
pil_image=input_image,
224-
negative_prompt=args.negative_prompt,
225-
generator=generator,
226-
true_cfg_scale=args.cfg_scale,
227-
guidance_scale=args.guidance_scale,
228-
num_inference_steps=args.num_inference_steps,
229-
num_outputs_per_prompt=args.num_outputs_per_prompt,
230-
)
323+
generate_kwargs = {
324+
"prompt": args.prompt,
325+
"pil_image": input_image,
326+
"negative_prompt": args.negative_prompt,
327+
"generator": generator,
328+
"true_cfg_scale": args.cfg_scale,
329+
"guidance_scale": args.guidance_scale,
330+
"num_inference_steps": args.num_inference_steps,
331+
"num_outputs_per_prompt": args.num_outputs_per_prompt,
332+
"layers": args.layers,
333+
"resolution": args.resolution,
334+
}
335+
336+
outputs = omni.generate(**generate_kwargs)
231337
generation_end = time.perf_counter()
232338
generation_time = generation_end - generation_start
233339

@@ -239,15 +345,24 @@ def main():
239345
logger.info("Outputs: %s", outputs)
240346

241347
# Extract images from OmniRequestOutput
242-
first_output = outputs[0]
348+
# Handle both OmniRequestOutput list and direct images list
243349
images = []
244-
if getattr(first_output, "images", None):
245-
images = first_output.images
246-
elif getattr(first_output, "request_output", None):
247-
req_out = first_output.request_output
248-
if isinstance(req_out, list):
249-
req_out = req_out[0]
250-
images = getattr(req_out, "images", None) or []
350+
if isinstance(outputs, list) and len(outputs) > 0:
351+
first_output = outputs[0]
352+
# Check if it's OmniRequestOutput with images attribute
353+
if hasattr(first_output, "images") and first_output.images:
354+
images = first_output.images
355+
elif hasattr(first_output, "request_output") and first_output.request_output:
356+
req_out = first_output.request_output
357+
if isinstance(req_out, list):
358+
req_out = req_out[0]
359+
if hasattr(req_out, "images"):
360+
images = req_out.images or []
361+
# Check if outputs is already a list of images
362+
elif isinstance(first_output, Image.Image):
363+
images = outputs
364+
elif isinstance(outputs, Image.Image):
365+
images = [outputs]
251366

252367
if not images:
253368
raise ValueError("No images found in omni.generate() output")
@@ -258,16 +373,33 @@ def main():
258373
suffix = output_path.suffix or ".png"
259374
stem = output_path.stem or "output_image_edit"
260375

261-
if len(images) <= 1:
262-
images[0].save(output_path)
263-
print(f"Saved edited image to {os.path.abspath(output_path)}")
376+
# Handle layered output (each image may be a list of layers)
377+
if args.num_outputs_per_prompt <= 1:
378+
img = images[0]
379+
# Check if this is a layered output (list of images)
380+
if isinstance(img, list):
381+
for sub_idx, sub_img in enumerate(img):
382+
save_path = output_path.parent / f"{stem}_{sub_idx}{suffix}"
383+
sub_img.save(save_path)
384+
print(f"Saved edited image to {os.path.abspath(save_path)}")
385+
else:
386+
img.save(output_path)
387+
print(f"Saved edited image to {os.path.abspath(output_path)}")
264388
else:
265389
for idx, img in enumerate(images):
266-
save_path = output_path.parent / f"{stem}_{idx}{suffix}"
267-
img.save(save_path)
268-
print(f"Saved edited image to {os.path.abspath(save_path)}")
390+
# Check if this is a layered output (list of images)
391+
if isinstance(img, list):
392+
for sub_idx, sub_img in enumerate(img):
393+
save_path = output_path.parent / f"{stem}_{idx}_{sub_idx}{suffix}"
394+
sub_img.save(save_path)
395+
print(f"Saved edited image to {os.path.abspath(save_path)}")
396+
else:
397+
save_path = output_path.parent / f"{stem}_{idx}{suffix}"
398+
img.save(save_path)
399+
print(f"Saved edited image to {os.path.abspath(save_path)}")
269400
finally:
270401
omni.close()
271402

403+
272404
if __name__ == "__main__":
273405
main()

vllm_omni/config/model.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,17 @@ def draw_hf_text_config(self):
8787
# we need to draw the text config from the corresponding model stage.
8888
if self.hf_config_name is None:
8989
return get_hf_text_config(self.hf_config)
90-
return getattr(self.hf_config, self.hf_config_name).get_text_config()
90+
try:
91+
# Try to get the stage-specific config (e.g., thinker_config, talker_config)
92+
stage_config = getattr(self.hf_config, self.hf_config_name)
93+
return stage_config.get_text_config()
94+
except AttributeError:
95+
# Fallback: if the attribute doesn't exist, use the default get_hf_text_config
96+
logger.warning(
97+
f"Config attribute '{self.hf_config_name}' not found in hf_config, "
98+
"falling back to default get_hf_text_config"
99+
)
100+
return get_hf_text_config(self.hf_config)
91101

92102
def __post_init__(
93103
self,
@@ -173,9 +183,19 @@ def __post_init__(
173183
self.hf_text_config = self.draw_hf_text_config()
174184
self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None)
175185
self.encoder_config = self._get_encoder_config()
176-
self.hf_image_processor_config = get_hf_image_processor_config(
177-
self.model, hf_token=self.hf_token, revision=self.revision
178-
)
186+
# Try to load image processor config, but allow it to fail for stages that don't need it
187+
try:
188+
self.hf_image_processor_config = get_hf_image_processor_config(
189+
self.model, hf_token=self.hf_token, revision=self.revision
190+
)
191+
except (OSError, ValueError, IndexError) as e:
192+
# Some stages (e.g., code2wav, talker) don't need image processor
193+
# Log warning but allow initialization to continue
194+
logger.warning(
195+
f"Failed to load image processor config for model '{self.model}': {e}. "
196+
"This is expected for stages that don't require image processing."
197+
)
198+
self.hf_image_processor_config = None
179199

180200
architectures = self.architectures
181201
registry = self.registry

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,15 @@ def step(self, requests: list[OmniDiffusionRequest]):
6767
return None
6868

6969
postprocess_start_time = time.time()
70-
images = self.post_process_func(output.output)
70+
images = self.post_process_func(output.output) if self.post_process_func is not None else output.output
7171
postprocess_time = time.time() - postprocess_start_time
7272
logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds")
7373

7474
# Convert to OmniRequestOutput format
7575
# Ensure images is a list
7676
if not isinstance(images, list):
7777
images = [images] if images is not None else []
78-
78+
7979
# Handle single request or multiple requests
8080
if len(requests) == 1:
8181
# Single request: return single OmniRequestOutput
@@ -84,11 +84,11 @@ def step(self, requests: list[OmniDiffusionRequest]):
8484
prompt = request.prompt
8585
if isinstance(prompt, list):
8686
prompt = prompt[0] if prompt else None
87-
87+
8888
metrics = {}
8989
if output.trajectory_timesteps is not None:
90-
metrics['trajectory_timesteps'] = output.trajectory_timesteps
91-
90+
metrics["trajectory_timesteps"] = output.trajectory_timesteps
91+
9292
return OmniRequestOutput.from_diffusion(
9393
request_id=request_id,
9494
images=images,
@@ -101,22 +101,22 @@ def step(self, requests: list[OmniDiffusionRequest]):
101101
# Split images based on num_outputs_per_prompt for each request
102102
results = []
103103
image_idx = 0
104-
104+
105105
for request in requests:
106106
request_id = request.request_id or ""
107107
prompt = request.prompt
108108
if isinstance(prompt, list):
109109
prompt = prompt[0] if prompt else None
110-
110+
111111
# Get images for this request
112112
num_outputs = request.num_outputs_per_prompt
113-
request_images = images[image_idx:image_idx + num_outputs] if image_idx < len(images) else []
113+
request_images = images[image_idx : image_idx + num_outputs] if image_idx < len(images) else []
114114
image_idx += num_outputs
115-
115+
116116
metrics = {}
117117
if output.trajectory_timesteps is not None:
118-
metrics['trajectory_timesteps'] = output.trajectory_timesteps
119-
118+
metrics["trajectory_timesteps"] = output.trajectory_timesteps
119+
120120
results.append(
121121
OmniRequestOutput.from_diffusion(
122122
request_id=request_id,
@@ -126,7 +126,7 @@ def step(self, requests: list[OmniDiffusionRequest]):
126126
latents=output.trajectory_latents,
127127
)
128128
)
129-
129+
130130
return results
131131
except Exception as e:
132132
logger.error(f"Generation failed: {e}")

0 commit comments

Comments
 (0)