diff --git a/tritonbench/operators/blackwell_attentions/operator.py b/tritonbench/operators/blackwell_attentions/operator.py index 784a8e527..2f5eb52f9 100644 --- a/tritonbench/operators/blackwell_attentions/operator.py +++ b/tritonbench/operators/blackwell_attentions/operator.py @@ -182,15 +182,25 @@ def parse_op_args(args: List[str]): return parser.parse_args(args) +def unpack_inputs(*args): + inputs = args + if len(args) == 1 and isinstance(args[0], xformers_fmha.Inputs): + inp = args[0] + inputs = (inp.query, inp.key, inp.value) + return (t.detach() for t in inputs) + + def multi_input_wrapper(fn): def wrapper(self, *args): preproc_fn, benchmark_fn = fn(self, *args) arg_len = len(args) assert arg_len % 3 == 0 inputs = [] + all_inputs = [] for i in range(0, arg_len, 3): q, k, v = args[i : i + 3] inp = preproc_fn(q, k, v) + all_inputs += [*unpack_inputs(*inp)] inputs.append(inp) def multi_input_fn(): @@ -199,6 +209,8 @@ def multi_input_fn(): outputs.append(benchmark_fn(*i)) return outputs + self.optims[multi_input_fn] = torch.optim.SGD(all_inputs) + return multi_input_fn wrapper.__name__ = fn.__name__ @@ -276,6 +288,7 @@ def __init__( self.sm_scale = args.sm_scale if args.sm_scale else 1.0 / math.sqrt(self.D_HEAD) self.deterministic = args.deterministic self.gen_cache_size_inputs = args.gen_cache_size_inputs + self.optims = {} @register_benchmark() @multi_input_wrapper @@ -573,8 +586,14 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable: o = fwd_fn() outputs = [input_filter(lambda x: isinstance(x, torch.Tensor), o_) for o_ in o] dOs = [torch.rand_like(o_).detach() for o_ in outputs] + zero_grad = ( + self.optims[fwd_fn].zero_grad + if fwd_fn in self.optims + else lambda set_to_none: None + ) def fn(): + zero_grad(set_to_none=True) for ( o_tensor, do,