From 2963787ea5c3777c66c54acfcac98f1ecfa2d894 Mon Sep 17 00:00:00 2001 From: sung Date: Mon, 17 Oct 2022 18:09:34 -0700 Subject: [PATCH] Recover: [Bugfix] Couple of bug fixes to run TVM-gen code together with BYOC (#249) --- python/tvm/ir/function.py | 28 +++- python/tvm/ir/module.py | 37 ++++- python/tvm/relax/testing/transform.py | 9 +- .../transform/tuning_api/default_functions.py | 3 +- .../relax/transform/tuning_api/primitives.py | 49 ++++++ python/tvm/relax/vm.py | 3 +- src/ir/function.cc | 14 ++ src/ir/module.cc | 8 + src/relax/backend/task_extraction.cc | 14 +- src/relax/transform/meta_schedule.cc | 23 +-- src/relax/transform/run_codegen.cc | 7 +- src/relax/transform/tuning_api/primitives.cc | 4 +- src/runtime/relax_vm/executable.cc | 6 +- src/runtime/relax_vm/vm.cc | 10 ++ .../python/relax/test_autotir_integration.py | 2 - .../relax/test_transform_codegen_pass.py | 156 +++++++++++++++--- tests/python/relax/test_vm.py | 3 +- 17 files changed, 316 insertions(+), 60 deletions(-) diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index b254138ab5..a5a572b4b6 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -15,10 +15,14 @@ # specific language governing permissions and limitations # under the License. """Function defintiions.""" +from __future__ import annotations +from typing import Union, Dict from enum import IntEnum import tvm.runtime +from tvm.runtime.object import Object from .expr import RelayExpr +from .attrs import DictAttrs from . import _ffi_api @@ -38,7 +42,7 @@ def attrs(self): """Return the attrs member of the function.""" return _ffi_api.BaseFunc_Attrs(self) - def with_attr(self, attr_key_or_dict, attr_value=None): + def with_attr(self, attr_key_or_dict, attr_value=None) -> BaseFunc: """Create a new copy of the function and update the attribute. Parameters @@ -51,7 +55,7 @@ def with_attr(self, attr_key_or_dict, attr_value=None): Returns ------- - func : Function + func : BaseFunc A new copy of the function """ # make sure we first copy so that we can safely do copy on write @@ -67,7 +71,23 @@ def with_attr(self, attr_key_or_dict, attr_value=None): res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value) ) - def without_attr(self, attr_key: str): + def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> BaseFunc: + """Copy the IRModule and add the given attribute map to it. + Parameters + ---------- + attr_map: Union[DictAttrs, Dict[str, Object]] + The attribute map + Returns + ------- + func : BaseFunc + A new copy of the function + """ + if isinstance(attr_map, tvm.ir.DictAttrs): + attr_map = attr_map._dict() + + return _ffi_api.BaseFuncWithAttrs(self, attr_map) + + def without_attr(self, attr_key: str) -> BaseFunc: """Create a new copy of the function with an attribute without provided key. Parameters @@ -78,7 +98,7 @@ def without_attr(self, attr_key: str): Returns ------- - func : Function + func : BaseFunc A new copy of the function """ return _ffi_api.BaseFuncWithoutAttr(self, attr_key) diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 3ec67f76f4..36656a5b4a 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,13 +15,16 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" -from typing import Optional +from __future__ import annotations +from typing import Optional, Union, Dict import ast from tvm._ffi.base import string_types import tvm._ffi +from tvm.runtime.object import Object from .base import Node from . import expr as _expr +from .attrs import DictAttrs from ..ir.function import BaseFunc from . import type as _ty from . import _ffi_api @@ -330,7 +333,7 @@ def get_attrs(self): return _ffi_api.Module_GetAttrs(self) - def with_attr(self, attr_key, attr_value): + def with_attr(self, attr_key, attr_value) -> IRModule: """Copy the IRModule and add an attribute to it. Parameters @@ -348,3 +351,33 @@ def with_attr(self, attr_key, attr_value): """ return _ffi_api.Module_WithAttr(self, attr_key, attr_value) + + def without_attr(self, attr_key: str) -> IRModule: + """Copy the IRModule and remove an attribute key and its associated value. + Parameters + ---------- + attr_key : str + The attribute key. + Returns + ------- + mod : IRModule + A new copy of the IRModule without the attribute + """ + + return _ffi_api.Module_WithoutAttr(self, attr_key) + + def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> IRModule: + """Copy the IRModule and add the given attribute map to it. + Parameters + ---------- + attr_map: Union[DictAttrs, Dict[str, Object]] + The attribute map + Returns + ------- + mod : IRModule + A new copy of the IRModule with the attribute + """ + if isinstance(attr_map, tvm.ir.DictAttrs): + attr_map = attr_map._dict() + + return _ffi_api.Module_WithAttrs(self, attr_map) diff --git a/python/tvm/relax/testing/transform.py b/python/tvm/relax/testing/transform.py index 6b0e664c3d..c26b15c860 100644 --- a/python/tvm/relax/testing/transform.py +++ b/python/tvm/relax/testing/transform.py @@ -74,6 +74,11 @@ class Lowerer(PyExprMutator): """Mutator that performs lowering.""" def visit_call_(self, call_node: Call): + # Ignore function calls + # We only target calls for operators + if isinstance(call_node.op, (relax.GlobalVar, relax.expr.ExternFunc)): + return call_node + # Current relax op name simply adds "relax." prefix to relay op name. # Thus, remove "relax." prefix to deduce relay op name. relay_op_name = call_node.op.name[6:] @@ -112,6 +117,8 @@ def transform(self): if isinstance(func, relax.Function): updated_func = self.visit_expr(func) self.builder_.update_func(gv, updated_func) - return self.builder_.get() + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + return new_mod return Lowerer().transform() diff --git a/python/tvm/relax/transform/tuning_api/default_functions.py b/python/tvm/relax/transform/tuning_api/default_functions.py index 0985108c6a..30b2d69b1d 100644 --- a/python/tvm/relax/transform/tuning_api/default_functions.py +++ b/python/tvm/relax/transform/tuning_api/default_functions.py @@ -16,7 +16,6 @@ # under the License. """Relax Tuning Pass API default functions""" from typing import Dict, List, Optional -import copy import sys import itertools import logging @@ -91,7 +90,7 @@ def default_generate_candidate( choice = knob.choices[decision] # Generate new candidate when this condition satisfies. if choice.check_constr(cur_trace.out_mod): - new_trace = copy.deepcopy(cur_trace) + new_trace = cur_trace.deepcopy() new_trace.add(knob, decision) candidates.append(new_trace) diff --git a/python/tvm/relax/transform/tuning_api/primitives.py b/python/tvm/relax/transform/tuning_api/primitives.py index fe34c2b43c..23b2101545 100644 --- a/python/tvm/relax/transform/tuning_api/primitives.py +++ b/python/tvm/relax/transform/tuning_api/primitives.py @@ -168,6 +168,9 @@ def from_json(json_obj: JSON_TYPE) -> "Choice": """ return _ffi_api.ChoiceFromJSON(json_obj) + def deepcopy(self): + return Choice.from_json(self.as_json()) + @register_object("relax.tuning_api.Knob") class Knob(Object): @@ -247,6 +250,9 @@ def __str__(self) -> str: msg += f" - {name}: {choice}\n" return msg + def deepcopy(self): + return Knob.from_json(self.as_json()) + @register_object("relax.tuning_api.Trace") class Trace(Object): @@ -346,6 +352,15 @@ def __str__(self) -> str: msg += f"[{idx+1}] {self.knobs[idx].name}: {self.decisions[idx]}\n" return msg + def deepcopy(self) -> "Trace": + new_in_mod = deepcopy_irmodule(self.in_mod) + new_knobs = [knob.deepcopy() for knob in self.knobs] + new_decisions = [str(decision) for decision in self.decisions] + new_trace = Trace(new_in_mod, new_knobs, new_decisions) + new_out_mod = deepcopy_irmodule(self.out_mod) + new_trace.set_out_mod(new_out_mod) + return new_trace + def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace: """ @@ -368,3 +383,37 @@ def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace: return Trace(tvm.IRModule.from_expr(in_)) raise Exception(f"Invalid input type for trace: {type(in_)}") + + +@tvm.register_func("relax.tuning_api.deepcopy_irmodule") +def deepcopy_irmodule(mod: IRModule) -> IRModule: + """ + Deepcopy for an IRModule. + Parameters + ---------- + mod: IRModule + input IRModule + Return + ---------- + copied_mod: IRModule + deep-copied IRModule + """ + func_save_json = tvm.get_global_func("node.SaveJSON") + func_load_json = tvm.get_global_func("node.LoadJSON") + new_mod = None + # Handle external modules separately if exist + # TODO(tvm-team): + # Serialization of IRModule with external mods is tricky. + # (1) External mod is runtime module. + # (2) Currently, `export_library` does not support serialization of + # runtime module without the host module + # Therefore, we simply pass around the compiled external modules without copy for now. + # Revisit later when we have a better solution. + if mod.attrs and "external_mods" in mod.attrs: + tmp_mod = mod.without_attr("external_mods") + new_mod = func_load_json(func_save_json(tmp_mod)) + new_mod = new_mod.with_attr("external_mods", mod.attrs["external_mods"]) + else: + new_mod = func_load_json(func_save_json(mod)) + + return new_mod diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index 46294b6b8b..177f2e1a9a 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -506,10 +506,11 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")): seq = tvm.transform.Sequential(passes) new_mod = seq(mod) - # split primfunc and relax function + # Split primfunc and relax function rx_mod, tir_mod = _split_tir_relax(new_mod) lib = tvm.build(tir_mod, target=target) + # Extract external runtime modules if exist. ext_libs = [] if mod.attrs and "external_mods" in mod.attrs: ext_libs = mod.attrs["external_mods"] diff --git a/src/ir/function.cc b/src/ir/function.cc index 597deb0f79..500d94d11c 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -52,6 +52,20 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr") } }); +TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") + .set_body_typed([](BaseFunc func, Map attr_map) -> BaseFunc { + if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } else if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } else if (func->IsInstance()) { + return WithAttrs(Downcast(std::move(func)), attr_map); + } else { + LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + return func; + } + }); + TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") .set_body_typed([](BaseFunc func, String key) -> BaseFunc { if (func->IsInstance()) { diff --git a/src/ir/module.cc b/src/ir/module.cc index eb7a89259e..3de19cfa9d 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -538,6 +538,14 @@ TVM_REGISTER_GLOBAL("ir.Module_WithAttr") return WithAttr(mod, key, value); }); +TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr") + .set_body_typed([](IRModule mod, String key) -> IRModule { return WithoutAttr(mod, key); }); + +TVM_REGISTER_GLOBAL("ir.Module_WithAttrs") + .set_body_typed([](IRModule mod, Map attr_map) -> IRModule { + return WithAttrs(mod, attr_map); + }); + TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef { return mod->GetAttr(key); }); diff --git a/src/relax/backend/task_extraction.cc b/src/relax/backend/task_extraction.cc index 0bbbfca117..beb3950af1 100644 --- a/src/relax/backend/task_extraction.cc +++ b/src/relax/backend/task_extraction.cc @@ -45,14 +45,14 @@ using tvm::meta_schedule::ExtractedTask; class TaskExtractor : public ExprVisitor { public: static Array ExtractTask(IRModule mod, Target target) { - TaskExtractor extracor(mod, target); + TaskExtractor extractor(mod, target); // We go through each Relax function in the module. for (const auto& kv : mod->functions) { if (const auto* func = kv.second.as()) { - extracor(GetRef(func)); + extractor(GetRef(func)); } } - return std::move(extracor.tasks_); + return std::move(extractor.tasks_); } private: @@ -64,12 +64,20 @@ class TaskExtractor : public ExprVisitor { void VisitExpr_(const CallNode* call) final { static const Op& call_tir_op = Op::Get("relax.call_tir"); + + // TODO(@tvm-team): When we differentiate the call for tir function and packed function, + // this logic should be changed accordingly. if (!call->op.same_as(call_tir_op)) { // Since the Relax function is of A-normal form, the arguments of this call cannot be another // Calls. And hence we do not need to recurse into this Call. return; } + // Do not extract external function + if (call->args[0].as()) { + return; + } + const GlobalVar& global_var = Downcast(call->args[0]); const tir::PrimFunc& func = Downcast(mod_->Lookup(global_var)); diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index 6af1e95d1a..7af7b678be 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -88,23 +88,24 @@ class MetaScheduleTuner { Pass MetaScheduleApplyDatabase(Optional work_dir) { using tvm::meta_schedule::Database; Target target = Target::Current(false); - Database database; - if (Database::Current().defined()) { - database = Database::Current().value(); - } else { - ICHECK(work_dir.defined()); - String path_workload = work_dir.value() + "/database_workload.json"; - String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; - LOG(INFO) << "Creating JSONDatabase. Workload at: " << path_workload - << ", Tuning records at: " << path_tuning_record; - database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true); - } const runtime::PackedFunc* normalize_mod_func_ = runtime::Registry::Get("tvm.meta_schedule.normalize_mod"); ICHECK(normalize_mod_func_) << "Normalization function is not found."; runtime::TypedPackedFunc pass_func = [=](IRModule mod, PassContext ctx) { + Database database; + if (Database::Current().defined()) { + database = Database::Current().value(); + } else { + ICHECK(work_dir.defined()); + String path_workload = work_dir.value() + "/database_workload.json"; + String path_tuning_record = work_dir.value() + "/database_tuning_record.json"; + LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload + << ", Tuning records at: " << path_tuning_record; + database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true); + } + Map result; for (const auto& iter : mod->functions) { GlobalVar gv = iter.first; diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index 011d384d76..5d954cf055 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -90,7 +90,12 @@ class CodeGenRunner : ExprMutator { return Call(call_op, new_args, tvm::Attrs(), {func->ret_type}); } } - return GetRef(call_node); + Array new_args; + for (const auto& arg : call_node->args) { + new_args.push_back(VisitExpr(arg)); + } + + return Call(call_node->op, new_args, call_node->attrs, call_node->type_args, call_node->span); } Expr VisitExpr_(const FunctionNode* func_node) override { diff --git a/src/relax/transform/tuning_api/primitives.cc b/src/relax/transform/tuning_api/primitives.cc index 93f063e856..ef4a3d41bd 100644 --- a/src/relax/transform/tuning_api/primitives.cc +++ b/src/relax/transform/tuning_api/primitives.cc @@ -148,7 +148,9 @@ Trace::Trace() { data_ = make_object(); } Trace::Trace(IRModule in_mod, Array knobs, Array decisions) { ICHECK(knobs.size() == decisions.size()) << "Size of knobs and decisions should match"; // Deep-copy IRModule - IRModule out_mod = meta_schedule::DeepCopyIRModule(in_mod); + auto func_deepcopy = runtime::Registry::Get("relax.tuning_api.deepcopy_irmodule"); + ICHECK(func_deepcopy); + IRModule out_mod = (*func_deepcopy)(in_mod); // Apply the decision history if provided int size = knobs.size(); for (int i = 0; i < size; i++) { diff --git a/src/runtime/relax_vm/executable.cc b/src/runtime/relax_vm/executable.cc index 47d34e9de4..db9a278760 100644 --- a/src/runtime/relax_vm/executable.cc +++ b/src/runtime/relax_vm/executable.cc @@ -505,13 +505,11 @@ String Executable::AsText() const { break; } case Opcode::Ret: { - os << std::setw(6) << std::left << "ret" - << "ret " << RegNameToStr(instr.result) << "\n"; + os << std::setw(6) << std::left << "ret " << RegNameToStr(instr.result) << "\n"; break; } case Opcode::Goto: { - os << std::setw(6) << std::left << "goto" - << "goto " << instr.pc_offset << "\n"; + os << std::setw(6) << std::left << "goto" << instr.pc_offset << "\n"; break; } case Opcode::If: { diff --git a/src/runtime/relax_vm/vm.cc b/src/runtime/relax_vm/vm.cc index f710c716f4..65656c42c8 100644 --- a/src/runtime/relax_vm/vm.cc +++ b/src/runtime/relax_vm/vm.cc @@ -265,7 +265,15 @@ void VirtualMachine::LoadExecutable(ObjectPtr exec) { RegType VirtualMachine::Invoke(Index gf_idx, const std::vector& args) { const VMFunction& gfunc = exec_->global_funcs[gf_idx]; + // Get the curr instr which might be a potential caller. + Instruction curr_instr = exec_->GetInstruction(pc_); PushFrame(this->pc_, gfunc); + // Get new frame and set the caller info. + VMFrame* curr_frame = frames_.back().get(); + if (curr_instr.op == Opcode::Call) { + curr_frame->caller_return_register = curr_instr.dst; + } + // load arguments to the register file ICHECK_EQ(static_cast(gfunc.num_args), args.size()) << "ValueError: Invoking function " << gfunc.name << " requires " << gfunc.num_args @@ -372,9 +380,11 @@ void VirtualMachine::RunInstrCall(VMFrame* curr_frame, Instruction instr) { this->PrepareFuncTable(instr.func_idx); func_table_[instr.func_idx].CallPacked(args, &ret); + // save the return value to the register if (instr.dst != Instruction::kVoidArg) { WriteRegister(curr_frame, instr.dst, ret); } + // increment pc pc_++; } diff --git a/tests/python/relax/test_autotir_integration.py b/tests/python/relax/test_autotir_integration.py index 7420b76e9a..7a061b0bbd 100644 --- a/tests/python/relax/test_autotir_integration.py +++ b/tests/python/relax/test_autotir_integration.py @@ -84,8 +84,6 @@ class InputModule: @T.prim_func def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: T.func_attr({"global_symbol": "tir_matmul"}) - m = T.var("int32") - n = T.var("int32") k = T.var("int32") A = T.match_buffer(x, (32, 32)) B = T.match_buffer(y, (32, 32)) diff --git a/tests/python/relax/test_transform_codegen_pass.py b/tests/python/relax/test_transform_codegen_pass.py index 8462d14d74..55cb27bb28 100644 --- a/tests/python/relax/test_transform_codegen_pass.py +++ b/tests/python/relax/test_transform_codegen_pass.py @@ -23,7 +23,10 @@ from tvm import relax import numpy as np from tvm.script import relax as R -from tvm import transform +from tvm.relax.testing import transform +import tempfile +from tvm.relax.transform.tuning_api import Trace +from tvm import meta_schedule as ms env_checker_codegen = tvm.get_global_func("relax.ext.tensorrt", True) env_checker_runtime = tvm.get_global_func("relax.is_tensorrt_runtime_enabled", True) @@ -40,15 +43,18 @@ # Global variable in pytest that applies markers to all tests. pytestmark = [has_tensorrt_codegen, has_tensorrt_runtime] +# Target gpu +target_str = "nvidia/nvidia-t4" +target = tvm.target.Target(target_str) +dev = tvm.cuda() + def check_executable(exec, dev, inputs, expected): vm = relax.VirtualMachine(exec, dev) - # Measure the performance w/o tuning log out = vm["main"](*inputs) - tvm.testing.assert_allclose(out.numpy(), expected) + tvm.testing.assert_allclose(out.numpy(), expected.numpy(), atol=1e-5, rtol=1e-5) -# TODO(sunggg): Serialize TRT runtime module. This might be helpful: `module.export_library(file_name)`` def check_roundtrip(exec0, dev, inputs, expected): exec0.mod.export_library("exec.so") exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so")) @@ -60,23 +66,62 @@ def check_roundtrip(exec0, dev, inputs, expected): check_executable(exec1, dev, inputs, expected) +def gen_ground_truth(mod, target, dev, inputs): + # Lower and run tuning + # Since there is no default schedule for GPU in MS yet, this is necessary + with tempfile.TemporaryDirectory() as work_dir: + with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=0): + seq = tvm.transform.Sequential( + [ + transform.LowerWithRelayOpStrategyPass(target), + relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=8 + ), + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + ) + new_mod = seq(mod) + assert relax.analysis.well_formed(new_mod) + exec = relax.vm.build(new_mod, target, params={}) + vm = relax.VirtualMachine(exec, dev) + return vm["main"](*inputs) + + +@tvm.testing.requires_gpu def test_single_annot_func(): @tvm.script.ir_module class InputModule: @R.function - def relax_func(x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32")) -> Tensor: - z1 = relax.add(x, y) + def relax_func( + x: Tensor((16, 16), "float32"), y: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + z1 = relax.multiply(x, y) z2 = relax.add(z1, z1) z3 = relax.add(z1, z2) return z3 @R.function - def main(x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32")) -> Tensor: - lv0 = relax_func(x, y) + def main( + x: Tensor((16, 16), "float32"), y: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + lv0: Tensor((16, 16), "float32") = relax_func(x, y) return lv0 + # Prepare IRModule and its input mod = InputModule assert isinstance(mod, tvm.IRModule) + + np0 = np.random.rand(16, 16).astype(np.float32) + np1 = np.random.rand(16, 16).astype(np.float32) + data0 = tvm.nd.array(np0, dev) + data1 = tvm.nd.array(np1, dev) + inputs = [data0, data1] + + # Ground truth should be generated before annotation + # due to the conflict with MS task extraction + # TODO(@sunggg): Sort this out + expected = gen_ground_truth(mod, target, dev, inputs) + # TODO(@sunggg): Revisit when TVMScript supports annotation. # Annotate target function. new_relax_func = mod["relax_func"].with_attr("Codegen", "tensorrt") @@ -84,33 +129,90 @@ def main(x: Tensor((2, 3), "float32"), y: Tensor((2, 3), "float32")) -> Tensor: mod["relax_func"] = new_relax_func # Run Codegen pass - seq = transform.Sequential( + seq = tvm.transform.Sequential( [relax.transform.RunCodegen(), relax.transform.RemoveUnusedFunctions()] ) + new_mod = seq(mod) + ex0 = relax.vm.build(new_mod, target, params={}) - target_str = "cuda" - target = tvm.target.Target(target_str) - dev = tvm.device(target_str, 0) + # Sanity check for the correctness and rountrip + check_roundtrip(ex0, dev, inputs, expected) - with transform.PassContext(opt_level=0): - ex0 = relax.vm.build(new_mod, target, params={}) + # If the annotation does not match with the target codegen, do not perform the codegen process. + new_mod = relax.transform.RunCodegen(target_codegens=["INVALID_CODEGEN"])(mod) + # TODO(tvm-team): Currently disabled due to the lack of type annotation support during parser. + # Revisit when new version of parser is available. + # tvm.ir.assert_structural_equal(mod, new_mod) + + +@tvm.testing.requires_gpu +def test_mix_use_tensorrt_and_tvm(): + @tvm.script.ir_module + class InputModule: + @R.function + def byoc_func( + x: Tensor((16, 16), "float32"), y: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + z1 = relax.multiply(x, y) + z2 = relax.add(z1, z1) + z3 = relax.add(z1, z2) + return z3 - # Correct output: Current relax cannot lower relax.add. - # Use numpy baseline instead. - np0 = np.random.rand(2, 3).astype(np.float32) - np1 = np.random.rand(2, 3).astype(np.float32) - data0 = tvm.nd.array(np0, tvm.cpu()) - data1 = tvm.nd.array(np1, tvm.cpu()) + @R.function + def tvm_func( + x: Tensor((16, 16), "float32"), w: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + gv0 = R.multiply(x, w) + gv1 = R.add(x, gv0) + return gv1 - tmp = np0 + np1 - out1 = tmp + tmp - expected = out1 + tmp - check_roundtrip(ex0, dev, [data0, data1], expected) + @R.function + def main( + x: Tensor((16, 16), "float32"), y: Tensor((16, 16), "float32") + ) -> Tensor((16, 16), "float32"): + lv0 = byoc_func(x, y) + lv1 = tvm_func(x, lv0) + return lv1 + + # Prepare IRModule and its inputs + mod = InputModule + assert isinstance(mod, tvm.IRModule) - # If the annotation does not match with the target codegen, do not perform the codegen process. - new_mod = relax.transform.RunCodegen(target_codegens=["INVALID_CODEGEN"])(mod) - tvm.ir.assert_structural_equal(mod, new_mod) + np0 = np.random.rand(16, 16).astype(np.float32) + np1 = np.random.rand(16, 16).astype(np.float32) + data0 = tvm.nd.array(np0, dev) + data1 = tvm.nd.array(np1, dev) + inputs = [data0, data1] + expected = gen_ground_truth(mod, target, dev, [data0, data1]) + + # TODO(@sunggg): Revisit when TVMScript supports annotation. + # Annotate target function. + new_byoc_func = mod["byoc_func"].with_attr("Codegen", "tensorrt") + new_byoc_func = new_byoc_func.with_attr("global_symbol", "trt_byoc_func") + mod["byoc_func"] = new_byoc_func + + # Run Codegen pass + with tempfile.TemporaryDirectory() as work_dir: + with target, tvm.transform.PassContext(trace=Trace(mod), opt_level=3): + seq = tvm.transform.Sequential( + [ + relax.transform.RunCodegen(), + relax.transform.RemoveUnusedFunctions(), + transform.LowerWithRelayOpStrategyPass(target), + relax.transform.MetaScheduleTuneIRMod( + params={}, work_dir=work_dir, max_trials_global=8 + ), + relax.transform.MetaScheduleApplyDatabase(work_dir), + ] + ) + new_mod = seq(mod) + assert relax.analysis.well_formed(new_mod) + with transform.PassContext(opt_level=0): + ex0 = relax.vm.build(new_mod, target, params={}) + + # Sanity check for the correctness and rountrip + check_roundtrip(ex0, dev, inputs, expected) # TODO(@sunggg): test with more complex patterns (e.g., multiple annots, mixed codegens, different ops, const binding) diff --git a/tests/python/relax/test_vm.py b/tests/python/relax/test_vm.py index afc83d1130..292d1ebdfb 100644 --- a/tests/python/relax/test_vm.py +++ b/tests/python/relax/test_vm.py @@ -954,8 +954,9 @@ def recursion(n: Tensor((1,), "float32")) -> Tensor: gv0 = relax.call_packed( "test.vm.subtract_one", n, type_args=(Tensor(ndim=1, dtype="float32")) ) + tmp = recursion(gv0) res = relax.call_packed( - "test.vm.add", recursion(gv0), n, type_args=(Tensor(ndim=1, dtype="float32")) + "test.vm.add", tmp, tmp, type_args=(Tensor(ndim=1, dtype="float32")) ) return res