Skip to content

Commit

Permalink
[TVMScript][Bugfix] Minor fix for type_args and constant printing (a…
Browse files Browse the repository at this point in the history
…pache#349)

* [TVMScript][Bugfix] Minor fix for `type_args` and constant printing
  • Loading branch information
Hzfengsy authored Jan 10, 2023
1 parent f7cfdb5 commit 3cd616a
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
7 changes: 5 additions & 2 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def call_builtin(
func: Union[str, Expr],
args: Union[RxTuple, List[Expr]],
*,
type_args: Optional[List[Type]] = None,
type_args: Optional[Union[Type, List[Type]]] = None,
int_args: Optional[List[int]] = None,
dtype_arg: Optional[str] = None,
str_args: Optional[List[str]] = None,
Expand All @@ -145,7 +145,7 @@ def call_builtin(
args : Union[RxTuple, List[Expr]]
The input arguments.
type_args: Optional[List[Type]]
type_args: Optional[Union[Type, List[Type]]]
The type arguments to the call node.
int_args: Optional[List[int]]
Expand All @@ -171,6 +171,9 @@ def call_builtin(
if isinstance(args, (list, tuple)):
args = RxTuple(args)

if type_args is not None and not isinstance(type_args, (list, tuple)):
type_args = [type_args]

return _ffi_api.call_builtin( # type: ignore
func, args, type_args, int_args, dtype_arg, str_args, require_ctx # type: ignore
)
Expand Down
5 changes: 3 additions & 2 deletions src/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::CallNode* op) {
if (!op->type_args.empty()) {
doc << ", type_args=";
std::vector<Doc> type_args = PrintTypeArgs(op->type_args);

if (type_args.size() == 1) {
doc << type_args[0];
doc << "(" << type_args[0] << " ,)";
} else {
doc << "(" << Doc::Concat(type_args, Doc::Text(", ")) << ")";
}
Expand Down Expand Up @@ -229,7 +230,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ConstantNode* op) {
} else if (dtype == DataType::Bool()) {
scalar_val = ScalarLiteral(dtype, static_cast<const uint8_t*>(op->data->data)[0]);
}
return doc << scalar_val;
return doc << "R.const(" << Doc::Concat({scalar_val, PrintDType(dtype)}) << ")";
}
// default fall-back, record it as meta node.
// Don't append optional_info. Because the entry function is Print,
Expand Down
1 change: 1 addition & 0 deletions src/runtime/relax_vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ bool ReadIfCond(TVMArgValue cond) {
}
default:
LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(arr->dtype);
throw;
}
return result != 0;
}
Expand Down
22 changes: 21 additions & 1 deletion tests/python/relax/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,15 @@ def foo(x: R.Tensor((3, 3), "float32")):
check_roundtrip(foo)


@pytest.mark.skip("Need to fix string ast expr")
def test_relax_base_op():
@R.function
def foo(x: R.Tensor((2, 4), dtype="float32")):
gv = R.call_builtin("test_intrin", [x], type_args=R.Object)
return gv

check_roundtrip(foo)


def test_primexpr_arithmetic():
@R.function
def foo(x: R.Tensor(("n", "m"), "float32")):
Expand Down Expand Up @@ -273,6 +281,18 @@ def my_const(x: R.Tensor((2, 3), "float32")):
check_roundtrip(my_const)


def test_scalar_const():
x = relax.Var("x", relax.TensorStructInfo((8,)))
const_one = relax.const(1, "float32")

bb = relax.BlockBuilder()
with bb.function("main", [x]):
v0 = bb.emit(R.add(x, const_one))
bb.emit_func_output(v0)

check_roundtrip(bb.get())


def test_const_meta():
def _get_meta_data():
@R.function
Expand Down

0 comments on commit 3cd616a

Please sign in to comment.