Skip to content

Commit 6729c3a

Browse files
authored
Refactor: Decoupling NPU-related ttshared Pass Patches and Integrating Non-bisheng-compiler Stages into Pybind11 and Stage Consolidation (#125)
* Refactor: Integrate Pybind11 and Refactor NPU Passes for Compile Speedup and Size Reduction This commit introduces significant improvements to the NPU backend focusing on code architecture and build performance: 1. NPU Code Refactoring: NPU-specific logic was successfully decoupled from the monolithic 'triton-shared' component and re-implemented as dedicated, cleaner passes and patterns. This greatly enhances code clarity and maintainability. 2. Performance Boost (Pybind11): Integrated Pybind11, effectively eliminating Python-C++ interface overhead. Compilation benchmarks show intermediate stages (e.g., ttir_post, ttshared, linkedir) are now over 60% faster on average. 3. Binary Size Optimization: The implementation allows for the removal of the external 'opt' dependency, resulting in a substantial reduction of the release package size: Before: ~488MB After: ~330MB Note for Publish Build: To generate the optimized release version, ensure the environment variable is set: export IS_NOT_PUBLISH=0 * fix Improve error handling for bishengir-compile failures
1 parent 84cfd59 commit 6729c3a

File tree

30 files changed

+1526
-718
lines changed

30 files changed

+1526
-718
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ tags
9696
kernel_meta/
9797
fusion_result.json
9898
*.log
99+
launcher_cxx11abi*
99100

100101
# package
101102
backend/triton-shared-opt-v3*

CMakeLists.txt

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
set(DC_TRITON_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
22
set(DC_TRITON_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
3+
set(TRTION_SHARED_NPU_SPECIFIC_SOURCES "${DC_TRITON_SOURCE_DIR}/compiler/lib/Conversion/TritonToLinalgNPU")
34

45
set(DC_TRITON_INCLUDE_DIR "")
56
set(DC_TRITON_LINK_DIR "")
@@ -39,5 +40,24 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tools)
3940
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/compiler)
4041
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/third_party/triton_shared)
4142

42-
add_triton_plugin(DICPTriton ${CMAKE_CURRENT_SOURCE_DIR}/dicp_triton.cc LINK_LIBS ${LIBS})
43-
target_include_directories(DICPTriton PUBLIC "${pybind11_INCLUDE_DIRS}")
43+
if (TRITON_BUILD_PYTHON_MODULE)
44+
add_triton_plugin(tritonDicpTriton ${CMAKE_CURRENT_SOURCE_DIR}/triton_dicp_triton.cc
45+
LINK_LIBS
46+
47+
MLIRAffineToStandard
48+
MLIRIR
49+
MLIRPass
50+
MLIRTransforms
51+
MLIRSupport
52+
MLIRBytecodeWriter
53+
54+
TritonToLinalgNPUCoversion
55+
56+
LinalgExtTransforms
57+
TritonExtTransforms
58+
59+
LinalgToLinked
60+
LinkedToHIVM
61+
)
62+
target_link_libraries(tritonDicpTriton PRIVATE Python3::Module pybind11::headers)
63+
endif()

backend/compiler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,9 @@ def add_stages(self, stages, options, language=None):
235235
from triton.backends.dicp_triton.npu import (
236236
make_ttir,
237237
ttir_to_linalg,
238-
ttir_to_ttsharedir,
238+
ttir_to_ttsharedir_ascend,
239239
ttsharedir_to_linkedir,
240240
linalg_to_bin_enable_npu_compile,
241-
ttir_post,
242241
)
243242

