diff --git a/src/examples/edit_demo.py b/src/examples/edit_demo.py index 5cfcdfc..30f40b6 100644 --- a/src/examples/edit_demo.py +++ b/src/examples/edit_demo.py @@ -10,9 +10,13 @@ # --- Model Loading --- dtype = torch.bfloat16 device = "cuda" if torch.cuda.is_available() else "cpu" +model = "Qwen/Qwen-Image-Edit" # Load the model pipeline -pipe = QwenImageEditPipeline.from_pretrained("Qwen/Qwen-Image-Edit", torch_dtype=dtype).to(device) +if torch.cuda.device_count() >1: + pipe = QwenImageEditPipeline.from_pretrained(model, torch_dtype=dtype, device_map="balanced") +else: + pipe = QwenImageEditPipeline.from_pretrained(model, torch_dtype=dtype).to(device) # --- UI Constants and Helpers --- MAX_SEED = np.iinfo(np.int32).max @@ -152,4 +156,4 @@ def infer( ) if __name__ == "__main__": - demo.launch() \ No newline at end of file + demo.launch()