Skip to content

Commit 17eddee

Browse files
committed
Refactor: Integrate Pybind11 and Refactor NPU Passes for Compile Speedup and Size Reduction
This commit introduces significant improvements to the NPU backend focusing on code architecture and build performance: 1. NPU Code Refactoring: NPU-specific logic was successfully decoupled from the monolithic 'triton-shared' component and re-implemented as dedicated, cleaner passes and patterns. This greatly enhances code clarity and maintainability. 2. Performance Boost (Pybind11): Integrated Pybind11, effectively eliminating Python-C++ interface overhead. Compilation benchmarks show intermediate stages (e.g., ttir_post, ttshared, linkedir) are now over 60% faster on average. 3. Binary Size Optimization: The implementation allows for the removal of the external 'opt' dependency, resulting in a substantial reduction of the release package size: Before: ~488MB After: ~330MB Note for Publish Build: To generate the optimized release version, ensure the environment variable is set: export IS_NOT_PUBLISH=0
1 parent 36e924f commit 17eddee

File tree

30 files changed

+1523
-717
lines changed

30 files changed

+1523
-717
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ tags
9696
kernel_meta/
9797
fusion_result.json
9898
*.log
99+
launcher_cxx11abi*
99100

100101
# package
101102
backend/triton-shared-opt-v3*

CMakeLists.txt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
set(DC_TRITON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
22
set(DC_TRITON_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
3+
set(TRTION_SHARED_NPU_SPECIFIC_SOURCES "${DC_TRITON_SOURCE_DIR}/compiler/lib/Conversion/TritonToLinalgNPU")
34

45
set(DC_TRITON_INCLUDE_DIR "")
56
set(DC_TRITON_LINK_DIR "")
@@ -39,5 +40,24 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools)
3940
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/compiler)
4041
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/triton_shared)
4142

42-
add_triton_plugin(DICPTriton ${CMAKE_CURRENT_SOURCE_DIR}/dicp_triton.cc LINK_LIBS ${LIBS})
43-
target_include_directories(DICPTriton PUBLIC "${pybind11_INCLUDE_DIRS}")
43+
if (TRITON_BUILD_PYTHON_MODULE)
44+
add_triton_plugin(tritonDicpTriton ${CMAKE_CURRENT_SOURCE_DIR}/triton_dicp_triton.cc
45+
LINK_LIBS
46+
47+
MLIRAffineToStandard
48+
MLIRIR
49+
MLIRPass
50+
MLIRTransforms
51+
MLIRSupport
52+
MLIRBytecodeWriter
53+
54+
TritonToLinalgNPUCoversion
55+
56+
LinalgExtTransforms
57+
TritonExtTransforms
58+
59+
LinalgToLinked
60+
LinkedToHIVM
61+
)
62+
target_link_libraries(tritonDicpTriton PRIVATE Python3::Module pybind11::headers)
63+
endif()

backend/compiler.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ def add_stages(self, stages, options, language):
233233
from triton.backends.dicp_triton.npu import (
234234
make_ttir,
235235
ttir_to_linalg,
236-
ttir_to_ttsharedir,
236+
ttir_to_ttsharedir_ascend,
237237
ttsharedir_to_linkedir,
238238
linalg_to_bin_enable_npu_compile,
239-
ttir_post,
240239
)
241240

