Skip to content

Commit 06b914e

Browse files
authored
feat: [AutoDeploy] generalizing cudagraph to multiple dynamic inputs (#3589)
* generalizing cudagraph to multiple dynamic inputs Signed-off-by: Lucas Liebenwein <[email protected]> * fix for failing test Signed-off-by: Lucas Liebenwein <[email protected]> --------- Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 442386d commit 06b914e

File tree

8 files changed

+164
-68
lines changed

8 files changed

+164
-68
lines changed

examples/auto_deploy/build_and_run_ad.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,12 @@ def main(config: Optional[SimpleConfig] = None):
9696
print_outputs(outs)
9797

9898
# run a benchmark for the model with batch_size == config.benchmark_bs
99-
if config.benchmark:
99+
if config.benchmark and config.runtime != "demollm":
100+
ad_logger.warning(
101+
f"Benchmarking with {config.runtime=} not supported. Please use `demollm` instead for "
102+
"quick benchmarking and `trtllm-bench` for full benchmarking."
103+
)
104+
elif config.benchmark:
100105
token_ids = torch.randint(0, 100, (config.benchmark_bs, config.benchmark_isl)).tolist()
101106
sampling_params = SamplingParams(max_tokens=config.benchmark_osl, top_k=None)
102107
keys = ["compile_backend", "attn_backend", "mla_backend"]

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_opt.py

+67-31
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,22 @@
1515

1616
class CompiledGraph(nn.Module):
1717
def __init__(
18-
self, model: GraphModule, max_batch_size: int, cuda_graph_batch_sizes: List[int] = None
18+
self,
19+
model: GraphModule,
20+
max_batch_size: int,
21+
cuda_graph_batch_sizes: List[int] = None,
22+
num_batched_inputs: Optional[int] = 1, # number of batched, dynamic inputs...
1923
):
2024
super().__init__()
2125
self._in_spec: TreeSpec = model._in_spec
2226
self._out_spec: TreeSpec = model._out_spec
2327
self.gm_compiled = torch.compile(model, dynamic=True)
2428
self.max_batch_size = max_batch_size
29+
self.num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
2530
self.graphs: Dict[Tuple[int, ...], CUDAGraph] = {}
26-
self._input_buffer: torch.Tensor = torch.empty(0, 1)
31+
self._input_buffers: List[torch.Tensor] = [
32+
torch.empty(0, 1) for _ in range(self.num_batched_inputs)
33+
]
2734
self._out_buffer_flat: List[torch.Tensor] = None
2835
self._args_hash: Optional[Tuple[int, ...]] = None
2936
self.cuda_graph_batch_sizes = (
@@ -42,6 +49,10 @@ def round_up_to_closest(batch_sizes: Iterable[int], bs: int) -> Optional[int]:
4249
return None
4350
return min(batch_sizes, key=lambda x: (x < bs, abs(x - bs)), default=None)
4451

52+
def round_to_cuda_batch_size(self, bs: int) -> int:
53+
"""Round batch size to the nearest cuda batch size."""
54+
return self.round_up_to_closest(self.cuda_graph_batch_sizes, bs)
55+
4556
def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
4657
"""Capture and return one cuda graph."""
4758
# warm-up
@@ -78,17 +89,31 @@ def _get_graph_batch_sizes(
7889
# return as sorted list
7990
return sorted(batch_sizes)
8091

81-
def _capture_cudagraph(self, input_t: torch.Tensor, flat_args: List[Any]):
82-
"""Capture graph for variable batch size."""
83-
# set the args hash --> this is used to compare the inputs during graph replay
84-
self._args_hash = self._get_hash(flat_args)
92+
def capture_graph(self, *args, **kwargs):
93+
"""Capture and pre-fetch the graph for variable batch size."""
94+
# flatten args, kwargs
95+
all_args_flat = _flatten_args(self._in_spec, *args, **kwargs)
96+
97+
# extract the batched input tensors
98+
args_batched = all_args_flat[: self.num_batched_inputs]
99+
args_static = all_args_flat[self.num_batched_inputs :]
85100

86-
# set the input buffer to the max needed batch size with rest of shape as is
87-
assert self.max_batch_size >= input_t.shape[0], "Max batch size too small."
88-
self._input_buffer = input_t[:1].repeat_interleave(self.max_batch_size, dim=0)
101+
# set the args hash --> this is used to compare the static inputs during graph replay
102+
self._args_hash = self._get_hash(args_static)
89103

90-
# unflatten args, kwargs
91-
args, kwargs = self._in_spec.unflatten([self._input_buffer] + flat_args)
104+
# sanity checks on the batched inputs
105+
msg_bs = "Max batch size too small."
106+
msg_ndim = "Expecting at least a 2D for batched input tensors."
107+
assert all(self.max_batch_size >= input.shape[0] for input in args_batched), msg_bs
108+
assert all(input.ndim > 1 for input in args_batched), msg_ndim
109+
110+
# repeat the batched input tensors to the max batch size
111+
self._input_buffers = [
112+
input[:1].repeat_interleave(self.max_batch_size, dim=0) for input in args_batched
113+
]
114+
115+
# create new args, kwargs with the input buffers and static args
116+
args, kwargs = self._in_spec.unflatten(self._input_buffers + args_static)
92117

93118
# capture output once with max batch size to capture output buffers
94119
with CudaGraphWarmUpPhase():
@@ -101,35 +126,46 @@ def _capture_cudagraph(self, input_t: torch.Tensor, flat_args: List[Any]):
101126
ad_logger.info(f"Capturing graph for batch size: {bs}")
102127

103128
# setup args, kwargs
104-
input_truncated = self._input_buffer[:bs]
105-
args, kwargs = self._in_spec.unflatten([input_truncated, *flat_args])
129+
inputs_truncated = [in_buffer[:bs] for in_buffer in self._input_buffers]
130+
args, kwargs = self._in_spec.unflatten(inputs_truncated + args_static)
106131

107-
# capture graph
108-
self.graphs[input_truncated.shape] = self._capture_one_graph(*args, **kwargs)
109-
110-
def capture_graph(self, *args, **kwargs):
111-
"""Capture and pre-fetch the graph."""
112-
input_t, flat_args = _flatten_args(self._in_spec, *args, **kwargs)
113-
self._capture_cudagraph(input_t, flat_args)
132+
# capture graph for truncated inputs
133+
combined_shape = sum((input.shape for input in inputs_truncated), start=())
134+
self.graphs[combined_shape] = self._capture_one_graph(*args, **kwargs)
114135

115136
def forward(self, *args, **kwargs) -> Any:
116137
"""Run the compiled graph."""
117-
input_t, flat_args = _flatten_args(self._in_spec, *args, **kwargs)
118-
bs, *other_dims = input_t.shape
138+
# flatten args, kwargs
139+
all_args_flat = _flatten_args(self._in_spec, *args, **kwargs)
119140

120-
# round up batch size and construct rounded up shape
121-
bs_graph = self.round_up_to_closest([shapes[0] for shapes in self.graphs.keys()], bs)
122-
shape_rounded_up = (bs_graph, *other_dims)
141+
# extract the batched input tensors
142+
args_batched = all_args_flat[: self.num_batched_inputs]
143+
args_static = all_args_flat[self.num_batched_inputs :]
123144

124-
# regular forward for non-matching shapes or non-matching flat_args
125-
if shape_rounded_up not in self.graphs or self._args_hash != self._get_hash(flat_args):
145+
# check if args_static match the stored hash
146+
if self._args_hash != self._get_hash(args_static):
126147
return self.gm_compiled(*args, **kwargs)
127148

149+
# Calculate rounded-up shapes for each input
150+
rounded_shapes = [
151+
(self.round_to_cuda_batch_size(input.shape[0]),) + input.shape[1:]
152+
for input in args_batched
153+
]
154+
combined_shape = sum(rounded_shapes, start=())
155+
156+
# regular forward for non-matching shapes
157+
if combined_shape not in self.graphs:
158+
return self.gm_compiled(*args, **kwargs)
159+
160+
# copy inputs to input buffers
161+
for i, input_tensor in enumerate(args_batched):
162+
self._input_buffers[i][: input_tensor.shape[0]] = input_tensor
163+
128164
# run forward pass via graph
129-
self._input_buffer[:bs] = input_t
130-
self.graphs[shape_rounded_up].replay()
165+
self.graphs[combined_shape].replay()
131166

132167
# retrieve output from buffer, cut to batch size, and unflatten
168+
bs = args_batched[0].shape[0]
133169
out_flat = [o_b[:bs].detach().clone() for o_b in self._out_buffer_flat]
134170
return self._out_spec.unflatten(out_flat)
135171

@@ -138,11 +174,11 @@ def forward(self, *args, **kwargs) -> Any:
138174
class TorchOptCompiler(BackendCompiler):
139175
@torch.inference_mode()
140176
def compile(self) -> CompiledGraph:
141-
cuda_graph_batch_sizes = self.compiler_kwargs.get("cuda_graph_batch_sizes", None)
142177
compiled_gm = CompiledGraph(
143178
self.gm,
144179
max_batch_size=self.max_batch_size,
145-
cuda_graph_batch_sizes=cuda_graph_batch_sizes,
180+
cuda_graph_batch_sizes=self.compiler_kwargs.get("cuda_graph_batch_sizes"),
181+
num_batched_inputs=self.compiler_kwargs.get("num_batched_inputs"),
146182
)
147183

148184
# try capturing cudagraph

tensorrt_llm/_torch/auto_deploy/compile/compiler.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from abc import ABC, abstractmethod
88
from typing import Any, Dict, List, Optional, Tuple, Type
99

10-
import torch
1110
import torch.nn as nn
1211
from torch.fx import GraphModule
1312
from torch.fx._pytree import tree_flatten_spec
@@ -16,12 +15,10 @@
1615
from ..utils.logger import ad_logger
1716

1817

19-
def _flatten_args(in_spec, *args, **kwargs) -> Tuple[torch.Tensor, List[Any]]:
18+
def _flatten_args(in_spec, *args, **kwargs) -> List[Any]:
2019
"""Flatten inputs from in_spec where we assume the first input is the main input tensor."""
2120
all_args: PyTree = (args, kwargs)
22-
input_t, *flat_args = tree_flatten_spec(all_args, in_spec)
23-
assert input_t.ndim > 1, "Expecting at least a 2D input tensor."
24-
return input_t, flat_args
21+
return tree_flatten_spec(all_args, in_spec)
2522

2623

2724
class BackendRegistry:
@@ -66,8 +63,9 @@ def __init__(
6663
if self.dynamic_shapes is not None and 0 in self.dynamic_shapes[0]:
6764
self.max_batch_size = self.dynamic_shapes[0][0].max
6865
else:
69-
idxs, *_ = _flatten_args(self.gm._in_spec, *self.args, **self.kwargs)
70-
self.max_batch_size = idxs.shape[0]
66+
# NOTE: we assume the first input is the main input tensor with batch dimension
67+
batched_input, *_ = _flatten_args(self.gm._in_spec, *self.args, **self.kwargs)
68+
self.max_batch_size = batched_input.shape[0]
7169

7270
@abstractmethod
7371
def compile(self) -> nn.Module:

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,10 @@ def __call__(self, cm: CachedSequenceInterface) -> GraphModule:
170170
############################################################################################
171171

172172
cm.info._set_generate_only_batch()
173-
compiler_kwargs = {"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes}
173+
compiler_kwargs = {
174+
"cuda_graph_batch_sizes": self.ad_config.cuda_graph_batch_sizes,
175+
"num_batched_inputs": 1, # TODO (lucaslie): improve once we have a config system...
176+
}
174177
egm_compiled = compile_and_capture(
175178
egm,
176179
self.compile_backend,

tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"model_kwargs": {"num_hidden_layers": 2},
2727
},
2828
),
29-
# small llama3.1-8B model with world_size 2 + trtllm runtime
29+
# small llama3.1-8B model with world_size 2 + trtllm runtime + torch-opt
3030
(
3131
2,
3232
{
@@ -36,7 +36,7 @@
3636
),
3737
"runtime": "trtllm",
3838
"attn_backend": "TritonWithFlattenedInputs",
39-
"compile_backend": "torch-simple",
39+
"compile_backend": "torch-opt",
4040
"model_kwargs": {"num_hidden_layers": 2},
4141
},
4242
),

tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_torch_opt.py

+76-24
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@
1111
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm
1212

1313

14+
class ModelWithMultipleInputs(torch.nn.Module):
15+
def __init__(self, base_model):
16+
super().__init__()
17+
self.base_model = base_model
18+
19+
def forward(self, x0, x1=None, x2=None):
20+
out = self.base_model(x0)
21+
if x1 is not None:
22+
out = out + self.base_model(x1)
23+
if x2 is not None:
24+
out = out + self.base_model(x2)
25+
return out
26+
27+
1428
# Using pytest.mark.parametrize to test multiple cases
1529
@pytest.mark.parametrize(
1630
"lst, value, expected",
@@ -31,60 +45,98 @@ def test_round_up_to_closest(lst, value, expected):
3145
assert CompiledGraph.round_up_to_closest(lst, value) == expected
3246

3347

48+
@pytest.mark.parametrize("num_inputs", [1, 2, 3])
3449
@pytest.mark.parametrize(
35-
"model_type, model_cls, input_shape, captured_shape_fn, atol",
50+
"model_type, model_cls, input_shape, atol",
3651
[
37-
("llm", TransformerLikeModel, (32, 10), lambda b, s: (b, s), 1e-5),
38-
("vit", VisionTransformerLikeModel, (32, 4096, 16), lambda b, s, c: (b, s, c), 1e-3),
52+
("llm", TransformerLikeModel, (32, 10), 1e-5),
53+
("vit", VisionTransformerLikeModel, (32, 4096, 16), 1e-3),
3954
],
4055
)
41-
def test_cudagraph_capture_replay(model_type, model_cls, input_shape, captured_shape_fn, atol):
56+
def test_cudagraph_capture_replay(model_type, model_cls, input_shape, atol, num_inputs):
4257
batch_size, *seq_shape = input_shape
4358

4459
if model_type == "llm":
4560
vocab_size = 100 # Vocabulary size
4661
embed_dim = 32 # Embedding dimension
4762
hidden_dim = 64 # Hidden layer dimension
48-
model = model_cls(vocab_size, embed_dim, hidden_dim).to("cuda")
49-
input_data = torch.randint(0, vocab_size, input_shape).to("cuda")
50-
captured_shape = captured_shape_fn(batch_size, seq_shape[0])
63+
base_model = model_cls(vocab_size, embed_dim, hidden_dim).to("cuda")
64+
model = ModelWithMultipleInputs(base_model).to("cuda")
65+
66+
# Create inputs for the model
67+
input_data = [
68+
torch.randint(0, vocab_size, input_shape).to("cuda") for _ in range(num_inputs)
69+
]
5170

5271
elif model_type == "vit":
5372
channels = 16 # Number of channels
5473
hidden_dim = 64 # Hidden layer dimension
55-
model = model_cls(channels, hidden_dim).to("cuda")
56-
input_data = torch.randn(*input_shape).to("cuda")
57-
captured_shape = captured_shape_fn(batch_size, seq_shape[0], channels)
74+
base_model = model_cls(channels, hidden_dim).to("cuda")
75+
model = ModelWithMultipleInputs(base_model).to("cuda")
76+
77+
# Create inputs for the model
78+
input_data = [torch.randn(*input_shape).to("cuda") for _ in range(num_inputs)]
79+
80+
combined_shape = input_shape * num_inputs
5881

5982
model.eval()
60-
dynamic_shapes = generate_dynamic_shapes(batch_size, seq_shape[0])
61-
graph_module = torch_export_to_gm(model, args=(input_data,), dynamic_shapes=dynamic_shapes)
62-
compiled_model = CompiledGraph(graph_module, max_batch_size=batch_size)
83+
dynamic_shapes = generate_dynamic_shapes(batch_size, seq_shape[0]) * num_inputs
84+
85+
# Prepare args - include only the number of inputs needed
86+
args = tuple(input_data[:num_inputs])
87+
print(args)
88+
print(dynamic_shapes)
89+
90+
graph_module = torch_export_to_gm(model, args=args, dynamic_shapes=dynamic_shapes)
91+
compiled_model = CompiledGraph(
92+
graph_module, max_batch_size=batch_size, num_batched_inputs=num_inputs
93+
)
6394

6495
with torch.inference_mode():
65-
full_args = (input_data,)
66-
compiled_model.capture_graph(*full_args)
96+
# Capture graph with all inputs
97+
compiled_model.capture_graph(*args)
6798

68-
# Ensure the graph is stored for the batch size
69-
assert captured_shape in compiled_model.graphs, "Graph for batch size was not captured."
99+
# Ensure the graph is stored for the combined shape of all inputs
100+
assert combined_shape in compiled_model.graphs, (
101+
f"Graph for combined shape {combined_shape} was not captured."
102+
)
103+
104+
# Create smaller inputs for replay
105+
if model_type == "llm":
106+
replay_input_data = [x[:, :1] for x in input_data[:num_inputs]]
107+
else: # vit
108+
replay_input_data = [x[:, :1, :] for x in input_data[:num_inputs]]
109+
110+
# Prepare replay args - include only the number of inputs needed
111+
replay_args = tuple(replay_input_data)
112+
113+
# Get flat inputs for manual replay
114+
all_args_flat = _flatten_args(compiled_model._in_spec, *replay_args)
115+
input_args_flat = all_args_flat[:num_inputs] # Extract just the batched inputs
70116

71-
input_data_replay = input_data[:, :1] if model_type == "llm" else input_data[:, :1, :]
117+
# Update input buffers for replay
118+
for i, input_tensor in enumerate(input_args_flat):
119+
compiled_model._input_buffers[i][: input_tensor.shape[0]] = input_tensor
72120

73-
graph = compiled_model.graphs[captured_shape]
74-
input_data_flatten, _ = _flatten_args(compiled_model._in_spec, input_data_replay)
75-
compiled_model._input_buffer[:] = input_data_flatten # Update input buffer
121+
# Get the appropriate graph and replay
122+
graph = compiled_model.graphs[combined_shape]
76123
graph.replay()
77124

125+
# Get output from manual replay
78126
replay_output = compiled_model._out_spec.unflatten(
79127
[buf[:batch_size].detach().clone() for buf in compiled_model._out_buffer_flat]
80128
)
81-
replay_output2 = compiled_model.forward(input_data_replay)
129+
130+
# Get output from forward method
131+
replay_output2 = compiled_model.forward(*replay_args)
132+
133+
# Compare outputs
82134
assert torch.allclose(replay_output, replay_output2, atol=atol), (
83135
"CUDAGraph replay output mismatch"
84136
)
85137

86-
original_output = compiled_model.gm_compiled(input_data_replay)
87-
138+
# Compare with original model output
139+
original_output = compiled_model.gm_compiled(*replay_args)
88140
assert torch.allclose(original_output, replay_output, atol=atol), (
89141
"CUDAGraph replay output mismatch"
90142
)

tests/unittest/_torch/auto_deploy/unit/singlegpu/models/test_deepseek_patches.py

+2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def _generate_ds_attention_mask(b, s):
5959
)
6060

6161

62+
# TODO (svelury): update unit test to run fast
63+
@pytest.mark.skip(reason="TODO: too slow for a unit test")
6264
@pytest.mark.parametrize(
6365
"model_name, module_name, patch, yarn, inputs",
6466
[

0 commit comments

Comments
 (0)