Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 84 additions & 20 deletions vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should update vllm-omni/vllm_omni/diffusion/worker/gpu_worker.py to support multi-request batching at first

    def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput:
        """
        Execute a forward pass.
        """
        assert self.pipeline is not None
        # TODO: dealing with first req for now
        req = reqs[0]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made this change locally but just dropping this change here would cause problems with other models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, what are the detailed problems?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently model implmenetations need some changes to support batching, changing this will cause inference errors

Original file line number Diff line number Diff line change
Expand Up @@ -621,19 +621,74 @@ def forward(
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
) -> DiffusionOutput:
# # TODO: only support single prompt now
# if req.prompt is not None:
# prompt = req.prompt[0] if isinstance(req.prompt, list) else req.prompt
prompt = req.prompt if req.prompt is not None else prompt
negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt
height = req.height or self.default_sample_size * self.vae_scale_factor
width = req.width or self.default_sample_size * self.vae_scale_factor
num_inference_steps = req.num_inference_steps or num_inference_steps
generator = req.generator or generator
true_cfg_scale = req.true_cfg_scale or true_cfg_scale
req_num_outputs = getattr(req, "num_outputs_per_prompt", None)
# Handle batch of requests
if isinstance(req, list):
requests = req
else:
requests = [req]

batch_size = len(requests)

# Extract parameters from requests, using defaults for missing values
prompts = []
negative_prompts = []
heights = []
widths = []
generators = []

for r in requests:
# Handle prompt
req_prompt = r.prompt if r.prompt is not None else prompt
if isinstance(req_prompt, list):
req_prompt = req_prompt[0] if len(req_prompt) > 0 else prompt
prompts.append(req_prompt)

# Handle negative prompt
req_neg_prompt = r.negative_prompt if r.negative_prompt is not None else negative_prompt
if isinstance(req_neg_prompt, list):
req_neg_prompt = req_neg_prompt[0] if len(req_neg_prompt) > 0 else negative_prompt
negative_prompts.append(req_neg_prompt)

# Handle height/width
heights.append(r.height or self.default_sample_size * self.vae_scale_factor)
widths.append(r.width or self.default_sample_size * self.vae_scale_factor)

# Handle generator
generators.append(r.generator if r.generator is not None else None)

# For batch processing, we require all images to have the same dimensions
# Use the first request's dimensions as the batch dimensions
height = heights[0]
width = widths[0]

# Validate that all requests have the same dimensions for batch processing
if not all(h == height for h in heights) or not all(w == width for w in widths):
logger.warning(
"Batch processing requires all requests to have the same height and width. "
"Using dimensions from the first request."
)

# Use parameters from the first request for shared settings
first_req = requests[0]
num_inference_steps = first_req.num_inference_steps or num_inference_steps
true_cfg_scale = first_req.true_cfg_scale or true_cfg_scale
req_num_outputs = getattr(first_req, "num_outputs_per_prompt", None)
if req_num_outputs and req_num_outputs > 0:
num_images_per_prompt = req_num_outputs

# Handle generator for batch
if any(g is not None for g in generators):
# Filter out None generators and use the list
generator = [g for g in generators if g is not None]
if len(generator) == 1:
generator = generator[0]
elif len(generator) == 0:
generator = None

# Use the batch prompts
prompt = prompts
negative_prompt = negative_prompts

# 1. check inputs
# 2. encode prompts
# 3. prepare latents and timesteps
Expand Down Expand Up @@ -668,6 +723,16 @@ def forward(
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)

# Check if negative_prompt is a non-empty list or non-empty string
if isinstance(negative_prompt, list):
has_neg_prompt = any(np for np in negative_prompt) or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
elif isinstance(negative_prompt, str):
has_neg_prompt = bool(negative_prompt) or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)

do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
Expand Down Expand Up @@ -697,15 +762,14 @@ def forward(
generator,
latents,
)
img_shapes = [
[
(
1,
height // self.vae_scale_factor // 2,
width // self.vae_scale_factor // 2,
)
]
] * batch_size

# Prepare img_shapes for batch processing
img_shape = (
1,
height // self.vae_scale_factor // 2,
width // self.vae_scale_factor // 2,
)
img_shapes = [[img_shape]] * (batch_size * num_images_per_prompt)

timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1])
# num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
Expand Down
Loading