242241
stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options)
@@ -253,10 +252,7 @@ def add_stages(self, stages, options, language):
253252
)
254253
else:
255254
if options.enable_npu_compile:
256-
stages["ttir_post"] = lambda src, metadata: ttir_post(
257-
src, metadata, options
258-
)
259-
stages["ttshared"] = lambda src, metadata: ttir_to_ttsharedir(
255+
stages["ttshared"] = lambda src, metadata: ttir_to_ttsharedir_ascend(
260256
src, metadata, options, named_ops=True
261257
)
262258
stages["linkedir"] = lambda src, metadata: ttsharedir_to_linkedir(

backend/driver.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from triton.runtime.cache import get_cache_manager, get_dump_manager
1010
from triton.backends.driver import DriverBase
1111
from triton.backends.compiler import GPUTarget
12-
from triton.backends.dicp_triton.utils import get_current_backend
12+
from triton.backends.dicp_triton.utils import get_current_backend, _cache_dir
1313

1414
import importlib
1515
import shutil
@@ -372,3 +372,19 @@ def map_python_to_cpp_type(self, ty: str) -> str:
372372
"f32": "float",
373373
"fp64": "double",
374374
}[ty]
375+
376+
@classmethod
377+
def cache_dir_path(self) -> Path:
378+
"""返回缓存目录 Path(不创建目录)。"""
379+
return _cache_dir()
380+
381+
@classmethod
382+
def cache_dir(self) -> Path:
383+
"""返回并创建缓存目录(如果不存在)。"""
384+
p = self.cache_dir_path()
385+
p.mkdir(parents=True, exist_ok=True)
386+
return p
387+
388+
@classmethod
389+
def clear_cache(self, cache):
390+
cache.zero_()

backend/npu.py

Lines changed: 68 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import hashlib
88
from triton.runtime.cache import get_cache_manager, get_dump_manager
99
from triton.backends.compiler import GPUTarget
10-
from triton._C.libtriton import ir, passes
10+
from triton._C.libtriton import ir, passes, dicp_triton
1111
from triton.runtime.cache import get_dump_manager
12+
from triton.backends.dicp_triton.utils import _dump_stage_ir
1213
from dataclasses import dataclass
1314
from typing import Any, Union, Tuple, Dict
1415
import ctypes
@@ -25,8 +26,7 @@
2526
replace_linked_ir = os.environ.get("DLC_REPLACE_LINKED_IR_FILE", None)
2627
if dump_ir or (replace_ttshared_ir is not None) or (replace_linked_ir is not None):
2728
os.environ["TRITON_ALWAYS_COMPILE"] = "1"
28-
if not os.path.exists("./tmp"):
29-
os.makedirs("./tmp")
29+
os.environ["TRITON_DUMP_DIR"] = os.environ.get("TRITON_DUMP_DIR", "./tmp")
3030

3131
local_bishengir_path = os.path.join(os.path.dirname(__file__), "../../_C/bishengir")
3232
bisheng_install_path = os.environ.get("BISHENG_INSTALL_PATH", None)
@@ -49,6 +49,19 @@ def downgrade_llir(llir):
4949
return llir
5050

5151

52+
def _replace_mod_ir_with_file(mod, filepath: str, stage_name: str):
53+
p = Path(filepath)
54+
if not p.exists():
55+
raise FileNotFoundError(f"Replacement MLIR file not found: {filepath}")
56+
print(f"[DEBUG] replacing '{stage_name}' IR with file '{filepath}'")
57+
try:
58+
new_mod = ir.parse_mlir_module(str(p), mod.context)
59+
new_mod.context = mod.context
60+
return new_mod
61+
except Exception as e:
62+
raise RuntimeError(f"Failed to parse replacement MLIR file '{filepath}': {e}")
63+
64+
5265
def _downgrade_mem_attrs(llir: str):
5366
memory_pattern = r"memory\([^()]*\)"
5467

@@ -341,44 +354,17 @@ def make_ttir(mod, metadata, opt):
341354
passes.common.add_licm(pm)
342355
passes.common.add_symbol_dce(pm)
343356
pm.run(mod)
344-
if opt.debug:
345-
dump_manager = get_dump_manager(metadata["hash"])
346-
print(f"Dumping intermediate results to {dump_manager.cache_dir}")
347-
dump_manager.put(str(mod), "kernel.ttir.mlir", binary=False)
348-
357+
if opt.debug or dump_ir:
358+
_dump_stage_ir(str(mod), metadata["hash"], "kernel.ttir.mlir")
349359
return mod
350360

351361

352-
def ttir_post(mod, metadata, opt):
353-
ttir_code = str(mod)
354-
with tempfile.TemporaryDirectory() as tmpdir:
355-
src_path = os.path.join(tmpdir, "kernel.ttir.mlir")
356-
dst_path = os.path.join(tmpdir, "kernel.ttir_post.mlir")
357-
Path(src_path).write_text(ttir_code)
358-
dicp_opt_path = _get_dicp_opt_path()
359-
dicp_cmd_list = [
360-
dicp_opt_path,
361-
src_path,
362-
"--canonicalize-cmpi",
363-
"--canonicalize-triton-ir-ascend",
364-
"-o",
365-
dst_path,
366-
]
367-
if dump_ir:
368-
shutil.copy(src_path, "./tmp/kernel.ttir.mlir")
369-
print(f"DEBUG dump ir[ttir_post] command: {dicp_cmd_list}")
370-
ret = subprocess.run(dicp_cmd_list, capture_output=True, check=True)
371-
if dump_ir:
372-
shutil.copy(dst_path, "./tmp/kernel.ttir_post.mlir")
373-
return Path(dst_path).read_text()
374-
375-
376362
def ttir_to_linalg(mod, metadata, opt, *, named_ops=True):
377363
# use triton_adapter to lower Triton-MLIR to linalg
378364
# Get Triton-MLIR as string
379365
ttir_code = str(mod)
380366
with tempfile.TemporaryDirectory() as tmpdir:
381-
src_path = os.path.join(tmpdir, "kernel.ttir_post.mlir")
367+
src_path = os.path.join(tmpdir, "kernel.ttir.mlir")
382368
dst_path = os.path.join(tmpdir, "kernel.ttadapter.mlir")
383369
Path(src_path).write_text(ttir_code)
384370
triton_adapter_opt_path = _get_triton_adapter_opt_path()
@@ -403,90 +389,67 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=True):
403389
return Path(dst_path).read_text()
404390

