Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tritonbench/operators/blackwell_attentions/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,25 @@ def parse_op_args(args: List[str]):
return parser.parse_args(args)


def unpack_inputs(*args):
inputs = args
if len(args) == 1 and isinstance(args[0], xformers_fmha.Inputs):
inp = args[0]
inputs = (inp.query, inp.key, inp.value)
return (t.detach() for t in inputs)


def multi_input_wrapper(fn):
def wrapper(self, *args):
preproc_fn, benchmark_fn = fn(self, *args)
arg_len = len(args)
assert arg_len % 3 == 0
inputs = []
all_inputs = []
for i in range(0, arg_len, 3):
q, k, v = args[i : i + 3]
inp = preproc_fn(q, k, v)
all_inputs += [*unpack_inputs(*inp)]
inputs.append(inp)

def multi_input_fn():
Expand All @@ -199,6 +209,8 @@ def multi_input_fn():
outputs.append(benchmark_fn(*i))
return outputs

self.optims[multi_input_fn] = torch.optim.SGD(all_inputs)

return multi_input_fn

wrapper.__name__ = fn.__name__
Expand Down Expand Up @@ -276,6 +288,7 @@ def __init__(
self.sm_scale = args.sm_scale if args.sm_scale else 1.0 / math.sqrt(self.D_HEAD)
self.deterministic = args.deterministic
self.gen_cache_size_inputs = args.gen_cache_size_inputs
self.optims = {}

@register_benchmark()
@multi_input_wrapper
Expand Down Expand Up @@ -573,8 +586,14 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
o = fwd_fn()
outputs = [input_filter(lambda x: isinstance(x, torch.Tensor), o_) for o_ in o]
dOs = [torch.rand_like(o_).detach() for o_ in outputs]
zero_grad = (
self.optims[fwd_fn].zero_grad
if fwd_fn in self.optims
else lambda set_to_none: None
)

def fn():
zero_grad(set_to_none=True)
for (
o_tensor,
do,
Expand Down