Skip to content

[Feature Request]: RunInference: Content-Aware Dynamic Batching for NLP/LLM Workloads #37414

@Eliaaazzz

Description

@Eliaaazzz

What would you like to happen?

I’ve been evaluating RunInference for variable-length inputs (NLP/LLMs), and the current batching approach (via BatchElements) feels too coarse for these workloads.

Right now batching is mostly count-based (batch size) or byte-based. That works well for fixed-shape inputs (e.g., images), but for highly variable sequence lengths it creates a few issues:

  1. Padding waste: One long sequence in a batch forces padding for many short sequences, wasting GPU compute on padding tokens.
  2. Unpredictable OOMs: A single outlier (very long input) can spike memory usage for the entire batch and cause hard-to-reproduce OOMs.
  3. Boilerplate for users: I ended up writing a custom DoFn to “bucket by length” before RunInference. This is a common pattern in NLP and would be better supported natively.

Proposed solution

Add content-aware batching to RunInference, so batching can be driven by computational cost rather than only element count.

Conceptually:

  • Cost-based thresholds: Let users provide a cost_fn (e.g., token count) and a max_cost_per_batch (e.g., 4096 tokens) to form batches.

  • Optional bucketing: Buffer elements and group similar lengths to reduce padding overhead.

  • Dynamic padding integration: Ensure the ModelHandler (e.g., PyTorch) pads to the batch max length (per batch), not a global max.

Sketch API (conceptual)

pipeline | RunInference(
model_handler=...,
batching_kwargs={
"mode": "dynamic",
"max_cost": 4096, # e.g., total tokens per batch
"cost_fn": lambda x: len(x), # user-defined cost metric
"bucket": True, # optional
},
)

Alternatives considered

  1. GroupIntoBatches with weights: Works, but is verbose and separates batching from inference/padding logic.
  2. Static padding to max length: Too inefficient for production latency/cost requirements.

Additional context

I’ve skimmed base.py and pytorch_inference.py. My initial thought is a TokenBasedBatcher (or an extension around BatchElements) that can be reused by RunInference.

I’m interested in working on this for GSoC 2026 and can draft a design Any pointers to existing work or preferred direction would be appreciated.

Issue Priority

Priority: 2 (default / most feature requests should be filed as P2)

Issue Components

  • Component: Python SDK
  • Component: Java SDK
  • Component: Go SDK
  • Component: Typescript SDK
  • Component: IO connector
  • Component: Beam YAML
  • Component: Beam examples
  • Component: Beam playground
  • Component: Beam katas
  • Component: Website
  • Component: Infrastructure
  • Component: Spark Runner
  • Component: Flink Runner
  • Component: Samza Runner
  • Component: Twister2 Runner
  • Component: Hazelcast Jet Runner
  • Component: Google Cloud Dataflow Runner

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions