77import hashlib
88from triton .runtime .cache import get_cache_manager , get_dump_manager
99from triton .backends .compiler import GPUTarget
10- from triton ._C .libtriton import ir , passes
10+ from triton ._C .libtriton import ir , passes , dicp_triton
1111from triton .runtime .cache import get_dump_manager
12+ from triton .backends .dicp_triton .utils import _dump_stage_ir
1213from dataclasses import dataclass
1314from typing import Any , Union , Tuple , Dict
1415import ctypes
2526replace_linked_ir = os .environ .get ("DLC_REPLACE_LINKED_IR_FILE" , None )
2627if dump_ir or (replace_ttshared_ir is not None ) or (replace_linked_ir is not None ):
2728 os .environ ["TRITON_ALWAYS_COMPILE" ] = "1"
28- if not os .path .exists ("./tmp" ):
29- os .makedirs ("./tmp" )
29+ os .environ ["TRITON_DUMP_DIR" ] = os .environ .get ("TRITON_DUMP_DIR" , "./tmp" )
3030
3131local_bishengir_path = os .path .join (os .path .dirname (__file__ ), "../../_C/bishengir" )
3232bisheng_install_path = os .environ .get ("BISHENG_INSTALL_PATH" , None )
@@ -49,6 +49,19 @@ def downgrade_llir(llir):
4949 return llir
5050
5151
52+ def _replace_mod_ir_with_file (mod , filepath : str , stage_name : str ):
53+ p = Path (filepath )
54+ if not p .exists ():
55+ raise FileNotFoundError (f"Replacement MLIR file not found: { filepath } " )
56+ print (f"[DEBUG] replacing '{ stage_name } ' IR with file '{ filepath } '" )
57+ try :
58+ new_mod = ir .parse_mlir_module (str (p ), mod .context )
59+ new_mod .context = mod .context
60+ return new_mod
61+ except Exception as e :
62+ raise RuntimeError (f"Failed to parse replacement MLIR file '{ filepath } ': { e } " )
63+
64+
5265def _downgrade_mem_attrs (llir : str ):
5366 memory_pattern = r"memory\([^()]*\)"
5467
@@ -341,44 +354,17 @@ def make_ttir(mod, metadata, opt):
341354 passes .common .add_licm (pm )
342355 passes .common .add_symbol_dce (pm )
343356 pm .run (mod )
344- if opt .debug :
345- dump_manager = get_dump_manager (metadata ["hash" ])
346- print (f"Dumping intermediate results to { dump_manager .cache_dir } " )
347- dump_manager .put (str (mod ), "kernel.ttir.mlir" , binary = False )
348-
357+ if opt .debug or dump_ir :
358+ _dump_stage_ir (str (mod ), metadata ["hash" ], "kernel.ttir.mlir" )
349359 return mod
350360
351361
352- def ttir_post (mod , metadata , opt ):
353- ttir_code = str (mod )
354- with tempfile .TemporaryDirectory () as tmpdir :
355- src_path = os .path .join (tmpdir , "kernel.ttir.mlir" )
356- dst_path = os .path .join (tmpdir , "kernel.ttir_post.mlir" )
357- Path (src_path ).write_text (ttir_code )
358- dicp_opt_path = _get_dicp_opt_path ()
359- dicp_cmd_list = [
360- dicp_opt_path ,
361- src_path ,
362- "--canonicalize-cmpi" ,
363- "--canonicalize-triton-ir-ascend" ,
364- "-o" ,
365- dst_path ,
366- ]
367- if dump_ir :
368- shutil .copy (src_path , "./tmp/kernel.ttir.mlir" )
369- print (f"DEBUG dump ir[ttir_post] command: { dicp_cmd_list } " )
370- ret = subprocess .run (dicp_cmd_list , capture_output = True , check = True )
371- if dump_ir :
372- shutil .copy (dst_path , "./tmp/kernel.ttir_post.mlir" )
373- return Path (dst_path ).read_text ()
374-
375-
376362def ttir_to_linalg (mod , metadata , opt , * , named_ops = True ):
377363 # use triton_adapter to lower Triton-MLIR to linalg
378364 # Get Triton-MLIR as string
379365 ttir_code = str (mod )
380366 with tempfile .TemporaryDirectory () as tmpdir :
381- src_path = os .path .join (tmpdir , "kernel.ttir_post .mlir" )
367+ src_path = os .path .join (tmpdir , "kernel.ttir .mlir" )
382368 dst_path = os .path .join (tmpdir , "kernel.ttadapter.mlir" )
383369 Path (src_path ).write_text (ttir_code )
384370 triton_adapter_opt_path = _get_triton_adapter_opt_path ()
@@ -403,90 +389,67 @@ def ttir_to_linalg(mod, metadata, opt, *, named_ops=True):
403389 return Path (dst_path ).read_text ()
404390
405391
406- def ttir_to_ttsharedir (mod , metadata , opt , * , named_ops = False ):
407- ttir_code = str (mod )
408- # 注释掉gpu.barrier
409- ttir_code = re .sub (r"gpu\.barrier" , r"// gpu.barrier" , ttir_code )
410- with tempfile .TemporaryDirectory () as tmpdir :
411- src_path = os .path .join (tmpdir , "kernel.ttir_post.mlir" )
412- dst_ttshared_path = os .path .join (tmpdir , "kernel.ttshared.mlir" )
413- Path (src_path ).write_text (ttir_code )
414- triton_shared_opt_path = _get_triton_shared_opt_path ()
415- ttshared_cmd = (
416- "--triton-to-linalg-experimental"
417- if "v3_2" not in triton_shared_opt_path
418- else "--triton-to-linalg"
419- )
420-
421- cmd_shared_list = [
422- triton_shared_opt_path ,
423- src_path ,
424- ttshared_cmd ,
425- "-o" ,
426- dst_ttshared_path ,
392+ def ttir_to_ttsharedir_ascend (mod , metadata , opt , * , named_ops = False ):
393+ mod .set_attr ("dicp.backend" , ir .builder (mod .context ).get_string_attr ("ascend" ))
394+ pm = ir .pass_manager (mod .context )
395+ dicp_triton .passes .triton_shared_ascend .add_canonicalize_cmpi (pm )
396+ dicp_triton .passes .triton_shared_ascend .add_canonicalize_triton_ir_ascend (pm )
397+ dicp_triton .passes .triton_shared_ascend .add_triton_to_linalg_npu (pm )
398+ pm .run (mod )
399+ if opt .debug or dump_ir :
400+ cmd_list = [
401+ _get_dicp_opt_path (),
402+ "kernel.ttir.mlir" ,
403+ "--canonicalize-cmpi" ,
404+ "--canonicalize-triton-ir-ascend" ,
405+ "--triton-to-linalg-npu-conversion" ,
427406 ]
428- if dump_ir :
429- print (f"DEBUG dump ir[ttir_to_ttsharedir] command: { cmd_shared_list } " )
430- ret = subprocess .run (cmd_shared_list , capture_output = True , check = True )
431-
432- if dump_ir :
433- shutil .copy (dst_ttshared_path , "./tmp/kernel.ttshared.mlir" )
434- if replace_ttshared_ir is not None :
435- print (f"[DEBUG] Replace ttsharedir with { replace_ttshared_ir } " )
436- return Path (replace_ttshared_ir ).read_text ()
437- return Path (dst_ttshared_path ).read_text ()
407+ _dump_stage_ir (str (mod ), metadata ["hash" ], "kernel.ttshared.mlir" , cmd_list )
408+ if replace_ttshared_ir is not None :
409+ return _replace_mod_ir_with_file (mod , replace_ttshared_ir , "ttir_to_ttsharedir_ascend" )
410+ return mod
438411
439412
440413def ttsharedir_to_linkedir (mod , metadata , opt , * , named_ops = False ):
441- ttsharedir_code = str (mod )
442- with tempfile .TemporaryDirectory () as tmpdir :
443- src_path = os .path .join (tmpdir , "kernel.ttshared.mlir" )
444- dst_path = os .path .join (tmpdir , "kernel.linkedir.mlir" )
445- Path (src_path ).write_text (ttsharedir_code )
446- dicp_opt_path = _get_dicp_opt_path ()
447- dicp_cmd_list = [
448- dicp_opt_path ,
449- src_path ,
414+ pm = ir .pass_manager (mod .context )
415+ dicp_triton .passes .linked_npu .add_lower_affine (pm )
416+ dicp_triton .passes .linked_npu .add_normalize_slice_ops (pm )
417+ dicp_triton .passes .linked_npu .add_linalg_if_to_select (pm )
418+ dicp_triton .passes .linked_npu .add_linalg_generic_to_scf (pm )
419+ dicp_triton .passes .linked_npu .add_scalar_to_1d_tensor (pm )
420+ dicp_triton .passes .linked_npu .add_linalg_to_linked (pm , False , True )
421+ dicp_triton .passes .linked_npu .add_linked_to_hivm (pm )
422+ pm .run (mod )
423+
424+ # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
425+ content = str (mod )
426+ # 将"*xfxxx"替换成"?xfxxx"
427+ content = content .replace ("*xf" , "?xf" )
428+ content = content .replace ("*xi" , "?xi" )
429+ content = content .replace ("*xbf" , "?xbf" )
430+ # 匹配形如 "memref<...> to tensor<...>" 的模式
431+ pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)"
432+ # 使用正则替换,保留memref和tensor类型,中间插入注释
433+ content = re .sub (pattern , r"\1 // to \2" , content )
434+
435+ if opt .debug or dump_ir :
436+ cmd_list = [
437+ _get_dicp_opt_path (),
438+ "kernel.ttshared.mlir" ,
450439 "--lower-affine" ,
451440 "--normalize-slice-ops" ,
452441 "--linalg-if-to-select" ,
453442 "--linalg-generic-to-scf" ,
454443 "--scalar-to-1d-tensor" ,
455444 f"--linalg-to-linked=global-kernel=false named-ops=true" ,
456445 "--linked-to-hivm" ,
457- "-o" ,
458- dst_path ,
459446 ]
460- if dump_ir :
461- print (f"DEBUG dump ir[ttsharedir_to_linkedir] command: { dicp_cmd_list } " )
462- ret = subprocess .run (dicp_cmd_list , capture_output = True , check = True )
463- # TODO(zmz): 修改test_path 中内容,暂时在python中处理,bishengir-compile后续会支持,去掉这里逻辑。
464- with open (dst_path , "r" ) as f :
465- content = f .read ()
466- # 将"*xfxxx"替换成"?xfxxx"
467- content = content .replace ("*xf" , "?xf" )
468- content = content .replace ("*xi" , "?xi" )
469- content = content .replace ("*xbf" , "?xbf" )
470- with open (dst_path , "w" ) as f :
471- f .write (content )
472-
473- # 匹配形如 "memref<...> to tensor<...>" 的模式
474- pattern = r"(memref\<.*?\>)\s+to\s+(tensor\<.*?\>)"
475- with open (dst_path , "r" ) as f :
476- lines = f .readlines ()
477- modified = []
478- for line in lines :
479- # 使用正则替换,保留memref和tensor类型,中间插入注释
480- new_line = re .sub (pattern , r"\1 // to \2" , line )
481- modified .append (new_line )
482- with open (dst_path , "w" ) as f :
483- f .writelines (modified )
484- if dump_ir :
485- shutil .copy (dst_path , "./tmp/kernel.linkedir.mlir" )
486- if replace_linked_ir is not None :
487- print (f"[DEBUG] Replace Linkedir with { replace_linked_ir } " )
488- return Path (replace_linked_ir ).read_text ()
489- return Path (dst_path ).read_text ()
447+ _dump_stage_ir (content , metadata ["hash" ], "kernel.linkedir.mlir" , cmd_list )
448+
449+ if replace_linked_ir is not None :
450+ print (f"[DEBUG] Replace Linkedir with { replace_linked_ir } " )
451+ return Path (replace_linked_ir ).read_text ()
452+ return content
490453
491454
492455def linalg_to_llir (linalg : str , metadata , opt ):
0 commit comments