diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index 7e1f0015e704..81f10e4a7dfe 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -129,7 +129,7 @@ def call_builtin( func: Union[str, Expr], args: Union[RxTuple, List[Expr]], *, - type_args: Optional[List[Type]] = None, + type_args: Optional[Union[Type, List[Type]]] = None, int_args: Optional[List[int]] = None, dtype_arg: Optional[str] = None, str_args: Optional[List[str]] = None, @@ -145,7 +145,7 @@ def call_builtin( args : Union[RxTuple, List[Expr]] The input arguments. - type_args: Optional[List[Type]] + type_args: Optional[Union[Type, List[Type]]] The type arguments to the call node. int_args: Optional[List[int]] @@ -171,6 +171,9 @@ def call_builtin( if isinstance(args, (list, tuple)): args = RxTuple(args) + if type_args is not None and not isinstance(type_args, (list, tuple)): + type_args = [type_args] + return _ffi_api.call_builtin( # type: ignore func, args, type_args, int_args, dtype_arg, str_args, require_ctx # type: ignore ) diff --git a/src/relay/printer/relax_script_printer.cc b/src/relay/printer/relax_script_printer.cc index 1a9f14c8086d..d4fefc60dba5 100644 --- a/src/relay/printer/relax_script_printer.cc +++ b/src/relay/printer/relax_script_printer.cc @@ -140,8 +140,9 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::CallNode* op) { if (!op->type_args.empty()) { doc << ", type_args="; std::vector type_args = PrintTypeArgs(op->type_args); + if (type_args.size() == 1) { - doc << type_args[0]; + doc << "(" << type_args[0] << " ,)"; } else { doc << "(" << Doc::Concat(type_args, Doc::Text(", ")) << ")"; } @@ -229,7 +230,7 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ConstantNode* op) { } else if (dtype == DataType::Bool()) { scalar_val = ScalarLiteral(dtype, static_cast(op->data->data)[0]); } - return doc << scalar_val; + return doc << "R.const(" << Doc::Concat({scalar_val, PrintDType(dtype)}) << ")"; } // default fall-back, record it as meta node. // Don't append optional_info. Because the entry function is Print, diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 9ae13a4d6f94..887bb0e2df1c 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -357,6 +357,7 @@ bool ReadIfCond(TVMArgValue cond) { } default: LOG(FATAL) << "Unknown scalar int type: " << DLDataType2String(arr->dtype); + throw; } return result != 0; } diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index 82e0c06589ef..cf2fd574e89f 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -204,7 +204,15 @@ def foo(x: R.Tensor((3, 3), "float32")): check_roundtrip(foo) -@pytest.mark.skip("Need to fix string ast expr") +def test_relax_base_op(): + @R.function + def foo(x: R.Tensor((2, 4), dtype="float32")): + gv = R.call_builtin("test_intrin", [x], type_args=R.Object) + return gv + + check_roundtrip(foo) + + def test_primexpr_arithmetic(): @R.function def foo(x: R.Tensor(("n", "m"), "float32")): @@ -273,6 +281,18 @@ def my_const(x: R.Tensor((2, 3), "float32")): check_roundtrip(my_const) +def test_scalar_const(): + x = relax.Var("x", relax.TensorStructInfo((8,))) + const_one = relax.const(1, "float32") + + bb = relax.BlockBuilder() + with bb.function("main", [x]): + v0 = bb.emit(R.add(x, const_one)) + bb.emit_func_output(v0) + + check_roundtrip(bb.get()) + + def test_const_meta(): def _get_meta_data(): @R.function