405391

406-
def ttir_to_ttsharedir(mod, metadata, opt, *, named_ops=False):
407-
ttir_code = str(mod)
408-
# 注释掉gpu.barrier
409-
ttir_code = re.sub(r"gpu\.barrier", r"// gpu.barrier", ttir_code)
410-
with tempfile.TemporaryDirectory() as tmpdir:
411-
src_path = os.path.join(tmpdir, "kernel.ttir_post.mlir")
412-
dst_ttshared_path = os.path.join(tmpdir, "kernel.ttshared.mlir")
413-
Path(src_path).write_text(ttir_code)
414-
triton_shared_opt_path = _get_triton_shared_opt_path()
415-
ttshared_cmd = (
416-
"--triton-to-linalg-experimental"
417-
if "v3_2" not in triton_shared_opt_path
418-
else "--triton-to-linalg"
419-
)
420-
421-
cmd_shared_list = [
422-
triton_shared_opt_path,
423-
src_path,
424-
ttshared_cmd,
425-
"-o",
426-
dst_ttshared_path,
392+
def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False):
393+
mod.set_attr("dicp.backend", ir.builder(mod.context).get_string_attr("ascend"))
394+
pm = ir.pass_manager(mod.context)
395+
dicp_triton.passes.triton_shared_ascend.add_canonicalize_cmpi(pm)
396+
dicp_triton.passes.triton_shared_ascend.add_canonicalize_triton_ir_ascend(pm)
397+
dicp_triton.passes.triton_shared_ascend.add_triton_to_linalg_npu(pm)
398+
pm.run(mod)
399+
if opt.debug or dump_ir:
400+
cmd_list = [
401+
_get_dicp_opt_path(),
402+
"kernel.ttir.mlir",
403+
"--canonicalize-cmpi",
404+
"--canonicalize-triton-ir-ascend",
405+
"--triton-to-linalg-npu-conversion",
427406
]
428-
if dump_ir:
429-
print(f"DEBUG dump ir[ttir_to_ttsharedir] command: {cmd_shared_list}")
430-
ret = subprocess.run(cmd_shared_list, capture_output=True, check=True)
431-
432-
if dump_ir:
433-
shutil.copy(dst_ttshared_path, "./tmp/kernel.ttshared.mlir")
434-
if replace_ttshared_ir is not None:
435-
print(f"[DEBUG] Replace ttsharedir with {replace_ttshared_ir}")
436-
return Path(replace_ttshared_ir).read_text()
437-
return Path(dst_ttshared_path).read_text()
407+
_dump_stage_ir(str(mod), metadata["hash"], "kernel.ttshared.mlir", cmd_list)
408+
if replace_ttshared_ir is not None:
409+
return _replace_mod_ir_with_file(mod, replace_ttshared_ir, "ttir_to_ttsharedir_ascend")
410+
return mod
438411

