@@ -182,15 +182,25 @@ def parse_op_args(args: List[str]):
182182 return parser .parse_args (args )
183183
184184
185+ def unpack_inputs (* args ):
186+ inputs = args
187+ if len (args ) == 1 and isinstance (args [0 ], xformers_fmha .Inputs ):
188+ inp = args [0 ]
189+ inputs = (inp .query , inp .key , inp .value )
190+ return (t .detach () for t in inputs )
191+
192+
185193def multi_input_wrapper (fn ):
186194 def wrapper (self , * args ):
187195 preproc_fn , benchmark_fn = fn (self , * args )
188196 arg_len = len (args )
189197 assert arg_len % 3 == 0
190198 inputs = []
199+ all_inputs = []
191200 for i in range (0 , arg_len , 3 ):
192201 q , k , v = args [i : i + 3 ]
193202 inp = preproc_fn (q , k , v )
203+ all_inputs += [* unpack_inputs (* inp )]
194204 inputs .append (inp )
195205
196206 def multi_input_fn ():
@@ -199,6 +209,8 @@ def multi_input_fn():
199209 outputs .append (benchmark_fn (* i ))
200210 return outputs
201211
212+ self .optims [multi_input_fn ] = torch .optim .SGD (all_inputs )
213+
202214 return multi_input_fn
203215
204216 wrapper .__name__ = fn .__name__
@@ -276,6 +288,7 @@ def __init__(
276288 self .sm_scale = args .sm_scale if args .sm_scale else 1.0 / math .sqrt (self .D_HEAD )
277289 self .deterministic = args .deterministic
278290 self .gen_cache_size_inputs = args .gen_cache_size_inputs
291+ self .optims = {}
279292
280293 @register_benchmark ()
281294 @multi_input_wrapper
@@ -573,8 +586,14 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
573586 o = fwd_fn ()
574587 outputs = [input_filter (lambda x : isinstance (x , torch .Tensor ), o_ ) for o_ in o ]
575588 dOs = [torch .rand_like (o_ ).detach () for o_ in outputs ]
589+ zero_grad = (
590+ self .optims [fwd_fn ].zero_grad
591+ if fwd_fn in self .optims
592+ else lambda set_to_none : None
593+ )
576594
577595 def fn ():
596+ zero_grad (set_to_none = True )
578597 for (
579598 o_tensor ,
580599 do ,
0 commit comments