Skip to content

Commit 077b781

Browse files
authored
Explicitly zero grad in blackwell_attention Triton Bench
Differential Revision: D88412050 Pull Request resolved: #694
1 parent 038d24c commit 077b781

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tritonbench/operators/blackwell_attentions/operator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,15 +182,25 @@ def parse_op_args(args: List[str]):
182182
return parser.parse_args(args)
183183

184184

185+
def unpack_inputs(*args):
186+
inputs = args
187+
if len(args) == 1 and isinstance(args[0], xformers_fmha.Inputs):
188+
inp = args[0]
189+
inputs = (inp.query, inp.key, inp.value)
190+
return (t.detach() for t in inputs)
191+
192+
185193
def multi_input_wrapper(fn):
186194
def wrapper(self, *args):
187195
preproc_fn, benchmark_fn = fn(self, *args)
188196
arg_len = len(args)
189197
assert arg_len % 3 == 0
190198
inputs = []
199+
all_inputs = []
191200
for i in range(0, arg_len, 3):
192201
q, k, v = args[i : i + 3]
193202
inp = preproc_fn(q, k, v)
203+
all_inputs += [*unpack_inputs(*inp)]
194204
inputs.append(inp)
195205

196206
def multi_input_fn():
@@ -199,6 +209,8 @@ def multi_input_fn():
199209
outputs.append(benchmark_fn(*i))
200210
return outputs
201211

212+
self.optims[multi_input_fn] = torch.optim.SGD(all_inputs)
213+
202214
return multi_input_fn
203215

204216
wrapper.__name__ = fn.__name__
@@ -276,6 +288,7 @@ def __init__(
276288
self.sm_scale = args.sm_scale if args.sm_scale else 1.0 / math.sqrt(self.D_HEAD)
277289
self.deterministic = args.deterministic
278290
self.gen_cache_size_inputs = args.gen_cache_size_inputs
291+
self.optims = {}
279292

280293
@register_benchmark()
281294
@multi_input_wrapper
@@ -573,8 +586,14 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
573586
o = fwd_fn()
574587
outputs = [input_filter(lambda x: isinstance(x, torch.Tensor), o_) for o_ in o]
575588
dOs = [torch.rand_like(o_).detach() for o_ in outputs]
589+
zero_grad = (
590+
self.optims[fwd_fn].zero_grad
591+
if fwd_fn in self.optims
592+
else lambda set_to_none: None
593+
)
576594

577595
def fn():
596+
zero_grad(set_to_none=True)
578597
for (
579598
o_tensor,
580599
do,

0 commit comments

Comments
 (0)