2222 --cfg_scale 4.0 \
2323 --guidance_scale 1.0
2424
25+ Usage (with cache-dit acceleration):
26+ python image_edit.py \
27+ --image input.png \
28+ --prompt "Edit description" \
29+ --cache_backend cache_dit \
30+ --cache_dit_max_continuous_cached_steps 3 \
31+ --cache_dit_residual_diff_threshold 0.24 \
32+ --cache_dit_enable_taylorseer
33+
34+ Usage (with tea_cache acceleration):
35+ python image_edit.py \
36+ --image input.png \
37+ --prompt "Edit description" \
38+ --cache_backend tea_cache \
39+ --tea_cache_rel_l1_thresh 0.25
40+
41+ Usage (layered):
42+ python image_edit.py \
43+ --model "Qwen/Qwen-Image-Layered" \
44+ --image input.png \
45+ --prompt "" \
46+ --output "layered" \
47+ --num_inference_steps 50 \
48+ --cfg_scale 4.0 \
49+ --layers 4 \
50+ --color-format "RGBA"
51+
2552For more options, run:
2653 python image_edit.py --help
2754"""
@@ -100,7 +127,7 @@ def parse_args() -> argparse.Namespace:
100127 "--output" ,
101128 type = str ,
102129 default = "output_image_edit.png" ,
103- help = "Path to save the edited image (PNG)." ,
130+ help = ( "Path to save the edited image (PNG). Or prefix for Qwen-Image-Layered model save images(PNG)." ) ,
104131 )
105132 parser .add_argument (
106133 "--num_outputs_per_prompt" ,
@@ -132,6 +159,87 @@ def parse_args() -> argparse.Namespace:
132159 help = "Number of GPUs used for ulysses sequence parallelism." ,
133160 )
134161
162+ parser .add_argument ("--layers" , type = int , default = 4 , help = "Number of layers to decompose the input image into." )
163+ parser .add_argument (
164+ "--resolution" ,
165+ type = int ,
166+ default = 640 ,
167+ help = "Bucket in (640, 1024) to determine the condition and output resolution" ,
168+ )
169+
170+ parser .add_argument (
171+ "--color-format" ,
172+ type = str ,
173+ default = "RGB" ,
174+ help = "For Qwen-Image-Layered, set to RGBA." ,
175+ )
176+
177+ # Cache-DiT specific parameters
178+ parser .add_argument (
179+ "--cache_dit_fn_compute_blocks" ,
180+ type = int ,
181+ default = 1 ,
182+ help = "[cache-dit] Number of forward compute blocks. Optimized for single-transformer models." ,
183+ )
184+ parser .add_argument (
185+ "--cache_dit_bn_compute_blocks" ,
186+ type = int ,
187+ default = 0 ,
188+ help = "[cache-dit] Number of backward compute blocks." ,
189+ )
190+ parser .add_argument (
191+ "--cache_dit_max_warmup_steps" ,
192+ type = int ,
193+ default = 4 ,
194+ help = "[cache-dit] Maximum warmup steps (works for few-step models)." ,
195+ )
196+ parser .add_argument (
197+ "--cache_dit_residual_diff_threshold" ,
198+ type = float ,
199+ default = 0.24 ,
200+ help = "[cache-dit] Residual diff threshold. Higher values enable more aggressive caching." ,
201+ )
202+ parser .add_argument (
203+ "--cache_dit_max_continuous_cached_steps" ,
204+ type = int ,
205+ default = 3 ,
206+ help = "[cache-dit] Maximum continuous cached steps to prevent precision degradation." ,
207+ )
208+ parser .add_argument (
209+ "--cache_dit_enable_taylorseer" ,
210+ action = "store_true" ,
211+ default = False ,
212+ help = "[cache-dit] Enable TaylorSeer acceleration (not suitable for few-step models)." ,
213+ )
214+ parser .add_argument (
215+ "--cache_dit_taylorseer_order" ,
216+ type = int ,
217+ default = 1 ,
218+ help = "[cache-dit] TaylorSeer polynomial order." ,
219+ )
220+ parser .add_argument (
221+ "--cache_dit_scm_steps_mask_policy" ,
222+ type = str ,
223+ default = None ,
224+ choices = [None , "slow" , "medium" , "fast" , "ultra" ],
225+ help = "[cache-dit] SCM mask policy: None (disabled), slow, medium, fast, ultra." ,
226+ )
227+ parser .add_argument (
228+ "--cache_dit_scm_steps_policy" ,
229+ type = str ,
230+ default = "dynamic" ,
231+ choices = ["dynamic" , "static" ],
232+ help = "[cache-dit] SCM steps policy: dynamic or static." ,
233+ )
234+
235+ # TeaCache specific parameters
236+ parser .add_argument (
237+ "--tea_cache_rel_l1_thresh" ,
238+ type = float ,
239+ default = 0.2 ,
240+ help = "[tea_cache] Threshold for accumulated relative L1 distance." ,
241+ )
242+
135243 return parser .parse_args ()
136244
137245
@@ -143,7 +251,8 @@ def main():
143251 for image_path in args .image :
144252 if not os .path .exists (image_path ):
145253 raise FileNotFoundError (f"Input image not found: { image_path } " )
146- img = Image .open (image_path ).convert ("RGB" )
254+
255+ img = Image .open (image_path ).convert (args .color_format )
147256 input_images .append (img )
148257
149258 # Use single image or list based on number of inputs
@@ -164,29 +273,22 @@ def main():
164273 cache_config = None
165274 if args .cache_backend == "cache_dit" :
166275 # cache-dit configuration: Hybrid DBCache + SCM + TaylorSeer
167- # All parameters marked with [cache-dit only] in DiffusionCacheConfig
168276 cache_config = {
169- # DBCache parameters [cache-dit only]
170- "Fn_compute_blocks" : 1 , # Optimized for single-transformer models
171- "Bn_compute_blocks" : 0 , # Number of backward compute blocks
172- "max_warmup_steps" : 4 , # Maximum warmup steps (works for few-step models)
173- "residual_diff_threshold" : 0.24 , # Higher threshold for more aggressive caching
174- "max_continuous_cached_steps" : 3 , # Limit to prevent precision degradation
175- # TaylorSeer parameters [cache-dit only]
176- "enable_taylorseer" : False , # Disabled by default (not suitable for few-step models)
177- "taylorseer_order" : 1 , # TaylorSeer polynomial order
178- # SCM (Step Computation Masking) parameters [cache-dit only]
179- "scm_steps_mask_policy" : None , # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra"
180- "scm_steps_policy" : "dynamic" , # SCM steps policy: "dynamic" or "static"
277+ "Fn_compute_blocks" : args .cache_dit_fn_compute_blocks ,
278+ "Bn_compute_blocks" : args .cache_dit_bn_compute_blocks ,
279+ "max_warmup_steps" : args .cache_dit_max_warmup_steps ,
280+ "residual_diff_threshold" : args .cache_dit_residual_diff_threshold ,
281+ "max_continuous_cached_steps" : args .cache_dit_max_continuous_cached_steps ,
282+ "enable_taylorseer" : args .cache_dit_enable_taylorseer ,
283+ "taylorseer_order" : args .cache_dit_taylorseer_order ,
284+ "scm_steps_mask_policy" : args .cache_dit_scm_steps_mask_policy ,
285+ "scm_steps_policy" : args .cache_dit_scm_steps_policy ,
181286 }
182287 elif args .cache_backend == "tea_cache" :
183288 # TeaCache configuration
184- # All parameters marked with [tea_cache only] in DiffusionCacheConfig
185289 cache_config = {
186- # TeaCache parameters [tea_cache only]
187- "rel_l1_thresh" : 0.2 , # Threshold for accumulated relative L1 distance
290+ "rel_l1_thresh" : args .tea_cache_rel_l1_thresh ,
188291 # Note: coefficients will use model-specific defaults based on model_type
189- # (e.g., QwenImagePipeline or FluxPipeline)
190292 }
191293
192294 # Initialize Omni with appropriate pipeline
@@ -218,16 +320,20 @@ def main():
218320 try :
219321 generation_start = time .perf_counter ()
220322 # Generate edited image
221- outputs = omni .generate (
222- prompt = args .prompt ,
223- pil_image = input_image ,
224- negative_prompt = args .negative_prompt ,
225- generator = generator ,
226- true_cfg_scale = args .cfg_scale ,
227- guidance_scale = args .guidance_scale ,
228- num_inference_steps = args .num_inference_steps ,
229- num_outputs_per_prompt = args .num_outputs_per_prompt ,
230- )
323+ generate_kwargs = {
324+ "prompt" : args .prompt ,
325+ "pil_image" : input_image ,
326+ "negative_prompt" : args .negative_prompt ,
327+ "generator" : generator ,
328+ "true_cfg_scale" : args .cfg_scale ,
329+ "guidance_scale" : args .guidance_scale ,
330+ "num_inference_steps" : args .num_inference_steps ,
331+ "num_outputs_per_prompt" : args .num_outputs_per_prompt ,
332+ "layers" : args .layers ,
333+ "resolution" : args .resolution ,
334+ }
335+
336+ outputs = omni .generate (** generate_kwargs )
231337 generation_end = time .perf_counter ()
232338 generation_time = generation_end - generation_start
233339
@@ -239,15 +345,24 @@ def main():
239345 logger .info ("Outputs: %s" , outputs )
240346
241347 # Extract images from OmniRequestOutput
242- first_output = outputs [ 0 ]
348+ # Handle both OmniRequestOutput list and direct images list
243349 images = []
244- if getattr (first_output , "images" , None ):
245- images = first_output .images
246- elif getattr (first_output , "request_output" , None ):
247- req_out = first_output .request_output
248- if isinstance (req_out , list ):
249- req_out = req_out [0 ]
250- images = getattr (req_out , "images" , None ) or []
350+ if isinstance (outputs , list ) and len (outputs ) > 0 :
351+ first_output = outputs [0 ]
352+ # Check if it's OmniRequestOutput with images attribute
353+ if hasattr (first_output , "images" ) and first_output .images :
354+ images = first_output .images
355+ elif hasattr (first_output , "request_output" ) and first_output .request_output :
356+ req_out = first_output .request_output
357+ if isinstance (req_out , list ):
358+ req_out = req_out [0 ]
359+ if hasattr (req_out , "images" ):
360+ images = req_out .images or []
361+ # Check if outputs is already a list of images
362+ elif isinstance (first_output , Image .Image ):
363+ images = outputs
364+ elif isinstance (outputs , Image .Image ):
365+ images = [outputs ]
251366
252367 if not images :
253368 raise ValueError ("No images found in omni.generate() output" )
@@ -258,16 +373,33 @@ def main():
258373 suffix = output_path .suffix or ".png"
259374 stem = output_path .stem or "output_image_edit"
260375
261- if len (images ) <= 1 :
262- images [0 ].save (output_path )
263- print (f"Saved edited image to { os .path .abspath (output_path )} " )
376+ # Handle layered output (each image may be a list of layers)
377+ if args .num_outputs_per_prompt <= 1 :
378+ img = images [0 ]
379+ # Check if this is a layered output (list of images)
380+ if isinstance (img , list ):
381+ for sub_idx , sub_img in enumerate (img ):
382+ save_path = output_path .parent / f"{ stem } _{ sub_idx } { suffix } "
383+ sub_img .save (save_path )
384+ print (f"Saved edited image to { os .path .abspath (save_path )} " )
385+ else :
386+ img .save (output_path )
387+ print (f"Saved edited image to { os .path .abspath (output_path )} " )
264388 else :
265389 for idx , img in enumerate (images ):
266- save_path = output_path .parent / f"{ stem } _{ idx } { suffix } "
267- img .save (save_path )
268- print (f"Saved edited image to { os .path .abspath (save_path )} " )
390+ # Check if this is a layered output (list of images)
391+ if isinstance (img , list ):
392+ for sub_idx , sub_img in enumerate (img ):
393+ save_path = output_path .parent / f"{ stem } _{ idx } _{ sub_idx } { suffix } "
394+ sub_img .save (save_path )
395+ print (f"Saved edited image to { os .path .abspath (save_path )} " )
396+ else :
397+ save_path = output_path .parent / f"{ stem } _{ idx } { suffix } "
398+ img .save (save_path )
399+ print (f"Saved edited image to { os .path .abspath (save_path )} " )
269400 finally :
270401 omni .close ()
271402
403+
272404if __name__ == "__main__" :
273405 main ()
0 commit comments