244243
stages["ttir"] = lambda src, metadata: make_ttir(src, metadata, options)
@@ -255,11 +254,10 @@ def add_stages(self, stages, options, language=None):
255254
)
256255
else:
257256
if options.enable_npu_compile:
258-
stages["ttir_post"] = lambda src, metadata: ttir_post(
259-
src, metadata, options
260-
)
261-
stages["ttshared"] = lambda src, metadata: ttir_to_ttsharedir(
262-
src, metadata, options, named_ops=True
257+
stages["ttshared"] = (
258+
lambda src, metadata: ttir_to_ttsharedir_ascend(
259+
src, metadata, options, named_ops=True
260+
)
263261
)
264262
stages["linkedir"] = lambda src, metadata: ttsharedir_to_linkedir(
265263
src, metadata, options, named_ops=True

backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,3 +381,7 @@ def map_python_to_cpp_type(self, ty: str) -> str:
381381
"f32": "float",
382382
"fp64": "double",
383383
}[ty]
384+
385+
@classmethod
386+
def clear_cache(self, cache):
387+
cache.zero_()

backend/npu.py

Lines changed: 86 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import hashlib
88
from triton.runtime.cache import get_cache_manager, get_dump_manager
99
from triton.backends.compiler import GPUTarget
10-
from triton._C.libtriton import ir, passes
10+
from triton._C.libtriton import ir, passes, dicp_triton
1111
from triton.runtime.cache import get_dump_manager
12+
import triton.backends.dicp_triton.utils as dicp_utils
1213
from dataclasses import dataclass
1314
from typing import Any, Union, Tuple, Dict
1415
import ctypes
@@ -25,8 +26,11 @@
2526
replace_linked_ir = os.environ.get("DLC_REPLACE_LINKED_IR_FILE", None)
2627
if dump_ir or (replace_ttshared_ir is not None) or (replace_linked_ir is not None):
2728
os.environ["TRITON_ALWAYS_COMPILE"] = "1"
28-
if not os.path.exists("./tmp"):
29-
os.makedirs("./tmp")
29+
dump_dir = "./tmp"
30+
os.environ["TRITON_DUMP_DIR"] = os.environ.get("TRITON_DUMP_DIR", dump_dir)
31+
if os.path.exists(dump_dir):
32+
print(f"Directory **{dump_dir}** exists. Deleting the entire directory...")
33+
shutil.rmtree(dump_dir)
3034

3135
local_bishengir_path = os.path.join(os.path.dirname(__file__), "../../_C/bishengir")
3236
bisheng_install_path = os.environ.get("BISHENG_INSTALL_PATH", None)
@@ -49,6 +53,19 @@ def downgrade_llir(llir):
4953
return llir
5054

5155

56+
def _replace_mod_ir_with_file(mod, filepath: str, stage_name: str):
57+
p = Path(filepath)
58+
if not p.exists():
59+
raise FileNotFoundError(f"Replacement MLIR file not found: {filepath}")
60+
print(f"[DEBUG] replacing '{stage_name}' IR with file '{filepath}'")
61+
try:
62+
new_mod = ir.parse_mlir_module(str(p), mod.context)
63+
new_mod.context = mod.context
64+
return new_mod
65+
except Exception as e:
66+
raise RuntimeError(f"Failed to parse replacement MLIR file '{filepath}': {e}")
67+
68+
5269
def _downgrade_mem_attrs(llir: str):
5370
memory_pattern = r"memory\([^()]*\)"
5471

@@ -330,6 +347,7 @@ def min_dot_size(target: GPUTarget):
330347
def make_ttir(mod, metadata, opt):
331348
if "hash" not in metadata:
332349
metadata["hash"] = hashlib.md5(f"{mod}-{metadata}".encode()).hexdigest()
350+
mod.set_attr("dicp.backend", ir.builder(mod.context).get_string_attr("ascend"))
333351
# the same optimize pass for triton-ir as all other backends
334352
pm = ir.pass_manager(mod.context)
335353
pm.enable_debug()
@@ -341,44 +359,17 @@ def make_ttir(mod, metadata, opt):
341359
passes.common.add_licm(pm)
342360
passes.common.add_symbol_dce(pm)
343361
pm.run(mod)
344-
if opt.debug:
345-
dump_manager = get_dump_manager(metadata["hash"])
346-
print(f"Dumping intermediate results to {dump_manager.cache_dir}")
347-
dump_manager.put(str(mod), "kernel.ttir.mlir", binary=False)
348-
362+
if opt.debug or dump_ir:
363+
dicp_utils._dump_stage_ir(str(mod), metadata["hash"], "kernel.ttir.mlir")
349364
return mod
350365

351366

352-
def ttir_post(mod, metadata, opt):
353-
ttir_code = str(mod)
354-
with tempfile.TemporaryDirectory() as tmpdir:
355-
src_path = os.path.join(tmpdir, "kernel.ttir.mlir")
356-
dst_path = os.path.join(tmpdir, "kernel.ttir_post.mlir")
357-
Path(src_path).write_text(ttir_code)
358-
dicp_opt_path = _get_dicp_opt_path()
359-
dicp_cmd_list = [
360-
dicp_opt_path,
361-
src_path,
362-
"--canonicalize-cmpi",
363-
"--canonicalize-triton-ir-ascend",
364-
"-o",
365-
dst_path,
366-
]
367-
if dump_ir:
368-
shutil.copy(src_path, "./tmp/kernel.ttir.mlir")
369-
print(f"DEBUG dump ir[ttir_post] command: {dicp_cmd_list}")
370-
ret = subprocess.run(dicp_cmd_list, capture_output=True, check=True)
371-
if dump_ir:
372-
shutil.copy(dst_path, "./tmp/kernel.ttir_post.mlir")
373-
return Path(dst_path).read_text()
374-
375-
376367
def ttir_to_linalg(mod, metadata, opt, *, named_ops=True):
377368
# use triton_adapter to lower Triton-MLIR to linalg
378369
# Get Triton-MLIR as string
379370
ttir_code = str(mod)
380371
with tempfile.TemporaryDirectory() as tmpdir:
381-
src_path = os.path.join(tmpdir, "kernel.ttir_post.mlir")
372+
src_path = os.path.join(tmpdir, "kernel.ttir.mlir")
382373
dst_path = os.path.join(tmpdir, "kernel.ttadapter.mlir")
383374
Path(src_path).write_text(ttir_code)
384375
triton_adapter_opt_path = _get_triton_adapter_opt_path()
@@ -403,90 +394,72 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=True):
403394
return Path(dst_path).read_text()
404395

