Skip to content

Commit 1cacbf8

Browse files
author
Godnight1006
committed
feat: add torch.compile triton self-test
1 parent 5e76c4c commit 1cacbf8

4 files changed

Lines changed: 86 additions & 5 deletions

File tree

API.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ torchruntime.install(["torch", "torchvision<0.20"])
1515

1616
On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).
1717

18+
## Test torch
19+
Run:
20+
`python -m torchruntime test`
21+
22+
To specifically verify `torch.compile` / Triton:
23+
`python -m torchruntime test compile`
24+
1825
## Get device info
1926
You can use the device database built into `torchruntime` for your projects:
2027
```py

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ torchruntime.configure()
4444
```
4545

4646
### (Optional) Step 3. Test torch
47-
Run `python -m torchruntime test` to run a set of tests to check whether the installed version of torch is working correctly.
47+
Run `python -m torchruntime test` to run a set of tests to check whether the installed version of torch is working correctly (including a `torch.compile` / Triton check on CUDA/XPU systems). You can also run `python -m torchruntime test compile` to run only the compile check.
4848

4949
## Customizing packages
5050
By default, `python -m torchruntime install` will install the latest available `torch`, `torchvision` and `torchaudio` suitable on the user's platform.

torchruntime/__main__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def print_usage(entry_command: str):
1010
1111
Commands:
1212
install Install PyTorch packages
13-
test [subcommand] Run tests (subcommands: all, devices, math, functions)
13+
test [subcommand] Run tests (subcommands: all, import, devices, compile, math, functions)
1414
--help Show this help message
1515
1616
Examples:
@@ -20,10 +20,11 @@ def print_usage(entry_command: str):
2020
{entry_command} install --uv torch>=2.0.0 torchaudio
2121
{entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0
2222
23-
{entry_command} test # Runs all tests (import, devices, math, functions)
23+
{entry_command} test # Runs all tests (import, devices, compile, math, functions)
2424
{entry_command} test all # Same as above
2525
{entry_command} test import # Test only import
2626
{entry_command} test devices # Test only devices
27+
{entry_command} test compile # Test torch.compile (Triton)
2728
{entry_command} test math # Test only math
2829
{entry_command} test functions # Test only functions
2930

torchruntime/utils/torch_test/__init__.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import importlib.util
2+
import platform
13
import time
24

3-
from ..torch_device_utils import get_installed_torch_platform, get_device_count, get_device_name, get_device
5+
from ..torch_device_utils import get_device, get_device_count, get_device_name, get_installed_torch_platform
46

57

68
def test(subcommand):
@@ -16,7 +18,7 @@ def test(subcommand):
1618

1719

1820
def test_all():
19-
for fn in (test_import, test_devices, test_math, test_functions):
21+
for fn in (test_import, test_devices, test_compile, test_math, test_functions):
2022
fn()
2123
print("")
2224

@@ -101,3 +103,74 @@ def test_functions():
101103
t.run_all_tests()
102104

103105
print("--- / FUNCTIONAL TEST ---")
106+
107+
108+
def test_compile():
109+
print("--- COMPILE TEST ---")
110+
111+
try:
112+
import torch
113+
except ImportError:
114+
print("torch.compile: SKIPPED (torch not installed)")
115+
print("--- / COMPILE TEST ---")
116+
return
117+
118+
if not hasattr(torch, "compile"):
119+
print("torch.compile: SKIPPED (requires torch>=2.0)")
120+
print("--- / COMPILE TEST ---")
121+
return
122+
123+
torch_platform_name, _ = get_installed_torch_platform()
124+
if torch_platform_name not in ("cuda", "xpu"):
125+
print(f"torch.compile: SKIPPED (unsupported backend: {torch_platform_name})")
126+
print("--- / COMPILE TEST ---")
127+
return
128+
129+
if importlib.util.find_spec("triton") is None:
130+
print("triton: NOT INSTALLED")
131+
else:
132+
print("triton: installed")
133+
134+
device = get_device(0)
135+
print("On torch device:", device)
136+
137+
def f(x):
138+
return x * 2 + 1
139+
140+
try:
141+
compiled_f = torch.compile(f)
142+
x = torch.randn((1024,), device=device)
143+
y = compiled_f(x)
144+
expected = f(x)
145+
if not torch.allclose(y, expected):
146+
print("torch.compile: FAILED (output mismatch)")
147+
else:
148+
if torch_platform_name == "cuda":
149+
torch.cuda.synchronize()
150+
if torch_platform_name == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "synchronize"):
151+
torch.xpu.synchronize()
152+
print("torch.compile: PASSED")
153+
except Exception as e:
154+
print(f"torch.compile: FAILED ({type(e).__name__}: {e})")
155+
156+
hint = None
157+
os_name = platform.system()
158+
if torch_platform_name == "cuda" and os_name == "Windows":
159+
hint = "pip install triton-windows (or: python -m torchruntime install)"
160+
elif torch_platform_name == "cuda" and os_name == "Linux":
161+
if getattr(torch.version, "hip", None):
162+
hint = (
163+
"pip install pytorch-triton-rocm --index-url https://download.pytorch.org/whl "
164+
"(or: python -m torchruntime install)"
165+
)
166+
elif torch_platform_name == "xpu" and os_name == "Linux":
167+
hint = (
168+
"pip install pytorch-triton-xpu --index-url https://download.pytorch.org/whl "
169+
"(or: python -m torchruntime install)"
170+
)
171+
172+
if hint:
173+
print("If this failed due to Triton, try:")
174+
print(" ", hint)
175+
176+
print("--- / COMPILE TEST ---")

0 commit comments

Comments
 (0)