77import hashlib
88from triton .runtime .cache import get_cache_manager , get_dump_manager
99from triton .backends .compiler import GPUTarget
10- from triton ._C .libtriton import ir , passes
10+ from triton ._C .libtriton import ir , passes , dicp_triton
1111from triton .runtime .cache import get_dump_manager
12+ import triton .backends .dicp_triton .utils as dicp_utils
1213from dataclasses import dataclass
1314from typing import Any , Union , Tuple , Dict
1415import ctypes
2526replace_linked_ir = os .environ .get ("DLC_REPLACE_LINKED_IR_FILE" , None )
2627if dump_ir or (replace_ttshared_ir is not None ) or (replace_linked_ir is not None ):
2728 os .environ ["TRITON_ALWAYS_COMPILE" ] = "1"
28- if not os .path .exists ("./tmp" ):
29- os .makedirs ("./tmp" )
29+ dump_dir = "./tmp"
30+ os .environ ["TRITON_DUMP_DIR" ] = os .environ .get ("TRITON_DUMP_DIR" , dump_dir )
31+ if os .path .exists (dump_dir ):
32+ print (f"Directory **{ dump_dir } ** exists. Deleting the entire directory..." )
33+ shutil .rmtree (dump_dir )
3034
3135local_bishengir_path = os .path .join (os .path .dirname (__file__ ), "../../_C/bishengir" )
3236bisheng_install_path = os .environ .get ("BISHENG_INSTALL_PATH" , None )
@@ -49,6 +53,19 @@ def downgrade_llir(llir):
4953 return llir
5054
5155
56+ def _replace_mod_ir_with_file (mod , filepath : str , stage_name : str ):
57+ p = Path (filepath )
58+ if not p .exists ():
59+ raise FileNotFoundError (f"Replacement MLIR file not found: { filepath } " )
60+ print (f"[DEBUG] replacing '{ stage_name } ' IR with file '{ filepath } '" )
61+ try :
62+ new_mod = ir .parse_mlir_module (str (p ), mod .context )
63+ new_mod .context = mod .context
64+ return new_mod
65+ except Exception as e :
66+ raise RuntimeError (f"Failed to parse replacement MLIR file '{ filepath } ': { e } " )
67+
68+
5269def _downgrade_mem_attrs (llir : str ):
5370 memory_pattern = r"memory\([^()]*\)"
5471
@@ -330,6 +347,7 @@ def min_dot_size(target: GPUTarget):
330347def make_ttir (mod , metadata , opt ):
331348 if "hash" not in metadata :
332349 metadata ["hash" ] = hashlib .md5 (f"{ mod } -{ metadata } " .encode ()).hexdigest ()
350+ mod .set_attr ("dicp.backend" , ir .builder (mod .context ).get_string_attr ("ascend" ))
333351 # the same optimize pass for triton-ir as all other backends
334352 pm = ir .pass_manager (mod .context )
335353 pm .enable_debug ()
@@ -341,44 +359,17 @@ def make_ttir(mod, metadata, opt):
341359 passes .common .add_licm (pm )
342360 passes .common .add_symbol_dce (pm )
343361 pm .run (mod )
344- if opt .debug :
345- dump_manager = get_dump_manager (metadata ["hash" ])
346- print (f"Dumping intermediate results to { dump_manager .cache_dir } " )
347- dump_manager .put (str (mod ), "kernel.ttir.mlir" , binary = False )
348-
362+ if opt .debug or dump_ir :
363+ dicp_utils ._dump_stage_ir (str (mod ), metadata ["hash" ], "kernel.ttir.mlir" )
349364 return mod
350365
351366
352- def ttir_post (mod , metadata , opt ):
353- ttir_code = str (mod )
354- with tempfile .TemporaryDirectory () as tmpdir :
355- src_path = os .path .join (tmpdir , "kernel.ttir.mlir" )
356- dst_path = os .path .join (tmpdir , "kernel.ttir_post.mlir" )
357- Path (src_path ).write_text (ttir_code )
358- dicp_opt_path = _get_dicp_opt_path ()
359- dicp_cmd_list = [
360- dicp_opt_path ,
361- src_path ,
362- "--canonicalize-cmpi" ,
363- "--canonicalize-triton-ir-ascend" ,
364- "-o" ,
365- dst_path ,
366- ]
367- if dump_ir :
368- shutil .copy (src_path , "./tmp/kernel.ttir.mlir" )
369- print (f"DEBUG dump ir[ttir_post] command: { dicp_cmd_list } " )
370- ret = subprocess .run (dicp_cmd_list , capture_output = True , check = True )
371- if dump_ir :
372- shutil .copy (dst_path , "./tmp/kernel.ttir_post.mlir" )
373- return Path (dst_path ).read_text ()
374-
375-
376367def ttir_to_linalg (mod , metadata , opt , * , named_ops = True ):
377368 # use triton_adapter to lower Triton-MLIR to linalg
378369 # Get Triton-MLIR as string
379370 ttir_code = str (mod )
380371 with tempfile .TemporaryDirectory () as tmpdir :
381- src_path = os .path .join (tmpdir , "kernel.ttir_post .mlir" )
372+ src_path = os .path .join (tmpdir , "kernel.ttir .mlir" )
382373 dst_path = os .path .join (tmpdir , "kernel.ttadapter.mlir" )
383374 Path (src_path ).write_text (ttir_code )
384375 triton_adapter_opt_path = _get_triton_adapter_opt_path ()
@@ -403,90 +394,72 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=True):
403394 return Path (dst_path ).read_text ()
404395
405396
406- def ttir_to_ttsharedir (mod , metadata , opt , * , named_ops = False ):
407- ttir_code = str (mod )
408- # 注释掉gpu.barrier
409- ttir_code = re .sub (r"gpu\.barrier" , r"// gpu.barrier" , ttir_code )
410- with tempfile .TemporaryDirectory () as tmpdir :
411- src_path = os .path .join (tmpdir , "kernel.ttir_post.mlir" )
412- dst_ttshared_path = os .path .join (tmpdir , "kernel.ttshared.mlir" )
413- Path (src_path ).write_text (ttir_code )
414- triton_shared_opt_path = _get_triton_shared_opt_path ()
415- ttshared_cmd = (
416- "--triton-to-linalg-experimental"
417- if "v3_2" not in triton_shared_opt_path
418- else "--triton-to-linalg"
419- )
420-
421- cmd_shared_list = [
422- triton_shared_opt_path ,
423- src_path ,
424- ttshared_cmd ,
425- "-o" ,
426- dst_ttshared_path ,
397+ def ttir_to_ttsharedir_ascend (mod , metadata , opt , * , named_ops = False ):
398+ pm = ir .pass_manager (mod .context )
399+ dicp_triton .passes .triton_shared_ascend .add_canonicalize_cmpi (pm )
400+ dicp_triton .passes .triton_shared_ascend .add_canonicalize_triton_ir_ascend (pm )
401+ dicp_triton .passes .triton_shared_ascend .add_triton_to_linalg_npu (pm )
402+ pm .run (mod )
403+ if opt .debug or dump_ir :
404+ cmd_list = [
405+ _get_dicp_opt_path (),
406+ "kernel.ttir.mlir" ,
407+ "--canonicalize-cmpi" ,
408+ "--canonicalize-triton-ir-ascend" ,
409+ "--triton-to-linalg-npu-conversion" ,
427410 ]
428- if dump_ir :
429- print (f"DEBUG dump ir[ttir_to_ttsharedir] command: { cmd_shared_list } " )
430- ret = subprocess .run (cmd_shared_list , capture_output = True , check = True )
431-
432- if dump_ir :
433- shutil .copy (dst_ttshared_path , "./tmp/kernel.ttshared.mlir" )
434- if replace_ttshared_ir is not None :
435- print (f"[DEBUG] Replace ttsharedir with { replace_ttshared_ir } " )
436- return Path (replace_ttshared_ir ).read_text ()
437- return Path (dst_ttshared_path ).read_text ()
411+ dicp_utils ._dump_stage_ir (
412+ str (mod ), metadata ["hash" ], "kernel.ttshared.mlir" , cmd_list
413+ )
414+ if replace_ttshared_ir is not None :
415+ return _replace_mod_ir_with_file (
416+ mod , replace_ttshared_ir , "ttir_to_ttsharedir_ascend"
417+ )
418+ return mod
438419
439420
440421def ttsharedir_to_linkedir (mod , metadata , opt , * , named_ops = False ):
441- ttsharedir_code = str (mod )
442- with tempfile .TemporaryDirectory () as tmpdir :
443- src_path = os .path .join (tmpdir , "kernel.ttshared.mlir" )
444- dst_path = os .path .join (tmpdir , "kernel.linkedir.mlir" )
445- Path (src_path ).write_text (ttsharedir_code )
446- dicp_opt_path = _get_dicp_opt_path ()
447- dicp_cmd_list = [
448- dicp_opt_path ,
449- src_path ,
422+ pm = ir .pass_manager (mod .context )
423+ dicp_triton .passes .linked_npu .add_lower_affine (pm )
424+ dicp_triton .passes .linked_npu .add_normalize_slice_ops (pm )
425+ dicp_triton .passes .linked_npu .add_linalg_if_to_select (pm )
426+ dicp_triton .passes .linked_npu .add_linalg_generic_to_scf (pm )
427+ dicp_triton .passes .linked_npu .add_scalar_to_1d_tensor (pm )
428+ dicp_triton .passes .linked_npu .add_linalg_to_linked (pm , False , True )
429+ dicp_triton .passes .linked_npu .add_linked_to_hivm (pm )
430+ pm .run (mod )
431+
432+ # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
433+ content = str (mod )
434+ # 将"*xfxxx"替换成"?xfxxx"
435+ content = content .replace ("*xf" , "?xf" )
436+ content = content .replace ("*xi" , "?xi" )
437+ content = content .replace ("*xbf" , "?xbf" )
438+ # 匹配形如 "memref<...> to tensor<...>" 的模式
439+ pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)"
440+ # 使用正则替换,保留memref和tensor类型,中间插入注释
441+ content = re .sub (pattern , r"\1 // to \2" , content )
442+
443+ if opt .debug or dump_ir :
444+ cmd_list = [
445+ _get_dicp_opt_path (),
446+ "kernel.ttshared.mlir" ,
450447 "--lower-affine" ,
451448 "--normalize-slice-ops" ,
452449 "--linalg-if-to-select" ,
453450 "--linalg-generic-to-scf" ,
454451 "--scalar-to-1d-tensor" ,
455452 f"--linalg-to-linked=global-kernel=false named-ops=true" ,
456453 "--linked-to-hivm" ,
457- "-o" ,
458- dst_path ,
459454 ]
460- if dump_ir :
461- print (f"DEBUG dump ir[ttsharedir_to_linkedir] command: { dicp_cmd_list } " )
462- ret = subprocess .run (dicp_cmd_list , capture_output = True , check = True )
463- # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
464- with open (dst_path , "r" ) as f :
465- content = f .read ()
466- # 将"*xfxxx"替换成"?xfxxx"
467- content = content .replace ("*xf" , "?xf" )
468- content = content .replace ("*xi" , "?xi" )
469- content = content .replace ("*xbf" , "?xbf" )
470- with open (dst_path , "w" ) as f :
471- f .write (content )
472-
473- # 匹配形如 "memref<...> to tensor<...>" 的模式
474- pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)"
475- with open (dst_path , "r" ) as f :
476- lines = f .readlines ()
477- modified = []
478- for line in lines :
479- # 使用正则替换,保留memref和tensor类型,中间插入注释
480- new_line = re .sub (pattern , r"\1 // to \2" , line )
481- modified .append (new_line )
482- with open (dst_path , "w" ) as f :
483- f .writelines (modified )
484- if dump_ir :
485- shutil .copy (dst_path , "./tmp/kernel.linkedir.mlir" )
486- if replace_linked_ir is not None :
487- print (f"[DEBUG] Replace Linkedir with { replace_linked_ir } " )
488- return Path (replace_linked_ir ).read_text ()
489- return Path (dst_path ).read_text ()
455+ dicp_utils ._dump_stage_ir (
456+ content , metadata ["hash" ], "kernel.linkedir.mlir" , cmd_list
457+ )
458+
459+ if replace_linked_ir is not None :
460+ print (f"[DEBUG] Replace Linkedir with { replace_linked_ir } " )
461+ return Path (replace_linked_ir ).read_text ()
462+ return content
490463
491464
492465def linalg_to_llir (linalg : str , metadata , opt ):
@@ -721,7 +694,14 @@ def linalg_to_bin_enable_npu_compile(linalg: str, metadata, opt):
721694 )
722695 if dump_ir :
723696 print (f"DEBUG dump ir[bishengir-compile] command: { cmd_list } " )
724- ret = subprocess .run (cmd_list , capture_output = True , check = True )
697+ try :
698+ ret = subprocess .run (cmd_list , capture_output = True , check = True , text = True )
699+ except subprocess .CalledProcessError as e :
700+ # Print compilation error details
701+ print (f"bishengir-compile compilation failed with exit code { e .returncode } " )
702+ print (f"Stderr:\n { e .stderr } " )
703+ raise RuntimeError ("bishengir-compile compilation failed" ) from e
704+
725705 if not Path (bin_path ).is_file ():
726706 print (ret .stderr .decode ("utf-8" ))
727707 if Path (callback_path ).is_file ():
0 commit comments