Skip to content

Commit a9f1b72

Browse files
authored
[SLM] Enable Debug Dump (#1499)
This PR enables the debug dump feature. The command would be something like ``` mlc_chat compile ./dist/Llama-2-7b-chat-hf-q4f16_1-MLC/mlc-chat-config.json --device cuda -o dist/libs/Llama-2-7b-chat-hf-q4f16_1-cuda.so --debug-dump debug/ ``` And it would dump 6 files in the `debug/` folder: ``` debug-phase0.py debug-phase1.py debug-phase2.py debug-phase3.py debug-phase4.py debug-final.py ```
1 parent 779b1a5 commit a9f1b72

File tree

5 files changed

+61
-5
lines changed

5 files changed

+61
-5
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
tmp/
22
dist/
33
params/
4+
debug/
45
*.bak
56
# Byte-compiled / optimized / DLL files
67
__pycache__/

python/mlc_chat/cli/compile.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33
import json
44
import re
5+
from functools import partial
56
from pathlib import Path
67
from typing import Union
78

@@ -37,6 +38,14 @@ def _parse_output(path: Union[str, Path]) -> Path:
3738
raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}")
3839
return path
3940

41+
def _parse_dir(path: Union[str, Path], auto_create: bool = False) -> Path:
42+
path = Path(path)
43+
if not auto_create and not path.is_dir():
44+
raise argparse.ArgumentTypeError(f"Directory does not exist: {path}")
45+
if auto_create and not path.is_dir():
46+
path.mkdir(parents=True)
47+
return path
48+
4049
def _check_system_lib_prefix(prefix: str) -> str:
4150
pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$"
4251
if prefix == "" or re.match(pattern, prefix):
@@ -46,7 +55,7 @@ def _check_system_lib_prefix(prefix: str) -> str:
4655
"numbers (0-9), alphabets (A-Z, a-z) and underscore (_)."
4756
)
4857

49-
parser = ArgumentParser("MLC LLM Compiler")
58+
parser = ArgumentParser("mlc_chat compile")
5059
parser.add_argument(
5160
"model",
5261
type=detect_mlc_chat_config,
@@ -103,6 +112,12 @@ def _check_system_lib_prefix(prefix: str) -> str:
103112
default="",
104113
help=HELP["overrides"] + ' (default: "%(default)s")',
105114
)
115+
parser.add_argument(
116+
"--debug-dump",
117+
type=partial(_parse_dir, auto_create=True),
118+
default=None,
119+
help=HELP["debug_dump"] + " (default: %(default)s)",
120+
)
106121
parsed = parser.parse_args(argv)
107122
target, build_func = detect_target_and_host(parsed.device, parsed.host)
108123
parsed.model_type = detect_model_type(parsed.model_type, parsed.model)
@@ -123,4 +138,5 @@ def _check_system_lib_prefix(prefix: str) -> str:
123138
system_lib_prefix=parsed.system_lib_prefix,
124139
output=parsed.output,
125140
overrides=parsed.overrides,
141+
debug_dump=parsed.debug_dump,
126142
)