405396

406-
def ttir_to_ttsharedir(mod, metadata, opt, *, named_ops=False):
407-
ttir_code = str(mod)
408-
# 注释掉gpu.barrier
409-
ttir_code = re.sub(r"gpu\.barrier", r"// gpu.barrier", ttir_code)
410-
with tempfile.TemporaryDirectory() as tmpdir:
411-
src_path = os.path.join(tmpdir, "kernel.ttir_post.mlir")
412-
dst_ttshared_path = os.path.join(tmpdir, "kernel.ttshared.mlir")
413-
Path(src_path).write_text(ttir_code)
414-
triton_shared_opt_path = _get_triton_shared_opt_path()
415-
ttshared_cmd = (
416-
"--triton-to-linalg-experimental"
417-
if "v3_2" not in triton_shared_opt_path
418-
else "--triton-to-linalg"
419-
)
420-
421-
cmd_shared_list = [
422-
triton_shared_opt_path,
423-
src_path,
424-
ttshared_cmd,
425-
"-o",
426-
dst_ttshared_path,
397+
def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False):
398+
pm = ir.pass_manager(mod.context)
399+
dicp_triton.passes.triton_shared_ascend.add_canonicalize_cmpi(pm)
400+
dicp_triton.passes.triton_shared_ascend.add_canonicalize_triton_ir_ascend(pm)
401+
dicp_triton.passes.triton_shared_ascend.add_triton_to_linalg_npu(pm)
402+
pm.run(mod)
403+
if opt.debug or dump_ir:
404+
cmd_list = [
405+
_get_dicp_opt_path(),
406+
"kernel.ttir.mlir",
407+
"--canonicalize-cmpi",
408+
"--canonicalize-triton-ir-ascend",
409+
"--triton-to-linalg-npu-conversion",
427410
]
428-
if dump_ir:
429-
print(f"DEBUG dump ir[ttir_to_ttsharedir] command: {cmd_shared_list}")
430-
ret = subprocess.run(cmd_shared_list, capture_output=True, check=True)
431-
432-
if dump_ir:
433-
shutil.copy(dst_ttshared_path, "./tmp/kernel.ttshared.mlir")
434-
if replace_ttshared_ir is not None:
435-
print(f"[DEBUG] Replace ttsharedir with {replace_ttshared_ir}")
436-
return Path(replace_ttshared_ir).read_text()
437-
return Path(dst_ttshared_path).read_text()
411+
dicp_utils._dump_stage_ir(
412+
str(mod), metadata["hash"], "kernel.ttshared.mlir", cmd_list
413+
)
414+
if replace_ttshared_ir is not None:
415+
return _replace_mod_ir_with_file(
416+
mod, replace_ttshared_ir, "ttir_to_ttsharedir_ascend"
417+
)
418+
return mod
438419

439420

