diff --git a/.gitignore b/.gitignore index d3dead3618..f9454e52fb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ tmp/ dist/ params/ +debug/ *.bak # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/python/mlc_chat/cli/compile.py b/python/mlc_chat/cli/compile.py index 9a22927bb4..3e65b3110a 100644 --- a/python/mlc_chat/cli/compile.py +++ b/python/mlc_chat/cli/compile.py @@ -2,6 +2,7 @@ import argparse import json import re +from functools import partial from pathlib import Path from typing import Union @@ -37,6 +38,14 @@ def _parse_output(path: Union[str, Path]) -> Path: raise argparse.ArgumentTypeError(f"Directory does not exist: {parent}") return path + def _parse_dir(path: Union[str, Path], auto_create: bool = False) -> Path: + path = Path(path) + if not auto_create and not path.is_dir(): + raise argparse.ArgumentTypeError(f"Directory does not exist: {path}") + if auto_create and not path.is_dir(): + path.mkdir(parents=True) + return path + def _check_system_lib_prefix(prefix: str) -> str: pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*$" if prefix == "" or re.match(pattern, prefix): @@ -46,7 +55,7 @@ def _check_system_lib_prefix(prefix: str) -> str: "numbers (0-9), alphabets (A-Z, a-z) and underscore (_)." ) - parser = ArgumentParser("MLC LLM Compiler") + parser = ArgumentParser("mlc_chat compile") parser.add_argument( "model", type=detect_mlc_chat_config, @@ -103,6 +112,12 @@ def _check_system_lib_prefix(prefix: str) -> str: default="", help=HELP["overrides"] + ' (default: "%(default)s")', ) + parser.add_argument( + "--debug-dump", + type=partial(_parse_dir, auto_create=True), + default=None, + help=HELP["debug_dump"] + " (default: %(default)s)", + ) parsed = parser.parse_args(argv) target, build_func = detect_target_and_host(parsed.device, parsed.host) parsed.model_type = detect_model_type(parsed.model_type, parsed.model) @@ -123,4 +138,5 @@ def _check_system_lib_prefix(prefix: str) -> str: system_lib_prefix=parsed.system_lib_prefix, output=parsed.output, overrides=parsed.overrides, + debug_dump=parsed.debug_dump, ) diff --git a/python/mlc_chat/compiler_pass/pipeline.py b/python/mlc_chat/compiler_pass/pipeline.py index 53415d87d6..2997d7ff52 100644 --- a/python/mlc_chat/compiler_pass/pipeline.py +++ b/python/mlc_chat/compiler_pass/pipeline.py @@ -1,5 +1,6 @@ """The compilation pipeline for LLM applications.""" -from typing import Any, Dict, List +from pathlib import Path +from typing import Any, Dict, List, Optional import tvm from tvm import IRModule @@ -34,13 +35,34 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR return mod +@tvm.transform.module_pass(opt_level=0, name="DebugDump") +class _DebugDump: # pylint: disable=too-few-public-methods + """A dummy compiler pass that does nothing but logging. + Only enabled when debug_dump is not None""" + + def __init__(self, file_name: str, file_path: Optional[Path], show_meta: bool = False): + self.file_name = file_name + self.file_path = file_path + self.show_meta = show_meta + + def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: + """A dummy transformation that dumps the module to file""" + if self.file_path is not None: + # NOTE: We use debug level here to avoid spamming the console + logger.debug("Dumping IR to %s", self.file_path / self.file_name) + with open(self.file_path / self.file_name, "w", encoding="utf-8") as f: + f.write(mod.script(show_meta=self.show_meta)) + return mod + + @register_pipeline("mlc_llm") -def _mlc_llm_pipeline( +def _mlc_llm_pipeline( # pylint: disable=too-many-arguments variable_bounds: Dict[str, int] = None, additional_tirs: Dict[str, tvm.tir.PrimFunc] = None, metadata: Dict[str, Any] = None, ext_mods: List[nn.ExternModule] = None, skip_gemm: bool = False, + debug_dump: Optional[Path] = None, ): variable_bounds = variable_bounds or {} additional_tirs = additional_tirs or {} @@ -54,10 +76,12 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I # Phase 0. Add additional information for compilation AttachVariableBounds(variable_bounds), AttachAdditionalPrimFuncs(additional_tirs), + _DebugDump("debug-phase0.py", debug_dump, show_meta=False), # Phase 1. Passes on high-level operator graph _LogProgress("Running TVM Relax graph-level optimizations"), FuseDequantizeTranspose(skip_gemm=skip_gemm), FuseTransposeMatmul(), + _DebugDump("debug-phase1.py", debug_dump, show_meta=False), # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline _LogProgress("Lowering to TVM TIR kernels"), tvm.relax.transform.LegalizeOps(), @@ -65,12 +89,14 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tvm.relax.transform.FoldConstant(), tvm.relax.transform.FuseOps(), tvm.relax.transform.FuseTIR(), + _DebugDump("debug-phase2.py", debug_dump, show_meta=False), # Phase 3. Passes on TIR _LogProgress("Running TVM TIR-level optimizations"), FuseDequantizeMatmulEwise(), FuseDequantizeTake(), tvm.relax.transform.DeadCodeElimination(), CleanUpTIRAttrs(["op_pattern"]), + _DebugDump("debug-phase3.py", debug_dump, show_meta=False), # Phase 4. Low-level Optimizations _LogProgress("Running TVM Dlight low-level optimizations"), dl.ApplyDefaultSchedule( @@ -80,6 +106,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I dl.gpu.GeneralReduction(), dl.gpu.Fallback(), ), + _DebugDump("debug-phase4.py", debug_dump, show_meta=False), _LogProgress("Lowering to VM bytecode"), LiftTIRGlobalBufferAlloc(), tvm.tir.transform.ForceNarrowIndexToInt32(), @@ -95,6 +122,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I tvm.relax.transform.VMBuiltinLower(), tvm.relax.transform.VMShapeLower(), tvm.relax.transform.AttachGlobalSymbol(), + _DebugDump("debug-final.py", debug_dump, show_meta=False), _LogProgress("Compiling external modules"), tvm.relax.transform.AttachExternModules(ext_mods), _LogProgress("Compilation complete! Exporting to disk"), diff --git a/python/mlc_chat/help.py b/python/mlc_chat/help.py index da3c002c83..55c019bbfe 100644 --- a/python/mlc_chat/help.py +++ b/python/mlc_chat/help.py @@ -57,7 +57,7 @@ "context_window_size": """ Option to provide the maximum sequence length supported by the model. This is usually explicitly shown as context length or context window in the model card. -If this option is not set explicitly, by default, +If this option is not set explicitly, by default, it will be determined by `context_window_size` or `max_position_embeddings` in `config.json`, and the latter is usually inaccurate for some models. """.strip(), @@ -110,5 +110,10 @@ `context_window_size`, `prefill_chunk_size`, `sliding_window_size`, `attention_sink_size`, `max_batch_size` and `tensor_parallel_shards`. Meanwhile, model config could be explicitly specified via details knobs, e.g. --overrides "context_window_size=1024;prefill_chunk_size=128". +""".strip(), + "debug_dump": """ +Specifies the directory where the compiler will store its IRs for debugging purposes +during various phases of compilation. By default, this is set to `None`, indicating +that debug dumping is disabled. """.strip(), } diff --git a/python/mlc_chat/interface/compile.py b/python/mlc_chat/interface/compile.py index 81b147956a..55229b5ade 100644 --- a/python/mlc_chat/interface/compile.py +++ b/python/mlc_chat/interface/compile.py @@ -2,7 +2,7 @@ import dataclasses from io import StringIO from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from tvm import IRModule, relax, tir from tvm.ir.transform import Pass @@ -97,6 +97,7 @@ class CompileArgs: # pylint: disable=too-many-instance-attributes system_lib_prefix: str output: Path overrides: ModelConfigOverride + debug_dump: Optional[Path] def __post_init__(self) -> None: self.opt.update(self.target) @@ -113,6 +114,8 @@ def display(self) -> None: print(f" {bold('--system-lib-prefix'):<25} \"{self.system_lib_prefix}\"", file=out) print(f" {bold('--output'):<25} {self.output}", file=out) print(f" {bold('--overrides'):<25} {self.overrides}", file=out) + # As it's debug only, no need to display + # print(f" {bold('--debug-dump'):<25} {self.debug_dump}", file=out) print(out.getvalue().rstrip()) @@ -200,6 +203,7 @@ def _get_param_metadata(name: str, param: nn.Parameter) -> Dict[str, Any]: additional_tirs=additional_tirs, ext_mods=ext_mods, metadata=metadata, + debug_dump=args.debug_dump, ), ) logger.info("Generated: %s", bold(str(args.output))) @@ -215,6 +219,7 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin system_lib_prefix: str, output: Path, overrides: ModelConfigOverride, + debug_dump: Optional[Path] = None, ): """Compile a model given its configuration and quantization format to a specific target.""" if "model_config" in config: @@ -231,6 +236,7 @@ def compile( # pylint: disable=too-many-arguments,redefined-builtin system_lib_prefix, output, overrides, + debug_dump, ) args.display() _compile(args, model_config)