Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Fix after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Feb 9, 2023
1 parent c8b3123 commit 3cbe967
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 21 deletions.
12 changes: 7 additions & 5 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@
# pylint: disable=invalid-name, dangerous-default-value, arguments-differ
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
import logging
import os
import multiprocessing
import os

import tvm
from tvm import runtime, relay, relax
from tvm.contrib.nvcc import get_cuda_version
from tvm import relax, relay, runtime
from tvm._ffi.registry import register_func
from .gen_gemm import CutlassGemmProfiler
from tvm.contrib.nvcc import get_cuda_version

from .gen_conv2d import CutlassConv2DProfiler
from .gen_gemm import CutlassGemmProfiler
from .library import ConvKind

logger = logging.getLogger("cutlass")
Expand Down Expand Up @@ -532,7 +534,7 @@ def _extract_relax_function_info(f):

def fvisit(e):
nonlocal op_attrs
if isinstance(e, relax.Call) and str(e.op) in ["relax.nn.conv2d"]:
if isinstance(e, relax.Call) and e.op.name in ["relax.nn.conv2d"]:
op_attrs = e.attrs

relax.analysis.post_order_visit(f.body, fvisit)
Expand Down
18 changes: 18 additions & 0 deletions src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
});

TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs")
.set_body_typed([](BaseFunc func, Map<String, ObjectRef> attr_map) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttrs(Downcast<tir::PrimFunc>(std::move(func)), attr_map);
}
if (const auto* f = runtime::Registry::Get("relay.ir.FuncWithAttrs")) {
if (Optional<BaseFunc> ret = (*f)(func, attr_map)) {
return ret.value();
}
}
if (const auto* f = runtime::Registry::Get("relax.FuncWithAttrs")) {
if (Optional<BaseFunc> ret = (*f)(func, attr_map)) {
return ret.value();
}
}
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
});

TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr")
.set_body_typed([](BaseFunc func, String key) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
Expand Down
8 changes: 8 additions & 0 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,14 @@ TVM_REGISTER_GLOBAL("relax.FuncWithAttr")
return NullOpt;
});

TVM_REGISTER_GLOBAL("relax.FuncWithAttrs")
.set_body_typed([](BaseFunc func, Map<String, ObjectRef> attr_map) -> Optional<Function> {
if (func->IsInstance<relax::FunctionNode>()) {
return WithAttrs(Downcast<relax::Function>(std::move(func)), attr_map);
}
return NullOpt;
});

TVM_REGISTER_GLOBAL("relax.FuncWithoutAttr")
.set_body_typed([](BaseFunc func, String key) -> Optional<Function> {
if (func->IsInstance<relax::FunctionNode>()) {
Expand Down
8 changes: 8 additions & 0 deletions src/relay/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,14 @@ TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttr")
return NullOpt;
});

TVM_REGISTER_GLOBAL("relay.ir.FuncWithAttrs")
.set_body_typed([](BaseFunc func, Map<String, ObjectRef> attr_map) -> Optional<Function> {
if (func->IsInstance<relay::FunctionNode>()) {
return WithAttrs(Downcast<relay::Function>(std::move(func)), attr_map);
}
return NullOpt;
});

TVM_REGISTER_GLOBAL("relay.ir.FuncWithoutAttr")
.set_body_typed([](BaseFunc func, String key) -> Optional<Function> {
if (func->IsInstance<relay::FunctionNode>()) {
Expand Down
28 changes: 12 additions & 16 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,23 +667,19 @@ def foo(x: R.Tensor((10, 5), "float32")) -> R.Tensor((10, 5), "float32"):
s = R.add(x, x)
return s

with pytest.raises(OSError):

@I.ir_module
class ErrorMod:
@R.function
def main(x: R.Tensor((10, 5), "float32")):
inner = foo
gv1 = inner(x)
gv2 = foo(x)
return (inner, gv1, gv2)
@I.ir_module
class Mod2:
@R.function
def main(x: R.Tensor((10, 5), "float32")):
inner = foo
gv1 = inner(x)
gv2 = foo(x)
return (inner, gv1, gv2)

@R.function
def foo(
x: R.Tensor((10, 5), "float32")
): # need function ret info since it is parse later than `main`
s = R.add(x, x)
return s
@R.function
def foo(x: R.Tensor((10, 5), "float32")): # the return type is automatically inferred
s = R.add(x, x)
return s


def test_if_branch():
Expand Down

0 comments on commit 3cbe967

Please sign in to comment.