Skip to content

Commit

Permalink
Skip cublas dispatch for single batch (#2315)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored May 10, 2024
1 parent b01cfab commit 347222c
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/mlc_llm/compiler_pass/cublas_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
model_names = [
gv.name_hint for gv, func in mod.functions.items() if isinstance(func, relax.Function)
]
model_names = [name for name in model_names if "batch" not in name]
mod = tvm.transform.Sequential(
[
relax.transform.FuseOpsByPattern(
patterns, bind_constants=False, annotate_codegen=True
patterns,
bind_constants=False,
annotate_codegen=True,
entry_functions=model_names,
),
relax.transform.RunCodegen({}, entry_functions=model_names),
]
Expand Down

0 comments on commit 347222c

Please sign in to comment.