Skip to content

Commit 627f73b

Browse files
authored
[launch-latency] Add tvm-ffi to cutedsl (#674)
1 parent 22f7e44 commit 627f73b

File tree

5 files changed

+40
-5
lines changed

5 files changed

+40
-5
lines changed

.gitmodules

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,3 @@
1616
[submodule "submodules/aiter"]
1717
path = submodules/aiter
1818
url = https://github.com/ROCm/aiter.git
19-
[submodule "submodules/quack"]
20-
path = submodules/quack
21-
url = https://github.com/Dao-AILab/quack.git

submodules/quack

Lines changed: 0 additions & 1 deletion
This file was deleted.

tools/quack/install.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,32 @@
11
import os
2+
import shutil
23
import subprocess
34

45
from pathlib import Path
56

67

78
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent
89
CURRENT_DIR = Path(os.path.abspath(__file__)).parent
9-
QUACK_PATH = REPO_PATH.joinpath("submodules", "quack")
10+
11+
QUACK_REPO = "https://github.com/Dao-AILab/quack.git"
12+
QUACK_SHA = "bceb632dbac9bb0b55d48a7ed3ad204bd952fcb2"
13+
14+
QUACK_INSTALL_PATH = REPO_PATH.joinpath(".install")
1015

1116

1217
def install_quack():
1318
cmd = ["pip", "install", "-e", "."]
1419
subprocess.check_call(cmd, cwd=QUACK_PATH)
20+
21+
22+
def install_quack():
23+
QUACK_INSTALL_PATH.mkdir(parents=True, exist_ok=True)
24+
quack_path = QUACK_INSTALL_PATH.joinpath("quack")
25+
if quack_path.exists():
26+
shutil.rmtree(quack_path)
27+
git_clone_cmd = ["git", "clone", QUACK_REPO]
28+
subprocess.check_call(git_clone_cmd, cwd=QUACK_INSTALL_PATH)
29+
git_checkout_cmd = ["git", "checkout", QUACK_SHA]
30+
subprocess.check_call(git_checkout_cmd, cwd=quack_path)
31+
install_helion_cmd = ["pip", "install", "-e", ".[dev]"]
32+
subprocess.check_call(install_helion_cmd, cwd=quack_path)

tritonbench/operators/launch_latency/operator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,24 @@ def nop_cutedsl(self, *args):
9393
cute_args = cute_args[:-5]
9494
return lambda: kernel(*cute_args)
9595

96+
@register_benchmark(enabled=HAS_CUTEDSL)
97+
def nop_cutedsl_tvm_ffi(self, *args):
98+
if len(args) == 0:
99+
kernel = cute.compile(cutedsl_nop_kernel)
100+
return lambda: kernel()
101+
cute_args = []
102+
for arg in args:
103+
if isinstance(arg, torch.Tensor):
104+
cute_args.append(cute.runtime.from_dlpack(arg, enable_tvm_ffi=True))
105+
else:
106+
cute_args.append(arg)
107+
kernel = cute.compile(
108+
cutedsl_nop_with_args_kernel, *cute_args, options="--enable-tvm-ffi"
109+
)
110+
# remove constexpr args
111+
cute_args = cute_args[:-5]
112+
return lambda: kernel(*cute_args)
113+
96114
@register_benchmark(baseline=True)
97115
def nop_python_function(self, *args):
98116
def nop():

tritonbench/utils/parser.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,4 +416,7 @@ def get_parser(args=None):
416416
)
417417
if args.isolate:
418418
parser.error("A/B testing is not compatible with --isolate mode")
419+
420+
if args.metrics and "walltime_kineto_trace" in args.metrics and args.repcnt is None:
421+
parser.error("Walltime Kineto trace requires --repcnt to be specified")
419422
return parser

0 commit comments

Comments
 (0)