439412

440413
def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False):
441-
ttsharedir_code = str(mod)
442-
with tempfile.TemporaryDirectory() as tmpdir:
443-
src_path = os.path.join(tmpdir, "kernel.ttshared.mlir")
444-
dst_path = os.path.join(tmpdir, "kernel.linkedir.mlir")
445-
Path(src_path).write_text(ttsharedir_code)
446-
dicp_opt_path = _get_dicp_opt_path()
447-
dicp_cmd_list = [
448-
dicp_opt_path,
449-
src_path,
414+
pm = ir.pass_manager(mod.context)
415+
dicp_triton.passes.linked_npu.add_lower_affine(pm)
416+
dicp_triton.passes.linked_npu.add_normalize_slice_ops(pm)
417+
dicp_triton.passes.linked_npu.add_linalg_if_to_select(pm)
418+
dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm)
419+
dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm)
420+
dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True)
421+
dicp_triton.passes.linked_npu.add_linked_to_hivm(pm)
422+
pm.run(mod)
423+
424+
# TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
425+
content = str(mod)
426+
# 将"*xfxxx"替换成"?xfxxx"
427+
content = content.replace("*xf", "?xf")
428+
content = content.replace("*xi", "?xi")
429+
content = content.replace("*xbf", "?xbf")
430+
# 匹配形如 "memref<...> to tensor<...>" 的模式
431+
pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)"
432+
# 使用正则替换,保留memref和tensor类型,中间插入注释
433+
content = re.sub(pattern, r"\1 // to \2", content)
434+
435+
if opt.debug or dump_ir:
436+
cmd_list = [
437+
_get_dicp_opt_path(),
438+
"kernel.ttshared.mlir",
450439
"--lower-affine",
451440
"--normalize-slice-ops",
452441
"--linalg-if-to-select",
453442
"--linalg-generic-to-scf",
454443
"--scalar-to-1d-tensor",
455444
f"--linalg-to-linked=global-kernel=false named-ops=true",
456445
"--linked-to-hivm",
457-
"-o",
458-
dst_path,
459446
]
460-
if dump_ir:
461-
print(f"DEBUG dump ir[ttsharedir_to_linkedir] command: {dicp_cmd_list}")
462-
ret = subprocess.run(dicp_cmd_list, capture_output=True, check=True)
463-
# TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
464-
with open(dst_path, "r") as f:
465-
content = f.read()
466-
# 将"*xfxxx"替换成"?xfxxx"
467-
content = content.replace("*xf", "?xf")
468-
content = content.replace("*xi", "?xi")
469-
content = content.replace("*xbf", "?xbf")
470-
with open(dst_path, "w") as f:
471-
f.write(content)
472-
473-
# 匹配形如 "memref<...> to tensor<...>" 的模式
474-
pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)"
475-
with open(dst_path, "r") as f:
476-
lines = f.readlines()
477-
modified = []
478-
for line in lines:
479-
# 使用正则替换,保留memref和tensor类型,中间插入注释
480-
new_line = re.sub(pattern, r"\1 // to \2", line)
481-
modified.append(new_line)
482-
with open(dst_path, "w") as f:
483-
f.writelines(modified)
484-
if dump_ir:
485-
shutil.copy(dst_path, "./tmp/kernel.linkedir.mlir")
486-
if replace_linked_ir is not None:
487-
print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}")
488-
return Path(replace_linked_ir).read_text()
489-
return Path(dst_path).read_text()
447+
_dump_stage_ir(content, metadata["hash"], "kernel.linkedir.mlir", cmd_list)
448+
449+
if replace_linked_ir is not None:
450+
print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}")
451+
return Path(replace_linked_ir).read_text()
452+
return content
490453

491454

492455
def linalg_to_llir(linalg: str, metadata, opt):

0 commit comments

Comments
 (0)