diff --git a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py index e5b99840d..74e1823e1 100644 --- a/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py +++ b/vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py @@ -621,19 +621,74 @@ def forward( callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, ) -> DiffusionOutput: - # # TODO: only support single prompt now - # if req.prompt is not None: - # prompt = req.prompt[0] if isinstance(req.prompt, list) else req.prompt - prompt = req.prompt if req.prompt is not None else prompt - negative_prompt = req.negative_prompt if req.negative_prompt is not None else negative_prompt - height = req.height or self.default_sample_size * self.vae_scale_factor - width = req.width or self.default_sample_size * self.vae_scale_factor - num_inference_steps = req.num_inference_steps or num_inference_steps - generator = req.generator or generator - true_cfg_scale = req.true_cfg_scale or true_cfg_scale - req_num_outputs = getattr(req, "num_outputs_per_prompt", None) + # Handle batch of requests + if isinstance(req, list): + requests = req + else: + requests = [req] + + batch_size = len(requests) + + # Extract parameters from requests, using defaults for missing values + prompts = [] + negative_prompts = [] + heights = [] + widths = [] + generators = [] + + for r in requests: + # Handle prompt + req_prompt = r.prompt if r.prompt is not None else prompt + if isinstance(req_prompt, list): + req_prompt = req_prompt[0] if len(req_prompt) > 0 else prompt + prompts.append(req_prompt) + + # Handle negative prompt + req_neg_prompt = r.negative_prompt if r.negative_prompt is not None else negative_prompt + if isinstance(req_neg_prompt, list): + req_neg_prompt = req_neg_prompt[0] if len(req_neg_prompt) > 0 else negative_prompt + negative_prompts.append(req_neg_prompt) + + # Handle height/width + heights.append(r.height or self.default_sample_size * self.vae_scale_factor) + widths.append(r.width or self.default_sample_size * self.vae_scale_factor) + + # Handle generator + generators.append(r.generator if r.generator is not None else None) + + # For batch processing, we require all images to have the same dimensions + # Use the first request's dimensions as the batch dimensions + height = heights[0] + width = widths[0] + + # Validate that all requests have the same dimensions for batch processing + if not all(h == height for h in heights) or not all(w == width for w in widths): + logger.warning( + "Batch processing requires all requests to have the same height and width. " + "Using dimensions from the first request." + ) + + # Use parameters from the first request for shared settings + first_req = requests[0] + num_inference_steps = first_req.num_inference_steps or num_inference_steps + true_cfg_scale = first_req.true_cfg_scale or true_cfg_scale + req_num_outputs = getattr(first_req, "num_outputs_per_prompt", None) if req_num_outputs and req_num_outputs > 0: num_images_per_prompt = req_num_outputs + + # Handle generator for batch + if any(g is not None for g in generators): + # Filter out None generators and use the list + generator = [g for g in generators if g is not None] + if len(generator) == 1: + generator = generator[0] + elif len(generator) == 0: + generator = None + + # Use the batch prompts + prompt = prompts + negative_prompt = negative_prompts + # 1. check inputs # 2. encode prompts # 3. prepare latents and timesteps @@ -668,6 +723,16 @@ def forward( has_neg_prompt = negative_prompt is not None or ( negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None ) + + # Check if negative_prompt is a non-empty list or non-empty string + if isinstance(negative_prompt, list): + has_neg_prompt = any(np for np in negative_prompt) or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) + elif isinstance(negative_prompt, str): + has_neg_prompt = bool(negative_prompt) or ( + negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None + ) do_true_cfg = true_cfg_scale > 1 and has_neg_prompt prompt_embeds, prompt_embeds_mask = self.encode_prompt( @@ -697,15 +762,14 @@ def forward( generator, latents, ) - img_shapes = [ - [ - ( - 1, - height // self.vae_scale_factor // 2, - width // self.vae_scale_factor // 2, - ) - ] - ] * batch_size + + # Prepare img_shapes for batch processing + img_shape = ( + 1, + height // self.vae_scale_factor // 2, + width // self.vae_scale_factor // 2, + ) + img_shapes = [[img_shape]] * (batch_size * num_images_per_prompt) timesteps, num_inference_steps = self.prepare_timesteps(num_inference_steps, sigmas, latents.shape[1]) # num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)