Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
Recover: [Bugfix] Couple of bug fixes to run TVM-gen code together wi…
Browse files Browse the repository at this point in the history
…th BYOC (#249)
  • Loading branch information
sunggg committed Oct 18, 2022
1 parent 83de3b2 commit 2963787
Show file tree
Hide file tree
Showing 17 changed files with 316 additions and 60 deletions.
28 changes: 24 additions & 4 deletions python/tvm/ir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Function defintiions."""
from __future__ import annotations
from typing import Union, Dict
from enum import IntEnum
import tvm.runtime
from tvm.runtime.object import Object

from .expr import RelayExpr
from .attrs import DictAttrs
from . import _ffi_api


Expand All @@ -38,7 +42,7 @@ def attrs(self):
"""Return the attrs member of the function."""
return _ffi_api.BaseFunc_Attrs(self)

def with_attr(self, attr_key_or_dict, attr_value=None):
def with_attr(self, attr_key_or_dict, attr_value=None) -> BaseFunc:
"""Create a new copy of the function and update the attribute.
Parameters
Expand All @@ -51,7 +55,7 @@ def with_attr(self, attr_key_or_dict, attr_value=None):
Returns
-------
func : Function
func : BaseFunc
A new copy of the function
"""
# make sure we first copy so that we can safely do copy on write
Expand All @@ -67,7 +71,23 @@ def with_attr(self, attr_key_or_dict, attr_value=None):
res._move(), attr_key_or_dict, tvm.runtime.convert(attr_value)
)

def without_attr(self, attr_key: str):
def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> BaseFunc:
"""Copy the IRModule and add the given attribute map to it.
Parameters
----------
attr_map: Union[DictAttrs, Dict[str, Object]]
The attribute map
Returns
-------
func : BaseFunc
A new copy of the function
"""
if isinstance(attr_map, tvm.ir.DictAttrs):
attr_map = attr_map._dict()

return _ffi_api.BaseFuncWithAttrs(self, attr_map)

def without_attr(self, attr_key: str) -> BaseFunc:
"""Create a new copy of the function with an attribute without provided key.
Parameters
Expand All @@ -78,7 +98,7 @@ def without_attr(self, attr_key: str):
Returns
-------
func : Function
func : BaseFunc
A new copy of the function
"""
return _ffi_api.BaseFuncWithoutAttr(self, attr_key)
37 changes: 35 additions & 2 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
# specific language governing permissions and limitations
# under the License.
"""IRModule that holds the functions and type definitions."""
from typing import Optional
from __future__ import annotations
from typing import Optional, Union, Dict
import ast
from tvm._ffi.base import string_types
import tvm._ffi
from tvm.runtime.object import Object

from .base import Node
from . import expr as _expr
from .attrs import DictAttrs
from ..ir.function import BaseFunc
from . import type as _ty
from . import _ffi_api
Expand Down Expand Up @@ -330,7 +333,7 @@ def get_attrs(self):

return _ffi_api.Module_GetAttrs(self)

def with_attr(self, attr_key, attr_value):
def with_attr(self, attr_key, attr_value) -> IRModule:
"""Copy the IRModule and add an attribute to it.
Parameters
Expand All @@ -348,3 +351,33 @@ def with_attr(self, attr_key, attr_value):
"""

return _ffi_api.Module_WithAttr(self, attr_key, attr_value)

def without_attr(self, attr_key: str) -> IRModule:
"""Copy the IRModule and remove an attribute key and its associated value.
Parameters
----------
attr_key : str
The attribute key.
Returns
-------
mod : IRModule
A new copy of the IRModule without the attribute
"""

return _ffi_api.Module_WithoutAttr(self, attr_key)

def with_attrs(self, attr_map: Union[DictAttrs, Dict[str, Object]]) -> IRModule:
"""Copy the IRModule and add the given attribute map to it.
Parameters
----------
attr_map: Union[DictAttrs, Dict[str, Object]]
The attribute map
Returns
-------
mod : IRModule
A new copy of the IRModule with the attribute
"""
if isinstance(attr_map, tvm.ir.DictAttrs):
attr_map = attr_map._dict()

return _ffi_api.Module_WithAttrs(self, attr_map)
9 changes: 8 additions & 1 deletion python/tvm/relax/testing/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ class Lowerer(PyExprMutator):
"""Mutator that performs lowering."""

def visit_call_(self, call_node: Call):
# Ignore function calls
# We only target calls for operators
if isinstance(call_node.op, (relax.GlobalVar, relax.expr.ExternFunc)):
return call_node

# Current relax op name simply adds "relax." prefix to relay op name.
# Thus, remove "relax." prefix to deduce relay op name.
relay_op_name = call_node.op.name[6:]
Expand Down Expand Up @@ -112,6 +117,8 @@ def transform(self):
if isinstance(func, relax.Function):
updated_func = self.visit_expr(func)
self.builder_.update_func(gv, updated_func)
return self.builder_.get()
new_mod = self.builder_.get()
new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod
return new_mod

return Lowerer().transform()
3 changes: 1 addition & 2 deletions python/tvm/relax/transform/tuning_api/default_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
"""Relax Tuning Pass API default functions"""
from typing import Dict, List, Optional
import copy
import sys
import itertools
import logging
Expand Down Expand Up @@ -91,7 +90,7 @@ def default_generate_candidate(
choice = knob.choices[decision]
# Generate new candidate when this condition satisfies.
if choice.check_constr(cur_trace.out_mod):
new_trace = copy.deepcopy(cur_trace)
new_trace = cur_trace.deepcopy()
new_trace.add(knob, decision)
candidates.append(new_trace)

Expand Down
49 changes: 49 additions & 0 deletions python/tvm/relax/transform/tuning_api/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def from_json(json_obj: JSON_TYPE) -> "Choice":
"""
return _ffi_api.ChoiceFromJSON(json_obj)

def deepcopy(self):
return Choice.from_json(self.as_json())


@register_object("relax.tuning_api.Knob")
class Knob(Object):
Expand Down Expand Up @@ -247,6 +250,9 @@ def __str__(self) -> str:
msg += f" - {name}: {choice}\n"
return msg

def deepcopy(self):
return Knob.from_json(self.as_json())


@register_object("relax.tuning_api.Trace")
class Trace(Object):
Expand Down Expand Up @@ -346,6 +352,15 @@ def __str__(self) -> str:
msg += f"[{idx+1}] {self.knobs[idx].name}: {self.decisions[idx]}\n"
return msg

def deepcopy(self) -> "Trace":
new_in_mod = deepcopy_irmodule(self.in_mod)
new_knobs = [knob.deepcopy() for knob in self.knobs]
new_decisions = [str(decision) for decision in self.decisions]
new_trace = Trace(new_in_mod, new_knobs, new_decisions)
new_out_mod = deepcopy_irmodule(self.out_mod)
new_trace.set_out_mod(new_out_mod)
return new_trace


def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace:
"""
Expand All @@ -368,3 +383,37 @@ def get_trace(in_: Union[Trace, IRModule, Expr]) -> Trace:
return Trace(tvm.IRModule.from_expr(in_))

raise Exception(f"Invalid input type for trace: {type(in_)}")


@tvm.register_func("relax.tuning_api.deepcopy_irmodule")
def deepcopy_irmodule(mod: IRModule) -> IRModule:
"""
Deepcopy for an IRModule.
Parameters
----------
mod: IRModule
input IRModule
Return
----------
copied_mod: IRModule
deep-copied IRModule
"""
func_save_json = tvm.get_global_func("node.SaveJSON")
func_load_json = tvm.get_global_func("node.LoadJSON")
new_mod = None
# Handle external modules separately if exist
# TODO(tvm-team):
# Serialization of IRModule with external mods is tricky.
# (1) External mod is runtime module.
# (2) Currently, `export_library` does not support serialization of
# runtime module without the host module
# Therefore, we simply pass around the compiled external modules without copy for now.
# Revisit later when we have a better solution.
if mod.attrs and "external_mods" in mod.attrs:
tmp_mod = mod.without_attr("external_mods")
new_mod = func_load_json(func_save_json(tmp_mod))
new_mod = new_mod.with_attr("external_mods", mod.attrs["external_mods"])
else:
new_mod = func_load_json(func_save_json(mod))

return new_mod
3 changes: 2 additions & 1 deletion python/tvm/relax/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,11 @@ def foo(x: Tensor((3, 4), "float32"), y: Tensor((3, 4), "float32")):
seq = tvm.transform.Sequential(passes)
new_mod = seq(mod)

# split primfunc and relax function
# Split primfunc and relax function
rx_mod, tir_mod = _split_tir_relax(new_mod)
lib = tvm.build(tir_mod, target=target)

# Extract external runtime modules if exist.
ext_libs = []
if mod.attrs and "external_mods" in mod.attrs:
ext_libs = mod.attrs["external_mods"]
Expand Down
14 changes: 14 additions & 0 deletions src/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttr")
}
});

TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs")
.set_body_typed([](BaseFunc func, Map<String, ObjectRef> attr_map) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
return WithAttrs(Downcast<tir::PrimFunc>(std::move(func)), attr_map);
} else if (func->IsInstance<relay::FunctionNode>()) {
return WithAttrs(Downcast<relay::Function>(std::move(func)), attr_map);
} else if (func->IsInstance<relax::FunctionNode>()) {
return WithAttrs(Downcast<relax::Function>(std::move(func)), attr_map);
} else {
LOG(FATAL) << "Do not support function type " << func->GetTypeKey();
return func;
}
});

TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr")
.set_body_typed([](BaseFunc func, String key) -> BaseFunc {
if (func->IsInstance<tir::PrimFuncNode>()) {
Expand Down
8 changes: 8 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,14 @@ TVM_REGISTER_GLOBAL("ir.Module_WithAttr")
return WithAttr(mod, key, value);
});

TVM_REGISTER_GLOBAL("ir.Module_WithoutAttr")
.set_body_typed([](IRModule mod, String key) -> IRModule { return WithoutAttr(mod, key); });

TVM_REGISTER_GLOBAL("ir.Module_WithAttrs")
.set_body_typed([](IRModule mod, Map<String, ObjectRef> attr_map) -> IRModule {
return WithAttrs(mod, attr_map);
});

TVM_REGISTER_GLOBAL("ir.Module_GetAttr").set_body_typed([](IRModule mod, String key) -> ObjectRef {
return mod->GetAttr<ObjectRef>(key);
});
Expand Down
14 changes: 11 additions & 3 deletions src/relax/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ using tvm::meta_schedule::ExtractedTask;
class TaskExtractor : public ExprVisitor {
public:
static Array<ExtractedTask> ExtractTask(IRModule mod, Target target) {
TaskExtractor extracor(mod, target);
TaskExtractor extractor(mod, target);
// We go through each Relax function in the module.
for (const auto& kv : mod->functions) {
if (const auto* func = kv.second.as<FunctionNode>()) {
extracor(GetRef<Function>(func));
extractor(GetRef<Function>(func));
}
}
return std::move(extracor.tasks_);
return std::move(extractor.tasks_);
}

private:
Expand All @@ -64,12 +64,20 @@ class TaskExtractor : public ExprVisitor {

void VisitExpr_(const CallNode* call) final {
static const Op& call_tir_op = Op::Get("relax.call_tir");

// TODO(@tvm-team): When we differentiate the call for tir function and packed function,
// this logic should be changed accordingly.
if (!call->op.same_as(call_tir_op)) {
// Since the Relax function is of A-normal form, the arguments of this call cannot be another
// Calls. And hence we do not need to recurse into this Call.
return;
}

// Do not extract external function
if (call->args[0].as<ExternFuncNode>()) {
return;
}

const GlobalVar& global_var = Downcast<GlobalVar>(call->args[0]);
const tir::PrimFunc& func = Downcast<tir::PrimFunc>(mod_->Lookup(global_var));

Expand Down
23 changes: 12 additions & 11 deletions src/relax/transform/meta_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,24 @@ class MetaScheduleTuner {
Pass MetaScheduleApplyDatabase(Optional<String> work_dir) {
using tvm::meta_schedule::Database;
Target target = Target::Current(false);
Database database;
if (Database::Current().defined()) {
database = Database::Current().value();
} else {
ICHECK(work_dir.defined());
String path_workload = work_dir.value() + "/database_workload.json";
String path_tuning_record = work_dir.value() + "/database_tuning_record.json";
LOG(INFO) << "Creating JSONDatabase. Workload at: " << path_workload
<< ", Tuning records at: " << path_tuning_record;
database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true);
}
const runtime::PackedFunc* normalize_mod_func_ =
runtime::Registry::Get("tvm.meta_schedule.normalize_mod");
ICHECK(normalize_mod_func_) << "Normalization function is not found.";

runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext ctx) {
Database database;
if (Database::Current().defined()) {
database = Database::Current().value();
} else {
ICHECK(work_dir.defined());
String path_workload = work_dir.value() + "/database_workload.json";
String path_tuning_record = work_dir.value() + "/database_tuning_record.json";
LOG(WARNING) << "Creating JSONDatabase. Workload at: " << path_workload
<< ", Tuning records at: " << path_tuning_record;
database = meta_schedule::Database::JSONDatabase(path_workload, path_tuning_record, true);
}

Map<GlobalVar, BaseFunc> result;
for (const auto& iter : mod->functions) {
GlobalVar gv = iter.first;
Expand Down
7 changes: 6 additions & 1 deletion src/relax/transform/run_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,12 @@ class CodeGenRunner : ExprMutator {
return Call(call_op, new_args, tvm::Attrs(), {func->ret_type});
}
}
return GetRef<Call>(call_node);
Array<Expr> new_args;
for (const auto& arg : call_node->args) {
new_args.push_back(VisitExpr(arg));
}

return Call(call_node->op, new_args, call_node->attrs, call_node->type_args, call_node->span);
}

Expr VisitExpr_(const FunctionNode* func_node) override {
Expand Down
4 changes: 3 additions & 1 deletion src/relax/transform/tuning_api/primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,9 @@ Trace::Trace() { data_ = make_object<TraceNode>(); }
Trace::Trace(IRModule in_mod, Array<Knob> knobs, Array<String> decisions) {
ICHECK(knobs.size() == decisions.size()) << "Size of knobs and decisions should match";
// Deep-copy IRModule
IRModule out_mod = meta_schedule::DeepCopyIRModule(in_mod);
auto func_deepcopy = runtime::Registry::Get("relax.tuning_api.deepcopy_irmodule");
ICHECK(func_deepcopy);
IRModule out_mod = (*func_deepcopy)(in_mod);
// Apply the decision history if provided
int size = knobs.size();
for (int i = 0; i < size; i++) {
Expand Down
Loading

0 comments on commit 2963787

Please sign in to comment.