Skip to content

Commit 230116a

Browse files
authored
Zcx/dev:support common ir (#137)
* fix CI * fix ut_tests warning * support dump bisheng ir * support common ir * fix format
1 parent 6afb3db commit 230116a

File tree

5 files changed

+468
-1
lines changed

5 files changed

+468
-1
lines changed

backend/commonir/__init__.py

Whitespace-only changes.

backend/commonir/adapter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class Adapter(object):
2+
pass

backend/commonir/backend.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
import functools
2+
import os
3+
from typing import Any
4+
from ..compiler import DICPOptions
5+
from ..driver import DICPDriver
6+
from ..utils import get_current_backend
7+
8+
9+
class CommonIRBackend:
10+
binary_ext = "ttlinalgdir"
11+
12+
def __init__(self) -> None:
13+
target = get_current_backend()
14+
self.driver = DICPDriver(target)
15+
if self.driver.target == "dicp":
16+
self.binary_ext = "ttlinalgdir"
17+
elif self.driver.target == "mlu":
18+
self.capability = target.arch
19+
assert isinstance(self.capability, int)
20+
self.binary_ext = "cnbin"
21+
elif self.driver.target == "maca":
22+
self.capability = 80
23+
self.binary_ext = "mcfatbin"
24+
elif self.driver.target == "ascend":
25+
self.binary_ext = "npubin"
26+
else:
27+
raise RuntimeError(f"Target '{self.target_type}' is not supported.")
28+
29+
def get_attrs_descriptor(self, params, args):
30+
if self.driver.target == "ascend":
31+
from triton.backends.dicp_triton.npu import AscendAttrsDescriptor
32+
33+
return AscendAttrsDescriptor(params, args)
34+
else:
35+
raise RuntimeError(
36+
f"backend {self.driver.target} not supported for get_attrs_descriptor."
37+
)
38+
39+
def add_stages(self, stages, options, language=None):
40+
41+
if self.driver.target == "ascend":
42+
from triton.backends.dicp_triton.npu import (
43+
commonir_to_linkedir,
44+
linalg_to_bin_enable_npu_compile,
45+
)
46+
47+
stages["linkedir"] = lambda src, metadata: commonir_to_linkedir(
48+
src, metadata, options, named_ops=True
49+
)
50+
stages["npubin"] = lambda src, metadata: linalg_to_bin_enable_npu_compile(
51+
src, metadata, options
52+
)
53+
else:
54+
raise RuntimeError("backend not supported")
55+
56+
def load_dialects(self, ctx):
57+
if self.driver.target == "mlu":
58+
from triton._C.libtriton import mlu
59+
60+
mlu.load_dialects(ctx)
61+
return
62+
63+
def get_driver(self):
64+
return self.driver
65+
66+
# parse add_kernel[(16,)](x, y, output, n_elements, BLOCK_SIZE=1024)
67+
def parse_options(self, options: dict) -> Any:
68+
if self.driver.target == "ascend":
69+
from triton.backends.dicp_triton.npu import NPUOptions
70+
71+
args = {
72+
k: options[k]
73+
for k in NPUOptions.__dataclass_fields__.keys()
74+
if k in options
75+
}
76+
options = NPUOptions(**args)
77+
return options
78+
elif self.driver.target == "mlu":
79+
from triton.backends.dicp_triton.mlu import MLUOptions
80+
81+
args = {
82+
k: options[k]
83+
for k in MLUOptions.__dataclass_fields__.keys()
84+
if k in options
85+
}
86+
# When arch is less than mtp_5xx, tf32 is not supported, use fp32 for calculation.
87+
if "allowed_dot_input_precisions" not in args:
88+
if self.capability < 500:
89+
args["allowed_dot_input_precisions"] = "ieee"
90+
91+
if "supported_fp8_dtypes" not in args:
92+
supported_fp8_dtypes = set(MLUOptions.supported_fp8_dtypes)
93+
if self.capability >= 600:
94+
supported_fp8_dtypes = supported_fp8_dtypes.union(
95+
("fp8e5", "fp8e4nv")
96+
)
97+
args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
98+
99+
args["max_num_imprecise_acc_default"] = 0
100+
101+
if "enable_fp_fusion" not in args:
102+
args["enable_fp_fusion"] = (
103+
os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1"
104+
)
105+
106+
if "enable_mlu_bound_check" not in args:
107+
args["enable_mlu_bound_check"] = (
108+
os.getenv("TRITON_ENABLE_MLU_BOUND_CHECK", "0") == "1"
109+
)
110+
return MLUOptions(**args)
111+
elif self.driver.target == "maca":
112+
from triton.backends.dicp_triton.maca import MACAOptions
113+
114+
# args = {k: options[k] for k in MACAOptions.__dataclass_fields__.keys() if k in options}
115+
# return MACAOptions(**args)
116+
args = {
117+
k: options[k]
118+
for k in MACAOptions.__dataclass_fields__.keys()
119+
if k in options
120+
}
121+
# USE_MACA: support allow_fp8e4nv(i.e. float8_e4m3fn)
122+
args["allow_fp8e4nv"] = True
123+
# args["allow_fp8e4nv"] = False
124+
args["allow_fp8e4b15"] = False
125+
args["max_num_imprecise_acc_default"] = (
126+
2**30 if self.capability == 90 else 0
127+
)
128+
return MACAOptions(**args)
129+
else:
130+
args = {"arch": self.target}
131+
args.update(
132+
{
133+
k: options[k]
134+
for k in DICPOptions.__dataclass_fields__.keys()
135+
if k in options
136+
}
137+
)
138+
return DICPOptions(**args)
139+
140+
def get_codegen_implementation(self, options=None):
141+
codegen_fns = dict()
142+
if self.driver.target == "ascend":
143+
from triton.backends.dicp_triton.npu import min_dot_size
144+
145+
codegen_fns = {"min_dot_size": min_dot_size(self.target)}
146+
elif self.driver.target == "mlu":
147+
from triton.backends.dicp_triton.mlu import min_dot_size
148+
149+
codegen_fns = {
150+
"convert_custom_types": lambda arg, dst_ty: arg,
151+
"min_dot_size": min_dot_size(self.target),
152+
}
153+
elif self.driver.target == "maca":
154+
import triton.language.extra.cuda as cuda
155+
156+
codegen_fns = {
157+
"convert_custom_types": (
158+
cuda.convert_custom_float8_sm80
159+
if self.capability >= 80
160+
else cuda.convert_custom_float8_sm70
161+
)
162+
}
163+
return codegen_fns
164+
165+
def pack_metadata(self, metadata):
166+
if self.driver.target == "ascend":
167+
from triton.backends.dicp_triton.npu import TRITON_PROFILER_REGISTERED
168+
169+
# collect necessary metadata to launch kernels
170+
# TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 could set unique name.
171+
# Get this name as the kernel_name to CANN runtime.
172+
# kernel_name is unique to Ascend backend and should not be public.
173+
# CANN runtime limits the length of kernel name <= 50.
174+
# Considering '\n' is appended, thus the real kernel name <= 49.
175+
KERNEL_NAME_MAX_LEN = 49
176+
kernel_name_orig, mix_mode = metadata.name.split()
177+
if len(kernel_name_orig) > KERNEL_NAME_MAX_LEN:
178+
kernel_name = kernel_name_orig[-KERNEL_NAME_MAX_LEN:]
179+
# import warnings
180+
# # red = "\x1b[31;20m"
181+
# # reset = "\x1b[0m"
182+
# warnings.warn(kernel_name_orig + " is truncated to " + kernel_name)
183+
# warnings.warn("because '" + kernel_name_orig + "' exceeds torchnpu profiler's length limit < 50")
184+
else:
185+
kernel_name = kernel_name_orig
186+
return {
187+
"kernel_name": kernel_name,
188+
"hash": metadata.hash,
189+
"debug": metadata.debug,
190+
"profiler_registered": TRITON_PROFILER_REGISTERED,
191+
}
192+
elif self.driver.target == "mlu":
193+
return (metadata.num_warps,)
194+
return (
195+
metadata.num_warps,
196+
metadata.num_ctas,
197+
metadata.shared,
198+
metadata.cluster_dims[0],
199+
metadata.cluster_dims[1],
200+
metadata.cluster_dims[2],
201+
)
202+
203+
@functools.lru_cache()
204+
def hash(self):
205+
if self.driver.target == "mlu":
206+
from triton.backends.dicp_triton.mlu import get_cnas_version
207+
208+
version = get_cnas_version()
209+
return f"{version}-{self.capability}"
210+
version_key = self.driver.target
211+
return str(version_key)
212+
213+
214+
commonir_backend = CommonIRBackend()

0 commit comments

Comments
 (0)