-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[C++] Invoke storage allocation for CUDA Graph explicitly
This PR adds a function that invokes the storage allocation function generated by CUDA Graph rewrite. With this function, we now manually trigger the storage allocation at initialization time. The reason we need this is because that the storage allocation may contain CUDA IPC memory alloc that has to run through a Disco session. So when a function that needs CUDA graph storage allocation runs first outside a Disco session, there might be error caused if we did not initialize the allocation in advance.
- Loading branch information
1 parent
e349684
commit 3e0d408
Showing
4 changed files
with
41 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
python/mlc_llm/compiler_pass/attach_cuda_graph_alloc_init_func.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
"""The pass that attaches an empty function for initialization.""" | ||
|
||
import tvm | ||
from tvm import IRModule, relax | ||
|
||
|
||
@tvm.transform.module_pass(opt_level=0, name="AttachCUDAGraphAllocInitFunc") | ||
class AttachCUDAGraphAllocInitFunc: # pylint: disable=too-few-public-methods | ||
"""Attach an empty function for initialization.""" | ||
|
||
def __init__(self): | ||
pass | ||
|
||
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: | ||
"""Entrypoint""" | ||
bb = relax.BlockBuilder(mod) | ||
alloc_func_gv = None | ||
for gv, _ in mod.functions_items(): | ||
if gv.name_hint.startswith("cuda_graph_alloc"): | ||
assert alloc_func_gv is None | ||
alloc_func_gv = gv | ||
if alloc_func_gv is None: | ||
return mod | ||
|
||
with bb.function("cuda_graph_alloc_init", []): | ||
bb.emit_func_output( | ||
relax.op.call_builtin_with_ctx( | ||
"vm.builtin.cuda_graph.get_cached_alloc", | ||
args=[alloc_func_gv, relax.PrimValue(0)], | ||
sinfo_args=relax.ObjectStructInfo(), | ||
) | ||
) | ||
return bb.finalize() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters