-
Notifications
You must be signed in to change notification settings - Fork 202
[Feature] Implements batching support for batch processing to qwen-image #390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| model, subfolder="vae", local_files_only=local_files_only | ||
| ).to(self.device) | ||
| logger.info("Loaded Qwen-Image VAE successfully") | ||
| self.transformer = QwenImageTransformer2DModel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pipeline fails to construct transformer
The constructor now calls QwenImageTransformer2DModel() without any arguments, but QwenImageTransformer2DModel.__init__ requires an od_config parameter (see vllm_omni/diffusion/models/qwen_image/qwen_image_transformer.py lines 483-500). Instantiating QwenImagePipeline will therefore raise a TypeError before any model weights are loaded or the new batching logic runs, blocking all uses of the pipeline.
Useful? React with 👍 / 👎.
ZJY0516
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m not sure: do you want to batch multiple requests, or batch multiple prompts within a single request?
| self.device | ||
| ) | ||
| self.transformer = QwenImageTransformer2DModel(od_config=od_config) | ||
| logger.info("Loaded Qwen-Image scheduler successfully") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
| ): | ||
| super().__init__() | ||
| self.od_config = od_config | ||
| self.weights_sources = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why we need to change this?
| prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) | ||
| prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) | ||
| prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1) | ||
| prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is copied from diffusers. Why we need to change this?
|
|
||
| # Broadcast timestep to match batch size | ||
| timestep = t.expand(latents.shape[0]).to(device=latents.device, dtype=latents.dtype) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may break teacache
my end to end usecase is to be run the following code: where |
|
The problem is batching may not yield performance gain. 1 prompt per request: 66s |
Since it's already compute bound for 1 request :) |
Considering the case where there's enough CUDA cores available for more than one request, doesn't batching cause under utilization of the GPU and hence reduced throughput? Can processing a single prompt use the full capacity? In the scenario when the serving logic accumulates the request with a good heuristic, I think processing requests in batches provides better throughput. Correct me if I'm wrong on this though, you guys are experts on this :) |
|
I recommend you to read this RFC #290 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think we should update vllm-omni/vllm_omni/diffusion/worker/gpu_worker.py to support multi-request batching at first
def execute_model(self, reqs: list[OmniDiffusionRequest], od_config: OmniDiffusionConfig) -> DiffusionOutput:
"""
Execute a forward pass.
"""
assert self.pipeline is not None
# TODO: dealing with first req for now
req = reqs[0]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made this change locally but just dropping this change here would cause problems with other models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, what are the detailed problems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently model implmenetations need some changes to support batching, changing this will cause inference errors
Thanks. I have performed the benchmark, you can see that from 9:40 until 10 I enabled batch mode (batch_size = 4) and kept a constant load using 16 workers. you can see the response times and queue size too. Then right after the first experiment (from about 10:10) I used the no batch mode and with the same load. Then at the last peak I used batch_size = 8Batching seems to improve stability, reduce queue size and increase memory utilization but does not seem to be improving the throughput But an importatnt thing I want to note is that this section in the code should throw at least a warning when passed multiple requests. It took me some time to find this part in the code and find what was wrong with my inference server: |


PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS (AT THE BOTTOM) HAVE BEEN CONSIDERED.
Purpose
Addresses #388
For now it only implements batching logic in vllm_omni/diffusion/models/qwen_image/pipeline_qwen_image.py
For finalization and enabling user-end usability the this line at vllm_omni/diffusion/worker/gpu_worker.py should also be updated to work with batches. But since making this change would break existing models I didn't make this change. However after batching is implemented for other models, this change is trivial
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.BEFORE SUBMITTING, PLEASE READ https://github.com/vllm-project/vllm-omni/blob/main/CONTRIBUTING.md (anything written below this line will be removed by GitHub Actions)