diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index dde77da212..99ecd5a70b 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -17,14 +17,16 @@ # pylint: disable=invalid-name, dangerous-default-value, arguments-differ """Driver for partitioning and building a Relay module for CUTLASS offload.""" import logging -import os import multiprocessing +import os + import tvm -from tvm import runtime, relay, relax -from tvm.contrib.nvcc import get_cuda_version +from tvm import relax, relay, runtime from tvm._ffi.registry import register_func -from .gen_gemm import CutlassGemmProfiler +from tvm.contrib.nvcc import get_cuda_version + from .gen_conv2d import CutlassConv2DProfiler +from .gen_gemm import CutlassGemmProfiler from .library import ConvKind logger = logging.getLogger("cutlass") @@ -532,7 +534,7 @@ def _extract_relax_function_info(f): def fvisit(e): nonlocal op_attrs - if isinstance(e, relax.Call) and str(e.op) in ["relax.nn.conv2d"]: + if isinstance(e, relax.Call) and e.op.name in ["relax.nn.conv2d"]: op_attrs = e.attrs relax.analysis.post_order_visit(f.body, fvisit) diff --git a/src/ir/function.cc b/src/ir/function.cc index 54a4529b91..d1063b312a 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -49,6 +49,24 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); }); +TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> BaseFunc { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttrs")) { + if (Optional ret = (*f)(func, attr_map)) { + return ret.value(); + } + } + if (const auto* f = runtime::Registry::Get("relax.FuncWithAttrs")) { + if (Optional ret = (*f)(func, attr_map)) { + return ret.value(); + } + } + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + }); + TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") .set_body_typed([](BaseFunc func, String key) -> BaseFunc { if (func->IsInstance()) { diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index b1ac703abf..067714873d 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -551,6 +551,14 @@ TVM_REGISTER_GLOBAL("relax.FuncWithAttr") return NullOpt; }); +TVM_REGISTER_GLOBAL("relax.FuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> Optional { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + return NullOpt; + }); + TVM_REGISTER_GLOBAL("relax.FuncWithoutAttr") .set_body_typed([](BaseFunc func, String key) -> Optional { if (func->IsInstance()) { diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index b73340df30..b3a2b86d12 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -262,6 +262,14 @@ TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr") return NullOpt; }); +TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> Optional { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } + return NullOpt; + }); + TVM_REGISTER_GLOBAL("relay.ir.FuncWithoutAttr") .set_body_typed([](BaseFunc func, String key) -> Optional { if (func->IsInstance()) { diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 4ed894783a..445eb44f5a 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -667,23 +667,19 @@ def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"): s = R.add(x, x) return s - with pytest.raises(OSError): - - @I.ir_module - class ErrorMod: - @R.function - def main(x: R.Tensor((10, 5), "float32")): - inner = foo - gv1 = inner(x) - gv2 = foo(x) - return (inner, gv1, gv2) + @I.ir_module + class Mod2: + @R.function + def main(x: R.Tensor((10, 5), "float32")): + inner = foo + gv1 = inner(x) + gv2 = foo(x) + return (inner, gv1, gv2) - @R.function - def foo( - x: R.Tensor((10, 5), "float32") - ): # need function ret info since it is parse later than `main` - s = R.add(x, x) - return s + @R.function + def foo(x: R.Tensor((10, 5), "float32")): # the return type is automatically inferred + s = R.add(x, x) + return s def test_if_branch():