python/mlc_chat/compiler_pass/pipeline.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""The compilation pipeline for LLM applications."""
2-
from typing import Any, Dict, List
2+
from pathlib import Path
3+
from typing import Any, Dict, List, Optional
34

45
import tvm
56
from tvm import IRModule
@@ -34,13 +35,34 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
3435
return mod
3536

3637

38+
@tvm.transform.module_pass(opt_level=0, name="DebugDump")
39+
class _DebugDump: # pylint: disable=too-few-public-methods
40+
"""A dummy compiler pass that does nothing but logging.
41+
Only enabled when debug_dump is not None"""
42+
43+
def __init__(self, file_name: str, file_path: Optional[Path], show_meta: bool = False):
44+
self.file_name = file_name
45+
self.file_path = file_path
46+
self.show_meta = show_meta
47+
48+
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
49+
"""A dummy transformation that dumps the module to file"""
50+
if self.file_path is not None:
51+
# NOTE: We use debug level here to avoid spamming the console
52+
logger.debug("Dumping IR to %s", self.file_path / self.file_name)
53+
with open(self.file_path / self.file_name, "w", encoding="utf-8") as f:
54+
f.write(mod.script(show_meta=self.show_meta))
55+
return mod
56+
57+
3758
@register_pipeline("mlc_llm")
38-
def _mlc_llm_pipeline(
59+
def _mlc_llm_pipeline( # pylint: disable=too-many-arguments
3960
variable_bounds: Dict[str, int] = None,
4061
additional_tirs: Dict[str, tvm.tir.PrimFunc] = None,
4162
metadata: Dict[str, Any] = None,
4263
ext_mods: List[nn.ExternModule] = None,
4364
skip_gemm: bool = False,
65+
debug_dump: Optional[Path] = None,
4466
):
4567
variable_bounds = variable_bounds or {}
4668
additional_tirs = additional_tirs or {}
@@ -54,23 +76,27 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
5476
# Phase 0. Add additional information for compilation
5577
AttachVariableBounds(variable_bounds),
5678
AttachAdditionalPrimFuncs(additional_tirs),
79+
_DebugDump("debug-phase0.py", debug_dump, show_meta=False),
5780
# Phase 1. Passes on high-level operator graph
5881
_LogProgress("Running TVM Relax graph-level optimizations"),
5982
FuseDequantizeTranspose(skip_gemm=skip_gemm),
6083
FuseTransposeMatmul(),
84+
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
6185
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
6286
_LogProgress("Lowering to TVM TIR kernels"),
6387
tvm.relax.transform.LegalizeOps(),
6488
tvm.relax.transform.AnnotateTIROpPattern(),
6589
tvm.relax.transform.FoldConstant(),
6690
tvm.relax.transform.FuseOps(),
6791
tvm.relax.transform.FuseTIR(),
92+
_DebugDump("debug-phase2.py", debug_dump, show_meta=False),
6893
# Phase 3. Passes on TIR
6994
_LogProgress("Running TVM TIR-level optimizations"),
7095
FuseDequantizeMatmulEwise(),
7196
FuseDequantizeTake(),
7297
tvm.relax.transform.DeadCodeElimination(),
7398
CleanUpTIRAttrs(["op_pattern"]),
99+
_DebugDump("debug-phase3.py", debug_dump, show_meta=False),
74100
# Phase 4. Low-level Optimizations
75101
_LogProgress("Running TVM Dlight low-level optimizations"),
76102
dl.ApplyDefaultSchedule(
@@ -80,6 +106,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
80106
dl.gpu.GeneralReduction(),
81107
dl.gpu.Fallback(),
82108
),
109+
_DebugDump("debug-phase4.py", debug_dump, show_meta=False),
83110
_LogProgress("Lowering to VM bytecode"),
84111
LiftTIRGlobalBufferAlloc(),
85112
tvm.tir.transform.ForceNarrowIndexToInt32(),
@@ -95,6 +122,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
95122
tvm.relax.transform.VMBuiltinLower(),
96123
tvm.relax.transform.VMShapeLower(),
97124
tvm.relax.transform.AttachGlobalSymbol(),
125+
_DebugDump("debug-final.py", debug_dump, show_meta=False),
98126
_LogProgress("Compiling external modules"),
99127
tvm.relax.transform.AttachExternModules(ext_mods),
100128
_LogProgress("Compilation complete! Exporting to disk"),

python/mlc_chat/help.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
"context_window_size": """
5858
Option to provide the maximum sequence length supported by the model.
5959
This is usually explicitly shown as context length or context window in the model card.
60-
If this option is not set explicitly, by default,
60+
If this option is not set explicitly, by default,
6161
it will be determined by `context_window_size` or `max_position_embeddings` in `config.json`,
6262
and the latter is usually inaccurate for some models.
6363
""".strip(),
@@ -110,5 +110,10 @@
110110
`context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`,
111111
`max_batch_size` and `tensor_parallel_shards`. Meanwhile, model config could be explicitly
112112
specified via details knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128".
113+
""".strip(),
114+
"debug_dump": """
115+
Specifies the directory where the compiler will store its IRs for debugging purposes
116+
during various phases of compilation. By default, this is set to `None`, indicating
117+
that debug dumping is disabled.
113118
""".strip(),
114119
}

python/mlc_chat/interface/compile.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import dataclasses
33
from io import StringIO
44
from pathlib import Path
5-
from typing import Any, Callable, Dict, List, Tuple
5+
from typing import Any, Callable, Dict, List, Optional, Tuple
66

77
from tvm import IRModule, relax, tir
88
from tvm.ir.transform import Pass
@@ -97,6 +97,7 @@ class CompileArgs: # pylint: disable=too-many-instance-attributes
9797
system_lib_prefix: str
9898
output: Path
9999
overrides: ModelConfigOverride
100+
debug_dump: Optional[Path]
100101

101102
def __post_init__(self) -> None:
102103
self.opt.update(self.target)
@@ -113,6 +114,8 @@ def display(self) -> None:
113114
print(f" {bold('--system-lib-prefix'):<25} \"{self.system_lib_prefix}\"", file=out)
114115
print(f" {bold('--output'):<25} {self.output}", file=out)
115116
print(f" {bold('--overrides'):<25} {self.overrides}", file=out)
117+
# As it's debug only, no need to display
118+
# print(f" {bold('--debug-dump'):<25} {self.debug_dump}", file=out)
116119
print(out.getvalue().rstrip())
117120

118121

@@ -200,6 +203,7 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]:
200203
additional_tirs=additional_tirs,
201204
ext_mods=ext_mods,
202205
metadata=metadata,
206+
debug_dump=args.debug_dump,
203207
),
204208
)
205209
logger.info("Generated: %s", bold(str(args.output)))
@@ -215,6 +219,7 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
215219
system_lib_prefix: str,
216220
output: Path,
217221
overrides: ModelConfigOverride,
222+
debug_dump: Optional[Path] = None,
218223
):
219224
"""Compile a model given its configuration and quantization format to a specific target."""
220225
if "model_config" in config:
@@ -231,6 +236,7 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin
231236
system_lib_prefix,
232237
output,
233238
overrides,
239+
debug_dump,
234240
)
235241
args.display()
236242
_compile(args, model_config)

0 commit comments

Comments
 (0)