-
Notifications
You must be signed in to change notification settings - Fork 4.5k
[RunInference] Add content-aware dynamic batching via element_size_fn… #37428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[RunInference] Add content-aware dynamic batching via element_size_fn… #37428
Conversation
Summary of ChangesHello @Eliaaazzz, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
Assigning reviewers: R: @shunping for label python. Note: If you would like to opt out of this review, comment Available commands:
The PR bot will only process comments in the main thread (not review comments). |
83ced41 to
9026ad0
Compare
|
Hi @damccorm, just a heads-up that the PR is ready and all CI checks have passed. I've implemented the weighted BatchElements approach as discussed with Robert. Ready for review whenever you have a chance. Thanks! |
| inference result in order from first applied to last applied.""" | ||
| return _PostProcessingModelHandler(self, fn) | ||
|
|
||
| def with_element_size_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is a little odd to have element_size batch args as a with function, while our other batching args exist as args to the model handlers themselves (e.g.
| min_batch_size: Optional[int] = None, |
I think we would probably be better off adding this as a direct argument to our built in model handlers (which also creates an easy pattern for others to follow when building their model handlers). That way we're consistent and there is just one way to add batching args. Thoughts?
- Added 'max_batch_weight' and 'element_size_fn' to __init__ of all ModelHandlers (PyTorch, Sklearn, TF, ONNX, XGBoost, TensorRT, Hugging Face, vLLM, VertexAI). - Updated subclasses to delegate these args to 'super().__init__' or internal batching kwargs. - Removed 'with_element_size_fn' builder method from base class to enforce API consistency. - Updated tests to reflect the new API signature.
9026ad0 to
b5a4e7a
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #37428 +/- ##
============================================
+ Coverage 35.95% 35.97% +0.02%
Complexity 1676 1676
============================================
Files 1062 1062
Lines 166033 165889 -144
Branches 1208 1208
============================================
- Hits 59690 59685 -5
+ Misses 104162 104023 -139
Partials 2181 2181
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
[RunInference] Add content-aware dynamic batching via element_size_fn (Issue #37414)
Rationale
This PR addresses #37414 by introducing content-aware dynamic batching to
RunInference.Currently,
RunInferencerelies onBatchElementswith a strict count-based limit (max_batch_size). However, for workloads like NLP and LLMs, variable-length inputs (tokens) lead to significant variance in computational cost and memory usage. A batch of 10 short sentences is vastly different from a batch of 10 long documents.This change allows users to provide a custom
element_size_fntoModelHandler, which is then passed down to the underlyingBatchElementstransform. This enables batching based on total "weight" (e.g., token count) rather than just element count, improving GPU utilization and preventing OOM errors.Design Principles
This implementation prioritizes modularity and type safety through the following design choices:
Decorator Pattern (Composition over Inheritance):
Implemented
_SizingModelHandleras a wrapper to dynamically attach sizing behavior to anyModelHandlerimplementation. This avoids the combinatorial explosion of subclasses (e.g.,TFModelHandlerWithSizing,PyTorchModelHandlerWithSizing) and keeps the codebase DRY.Open-Closed Principle (OCP):
The change is strictly additive. The base
ModelHandlerremains closed for modification, ensuring zero regression risk for existing pipelines. Functionality is extended purely by overridingbatch_elements_kwargsin the wrapper and safely delegating all other methods to the base instance.Architectural Consistency:
The implementation mirrors the existing
_PreProcessingModelHandlerpattern in Apache Beam. This ensures API consistency and reduces cognitive load for maintainers.Changes
apache_beam/ml/inference/base.py:with_element_size_fnmethod to theModelHandlerinterface._SizingModelHandlerwrapper class.batch_elements_kwargsto injectelement_size_fnwhile preserving existing configuration (using safe dictionary copy).ModelHandlermethods (e.g.,update_model_paths,get_metrics_namespace) to ensure transparency.Usage Example
Testing
Comprehensive tests were added in
sdks/python/apache_beam/ml/inference/base_test.py:test_kwargs_are_passed_correctlyelement_size_fnis correctly injected intobatch_elements_kwargs.test_batch_elements_integration_with_beam_pipelinemax_batch_sizeset to 10.BatchElementscorrectly creates batches of 2 elements (5+5=10), confirming the dynamic sizing logic.test_element_size_fn_wrapper_delegates_correctlyModelHandlermethods are properly delegated (Critical for features like model updates).test_multiple_wrappers_can_be_chainedwith_preprocess_fn.Thank you for your contribution! Follow this checklist to help us incorporate your contribution quickly and easily:
[x] fixes #37414