440421
def ttsharedir_to_linkedir(mod, metadata, opt, *, named_ops=False):
441-
ttsharedir_code = str(mod)
442-
with tempfile.TemporaryDirectory() as tmpdir:
443-
src_path = os.path.join(tmpdir, "kernel.ttshared.mlir")
444-
dst_path = os.path.join(tmpdir, "kernel.linkedir.mlir")
445-
Path(src_path).write_text(ttsharedir_code)
446-
dicp_opt_path = _get_dicp_opt_path()
447-
dicp_cmd_list = [
448-
dicp_opt_path,
449-
src_path,
422+
pm = ir.pass_manager(mod.context)
423+
dicp_triton.passes.linked_npu.add_lower_affine(pm)
424+
dicp_triton.passes.linked_npu.add_normalize_slice_ops(pm)
425+
dicp_triton.passes.linked_npu.add_linalg_if_to_select(pm)
426+
dicp_triton.passes.linked_npu.add_linalg_generic_to_scf(pm)
427+
dicp_triton.passes.linked_npu.add_scalar_to_1d_tensor(pm)
428+
dicp_triton.passes.linked_npu.add_linalg_to_linked(pm, False, True)
429+
dicp_triton.passes.linked_npu.add_linked_to_hivm(pm)
430+
pm.run(mod)
431+
432+
# TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
433+
content = str(mod)
434+
# 将"*xfxxx"替换成"?xfxxx"
435+
content = content.replace("*xf", "?xf")
436+
content = content.replace("*xi", "?xi")
437+
content = content.replace("*xbf", "?xbf")
438+
# 匹配形如 "memref<...> to tensor<...>" 的模式
439+
pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)"
440+
# 使用正则替换,保留memref和tensor类型,中间插入注释
441+
content = re.sub(pattern, r"\1 // to \2", content)
442+
443+
if opt.debug or dump_ir:
444+
cmd_list = [
445+
_get_dicp_opt_path(),
446+
"kernel.ttshared.mlir",
450447
"--lower-affine",
451448
"--normalize-slice-ops",
452449
"--linalg-if-to-select",
453450
"--linalg-generic-to-scf",
454451
"--scalar-to-1d-tensor",
455452
f"--linalg-to-linked=global-kernel=false named-ops=true",
456453
"--linked-to-hivm",
457-
"-o",
458-
dst_path,
459454
]
460-
if dump_ir:
461-
print(f"DEBUG dump ir[ttsharedir_to_linkedir] command: {dicp_cmd_list}")
462-
ret = subprocess.run(dicp_cmd_list, capture_output=True, check=True)
463-
# TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
464-
with open(dst_path, "r") as f:
465-
content = f.read()
466-
# 将"*xfxxx"替换成"?xfxxx"
467-
content = content.replace("*xf", "?xf")
468-
content = content.replace("*xi", "?xi")
469-
content = content.replace("*xbf", "?xbf")
470-
with open(dst_path, "w") as f:
471-
f.write(content)
472-
473-
# 匹配形如 "memref<...> to tensor<...>" 的模式
474-
pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)"
475-
with open(dst_path, "r") as f:
476-
lines = f.readlines()
477-
modified = []
478-
for line in lines:
479-
# 使用正则替换,保留memref和tensor类型,中间插入注释
480-
new_line = re.sub(pattern, r"\1 // to \2", line)
481-
modified.append(new_line)
482-
with open(dst_path, "w") as f:
483-
f.writelines(modified)
484-
if dump_ir:
485-
shutil.copy(dst_path, "./tmp/kernel.linkedir.mlir")
486-
if replace_linked_ir is not None:
487-
print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}")
488-
return Path(replace_linked_ir).read_text()
489-
return Path(dst_path).read_text()
455+
dicp_utils._dump_stage_ir(
456+
content, metadata["hash"], "kernel.linkedir.mlir", cmd_list
457+
)
458+
459+
if replace_linked_ir is not None:
460+
print(f"[DEBUG] Replace Linkedir with {replace_linked_ir}")
461+
return Path(replace_linked_ir).read_text()
462+
return content
490463

491464

492465
def linalg_to_llir(linalg: str, metadata, opt):
@@ -721,7 +694,14 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt):
721694
)
722695
if dump_ir:
723696
print(f"DEBUG dump ir[bishengir-compile] command: {cmd_list}")
724-
ret = subprocess.run(cmd_list, capture_output=True, check=True)
697+
try:
698+
ret = subprocess.run(cmd_list, capture_output=True, check=True, text=True)
699+
except subprocess.CalledProcessError as e:
700+
# Print compilation error details
701+
print(f"bishengir-compile compilation failed with exit code {e.returncode}")
702+
print(f"Stderr:\n{e.stderr}")
703+
raise RuntimeError("bishengir-compile compilation failed") from e
704+
725705
if not Path(bin_path).is_file():
726706
print(ret.stderr.decode("utf-8"))
727707
if Path(callback_path).is_file():

0 commit comments

Comments
 (0)