15
15
16
16
class CompiledGraph (nn .Module ):
17
17
def __init__ (
18
- self , model : GraphModule , max_batch_size : int , cuda_graph_batch_sizes : List [int ] = None
18
+ self ,
19
+ model : GraphModule ,
20
+ max_batch_size : int ,
21
+ cuda_graph_batch_sizes : List [int ] = None ,
22
+ num_batched_inputs : Optional [int ] = 1 , # number of batched, dynamic inputs...
19
23
):
20
24
super ().__init__ ()
21
25
self ._in_spec : TreeSpec = model ._in_spec
22
26
self ._out_spec : TreeSpec = model ._out_spec
23
27
self .gm_compiled = torch .compile (model , dynamic = True )
24
28
self .max_batch_size = max_batch_size
29
+ self .num_batched_inputs = num_batched_inputs if num_batched_inputs is not None else 1
25
30
self .graphs : Dict [Tuple [int , ...], CUDAGraph ] = {}
26
- self ._input_buffer : torch .Tensor = torch .empty (0 , 1 )
31
+ self ._input_buffers : List [torch .Tensor ] = [
32
+ torch .empty (0 , 1 ) for _ in range (self .num_batched_inputs )
33
+ ]
27
34
self ._out_buffer_flat : List [torch .Tensor ] = None
28
35
self ._args_hash : Optional [Tuple [int , ...]] = None
29
36
self .cuda_graph_batch_sizes = (
@@ -42,6 +49,10 @@ def round_up_to_closest(batch_sizes: Iterable[int], bs: int) -> Optional[int]:
42
49
return None
43
50
return min (batch_sizes , key = lambda x : (x < bs , abs (x - bs )), default = None )
44
51
52
+ def round_to_cuda_batch_size (self , bs : int ) -> int :
53
+ """Round batch size to the nearest cuda batch size."""
54
+ return self .round_up_to_closest (self .cuda_graph_batch_sizes , bs )
55
+
45
56
def _capture_one_graph (self , * args , ** kwargs ) -> torch .cuda .CUDAGraph :
46
57
"""Capture and return one cuda graph."""
47
58
# warm-up
@@ -78,17 +89,31 @@ def _get_graph_batch_sizes(
78
89
# return as sorted list
79
90
return sorted (batch_sizes )
80
91
81
- def _capture_cudagraph (self , input_t : torch .Tensor , flat_args : List [Any ]):
82
- """Capture graph for variable batch size."""
83
- # set the args hash --> this is used to compare the inputs during graph replay
84
- self ._args_hash = self ._get_hash (flat_args )
92
+ def capture_graph (self , * args , ** kwargs ):
93
+ """Capture and pre-fetch the graph for variable batch size."""
94
+ # flatten args, kwargs
95
+ all_args_flat = _flatten_args (self ._in_spec , * args , ** kwargs )
96
+
97
+ # extract the batched input tensors
98
+ args_batched = all_args_flat [: self .num_batched_inputs ]
99
+ args_static = all_args_flat [self .num_batched_inputs :]
85
100
86
- # set the input buffer to the max needed batch size with rest of shape as is
87
- assert self .max_batch_size >= input_t .shape [0 ], "Max batch size too small."
88
- self ._input_buffer = input_t [:1 ].repeat_interleave (self .max_batch_size , dim = 0 )
101
+ # set the args hash --> this is used to compare the static inputs during graph replay
102
+ self ._args_hash = self ._get_hash (args_static )
89
103
90
- # unflatten args, kwargs
91
- args , kwargs = self ._in_spec .unflatten ([self ._input_buffer ] + flat_args )
104
+ # sanity checks on the batched inputs
105
+ msg_bs = "Max batch size too small."
106
+ msg_ndim = "Expecting at least a 2D for batched input tensors."
107
+ assert all (self .max_batch_size >= input .shape [0 ] for input in args_batched ), msg_bs
108
+ assert all (input .ndim > 1 for input in args_batched ), msg_ndim
109
+
110
+ # repeat the batched input tensors to the max batch size
111
+ self ._input_buffers = [
112
+ input [:1 ].repeat_interleave (self .max_batch_size , dim = 0 ) for input in args_batched
113
+ ]
114
+
115
+ # create new args, kwargs with the input buffers and static args
116
+ args , kwargs = self ._in_spec .unflatten (self ._input_buffers + args_static )
92
117
93
118
# capture output once with max batch size to capture output buffers
94
119
with CudaGraphWarmUpPhase ():
@@ -101,35 +126,46 @@ def _capture_cudagraph(self, input_t: torch.Tensor, flat_args: List[Any]):
101
126
ad_logger .info (f"Capturing graph for batch size: { bs } " )
102
127
103
128
# setup args, kwargs
104
- input_truncated = self . _input_buffer [ :bs ]
105
- args , kwargs = self ._in_spec .unflatten ([ input_truncated , * flat_args ] )
129
+ inputs_truncated = [ in_buffer [ :bs ] for in_buffer in self . _input_buffers ]
130
+ args , kwargs = self ._in_spec .unflatten (inputs_truncated + args_static )
106
131
107
- # capture graph
108
- self .graphs [input_truncated .shape ] = self ._capture_one_graph (* args , ** kwargs )
109
-
110
- def capture_graph (self , * args , ** kwargs ):
111
- """Capture and pre-fetch the graph."""
112
- input_t , flat_args = _flatten_args (self ._in_spec , * args , ** kwargs )
113
- self ._capture_cudagraph (input_t , flat_args )
132
+ # capture graph for truncated inputs
133
+ combined_shape = sum ((input .shape for input in inputs_truncated ), start = ())
134
+ self .graphs [combined_shape ] = self ._capture_one_graph (* args , ** kwargs )
114
135
115
136
def forward (self , * args , ** kwargs ) -> Any :
116
137
"""Run the compiled graph."""
117
- input_t , flat_args = _flatten_args ( self . _in_spec , * args , ** kwargs )
118
- bs , * other_dims = input_t . shape
138
+ # flatten args, kwargs
139
+ all_args_flat = _flatten_args ( self . _in_spec , * args , ** kwargs )
119
140
120
- # round up batch size and construct rounded up shape
121
- bs_graph = self . round_up_to_closest ([ shapes [ 0 ] for shapes in self .graphs . keys ()], bs )
122
- shape_rounded_up = ( bs_graph , * other_dims )
141
+ # extract the batched input tensors
142
+ args_batched = all_args_flat [: self .num_batched_inputs ]
143
+ args_static = all_args_flat [ self . num_batched_inputs :]
123
144
124
- # regular forward for non-matching shapes or non-matching flat_args
125
- if shape_rounded_up not in self .graphs or self . _args_hash != self ._get_hash (flat_args ):
145
+ # check if args_static match the stored hash
146
+ if self ._args_hash != self ._get_hash (args_static ):
126
147
return self .gm_compiled (* args , ** kwargs )
127
148
149
+ # Calculate rounded-up shapes for each input
150
+ rounded_shapes = [
151
+ (self .round_to_cuda_batch_size (input .shape [0 ]),) + input .shape [1 :]
152
+ for input in args_batched
153
+ ]
154
+ combined_shape = sum (rounded_shapes , start = ())
155
+
156
+ # regular forward for non-matching shapes
157
+ if combined_shape not in self .graphs :
158
+ return self .gm_compiled (* args , ** kwargs )
159
+
160
+ # copy inputs to input buffers
161
+ for i , input_tensor in enumerate (args_batched ):
162
+ self ._input_buffers [i ][: input_tensor .shape [0 ]] = input_tensor
163
+
128
164
# run forward pass via graph
129
- self ._input_buffer [:bs ] = input_t
130
- self .graphs [shape_rounded_up ].replay ()
165
+ self .graphs [combined_shape ].replay ()
131
166
132
167
# retrieve output from buffer, cut to batch size, and unflatten
168
+ bs = args_batched [0 ].shape [0 ]
133
169
out_flat = [o_b [:bs ].detach ().clone () for o_b in self ._out_buffer_flat ]
134
170
return self ._out_spec .unflatten (out_flat )
135
171
@@ -138,11 +174,11 @@ def forward(self, *args, **kwargs) -> Any:
138
174
class TorchOptCompiler (BackendCompiler ):
139
175
@torch .inference_mode ()
140
176
def compile (self ) -> CompiledGraph :
141
- cuda_graph_batch_sizes = self .compiler_kwargs .get ("cuda_graph_batch_sizes" , None )
142
177
compiled_gm = CompiledGraph (
143
178
self .gm ,
144
179
max_batch_size = self .max_batch_size ,
145
- cuda_graph_batch_sizes = cuda_graph_batch_sizes ,
180
+ cuda_graph_batch_sizes = self .compiler_kwargs .get ("cuda_graph_batch_sizes" ),
181
+ num_batched_inputs = self .compiler_kwargs .get ("num_batched_inputs" ),
146
182
)
147
183
148
184
# try capturing cudagraph
0 commit comments