-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Description
What would you like to happen?
I’ve been evaluating RunInference for variable-length inputs (NLP/LLMs), and the current batching approach (via BatchElements) feels too coarse for these workloads.
Right now batching is mostly count-based (batch size) or byte-based. That works well for fixed-shape inputs (e.g., images), but for highly variable sequence lengths it creates a few issues:
- Padding waste: One long sequence in a batch forces padding for many short sequences, wasting GPU compute on padding tokens.
- Unpredictable OOMs: A single outlier (very long input) can spike memory usage for the entire batch and cause hard-to-reproduce OOMs.
- Boilerplate for users: I ended up writing a custom DoFn to “bucket by length” before RunInference. This is a common pattern in NLP and would be better supported natively.
Proposed solution
Add content-aware batching to RunInference, so batching can be driven by computational cost rather than only element count.
Conceptually:
-
Cost-based thresholds: Let users provide a cost_fn (e.g., token count) and a max_cost_per_batch (e.g., 4096 tokens) to form batches.
-
Optional bucketing: Buffer elements and group similar lengths to reduce padding overhead.
-
Dynamic padding integration: Ensure the ModelHandler (e.g., PyTorch) pads to the batch max length (per batch), not a global max.
Sketch API (conceptual)
pipeline | RunInference(
model_handler=...,
batching_kwargs={
"mode": "dynamic",
"max_cost": 4096, # e.g., total tokens per batch
"cost_fn": lambda x: len(x), # user-defined cost metric
"bucket": True, # optional
},
)
Alternatives considered
- GroupIntoBatches with weights: Works, but is verbose and separates batching from inference/padding logic.
- Static padding to max length: Too inefficient for production latency/cost requirements.
Additional context
I’ve skimmed base.py and pytorch_inference.py. My initial thought is a TokenBasedBatcher (or an extension around BatchElements) that can be reused by RunInference.
I’m interested in working on this for GSoC 2026 and can draft a design Any pointers to existing work or preferred direction would be appreciated.
Issue Priority
Priority: 2 (default / most feature requests should be filed as P2)
Issue Components
- Component: Python SDK
- Component: Java SDK
- Component: Go SDK
- Component: Typescript SDK
- Component: IO connector
- Component: Beam YAML
- Component: Beam examples
- Component: Beam playground
- Component: Beam katas
- Component: Website
- Component: Infrastructure
- Component: Spark Runner
- Component: Flink Runner
- Component: Samza Runner
- Component: Twister2 Runner
- Component: Hazelcast Jet Runner
- Component: Google Cloud Dataflow Runner