Skip to content

Commit c256340

Browse files
committed
fix tvm-ffi
1 parent aeb7970 commit c256340

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tritonbench/operators/launch_latency/operator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,22 @@ def nop_cutedsl(self, *args):
9393
cute_args = cute_args[:-5]
9494
return lambda: kernel(*cute_args)
9595

96+
@register_benchmark(enabled=HAS_CUTEDSL)
97+
def nop_cutedsl_tvm_ffi(self, *args):
98+
if len(args) == 0:
99+
kernel = cute.compile(cutedsl_nop_kernel)
100+
return lambda: kernel()
101+
cute_args = []
102+
for arg in args:
103+
if isinstance(arg, torch.Tensor):
104+
cute_args.append(cute.runtime.from_dlpack(arg, enable_tvm_ffi=True))
105+
else:
106+
cute_args.append(arg)
107+
kernel = cute.compile(cutedsl_nop_with_args_kernel, *cute_args, options="--enable-tvm-ffi")
108+
# remove constexpr args
109+
cute_args = cute_args[:-5]
110+
return lambda: kernel(*cute_args)
111+
96112
@register_benchmark(baseline=True)
97113
def nop_python_function(self, *args):
98114
def nop():

0 commit comments

Comments
 (0)