1
1
"""The compilation pipeline for LLM applications."""
2
- from typing import Any , Dict , List
2
+ from pathlib import Path
3
+ from typing import Any , Dict , List , Optional
3
4
4
5
import tvm
5
6
from tvm import IRModule
@@ -34,13 +35,34 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
34
35
return mod
35
36
36
37
38
+ @tvm .transform .module_pass (opt_level = 0 , name = "DebugDump" )
39
+ class _DebugDump : # pylint: disable=too-few-public-methods
40
+ """A dummy compiler pass that does nothing but logging.
41
+ Only enabled when debug_dump is not None"""
42
+
43
+ def __init__ (self , file_name : str , file_path : Optional [Path ], show_meta : bool = False ):
44
+ self .file_name = file_name
45
+ self .file_path = file_path
46
+ self .show_meta = show_meta
47
+
48
+ def transform_module (self , mod : IRModule , _ctx : tvm .transform .PassContext ) -> IRModule :
49
+ """A dummy transformation that dumps the module to file"""
50
+ if self .file_path is not None :
51
+ # NOTE: We use debug level here to avoid spamming the console
52
+ logger .debug ("Dumping IR to %s" , self .file_path / self .file_name )
53
+ with open (self .file_path / self .file_name , "w" , encoding = "utf-8" ) as f :
54
+ f .write (mod .script (show_meta = self .show_meta ))
55
+ return mod
56
+
57
+
37
58
@register_pipeline ("mlc_llm" )
38
- def _mlc_llm_pipeline (
59
+ def _mlc_llm_pipeline ( # pylint: disable=too-many-arguments
39
60
variable_bounds : Dict [str , int ] = None ,
40
61
additional_tirs : Dict [str , tvm .tir .PrimFunc ] = None ,
41
62
metadata : Dict [str , Any ] = None ,
42
63
ext_mods : List [nn .ExternModule ] = None ,
43
64
skip_gemm : bool = False ,
65
+ debug_dump : Optional [Path ] = None ,
44
66
):
45
67
variable_bounds = variable_bounds or {}
46
68
additional_tirs = additional_tirs or {}
@@ -54,23 +76,27 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
54
76
# Phase 0. Add additional information for compilation
55
77
AttachVariableBounds (variable_bounds ),
56
78
AttachAdditionalPrimFuncs (additional_tirs ),
79
+ _DebugDump ("debug-phase0.py" , debug_dump , show_meta = False ),
57
80
# Phase 1. Passes on high-level operator graph
58
81
_LogProgress ("Running TVM Relax graph-level optimizations" ),
59
82
FuseDequantizeTranspose (skip_gemm = skip_gemm ),
60
83
FuseTransposeMatmul (),
84
+ _DebugDump ("debug-phase1.py" , debug_dump , show_meta = False ),
61
85
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
62
86
_LogProgress ("Lowering to TVM TIR kernels" ),
63
87
tvm .relax .transform .LegalizeOps (),
64
88
tvm .relax .transform .AnnotateTIROpPattern (),
65
89
tvm .relax .transform .FoldConstant (),
66
90
tvm .relax .transform .FuseOps (),
67
91
tvm .relax .transform .FuseTIR (),
92
+ _DebugDump ("debug-phase2.py" , debug_dump , show_meta = False ),
68
93
# Phase 3. Passes on TIR
69
94
_LogProgress ("Running TVM TIR-level optimizations" ),
70
95
FuseDequantizeMatmulEwise (),
71
96
FuseDequantizeTake (),
72
97
tvm .relax .transform .DeadCodeElimination (),
73
98
CleanUpTIRAttrs (["op_pattern" ]),
99
+ _DebugDump ("debug-phase3.py" , debug_dump , show_meta = False ),
74
100
# Phase 4. Low-level Optimizations
75
101
_LogProgress ("Running TVM Dlight low-level optimizations" ),
76
102
dl .ApplyDefaultSchedule (
@@ -80,6 +106,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
80
106
dl .gpu .GeneralReduction (),
81
107
dl .gpu .Fallback (),
82
108
),
109
+ _DebugDump ("debug-phase4.py" , debug_dump , show_meta = False ),
83
110
_LogProgress ("Lowering to VM bytecode" ),
84
111
LiftTIRGlobalBufferAlloc (),
85
112
tvm .tir .transform .ForceNarrowIndexToInt32 (),
@@ -95,6 +122,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
95
122
tvm .relax .transform .VMBuiltinLower (),
96
123
tvm .relax .transform .VMShapeLower (),
97
124
tvm .relax .transform .AttachGlobalSymbol (),
125
+ _DebugDump ("debug-final.py" , debug_dump , show_meta = False ),
98
126
_LogProgress ("Compiling external modules" ),
99
127
tvm .relax .transform .AttachExternModules (ext_mods ),
100
128
_LogProgress ("Compilation complete! Exporting to disk" ),
0 commit comments