Skip to content

Commit

Permalink
Remove null_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Mar 1, 2023
1 parent 65b53e8 commit e32d773
Show file tree
Hide file tree
Showing 17 changed files with 138 additions and 154 deletions.
25 changes: 13 additions & 12 deletions include/tvm/ir/global_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,29 +76,30 @@ class DummyGlobalInfo : public GlobalInfo {
};

/*!
* \brief A return global info sub-class for expressions to return.
* \brief A return global info sub-class for return expressions.
*/
class ReturnGlobalInfoNode : public GlobalInfoNode {
class RelaxReturnGlobalInfoNode : public GlobalInfoNode {
public:
Array<RelayExpr> return_exprs;
Array<RelayExpr> relax_return_exprs;
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "ReturnGlobalInfo";
static constexpr const char* _type_key = "RelaxReturnGlobalInfo";

TVM_DLL bool SEqualReduce(const ReturnGlobalInfoNode* other, SEqualReducer equal) const {
return equal(return_exprs, other->return_exprs);
TVM_DLL bool SEqualReduce(const RelaxReturnGlobalInfoNode* other, SEqualReducer equal) const {
// return equal(relax_return_exprs, other->relax_return_exprs)
return true;
}

TVM_DLL void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(return_exprs); }
TVM_DECLARE_FINAL_OBJECT_INFO(ReturnGlobalInfoNode, GlobalInfoNode);
TVM_DLL void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(relax_return_exprs); }
TVM_DECLARE_FINAL_OBJECT_INFO(RelaxReturnGlobalInfoNode, GlobalInfoNode);
};

/*!
* \brief Managed reference to ReturnGlobalInfoNode.
* \sa ReturnGlobalInfoNode
* \brief Managed reference to RelaxReturnGlobalInfoNode.
* \sa RelaxReturnGlobalInfoNode
*/
class ReturnGlobalInfo : public GlobalInfo {
class RelaxReturnGlobalInfo : public GlobalInfo {
public:
TVM_DEFINE_OBJECT_REF_METHODS(ReturnGlobalInfo, GlobalInfo, ReturnGlobalInfoNode);
TVM_DEFINE_OBJECT_REF_METHODS(RelaxReturnGlobalInfo, GlobalInfo, RelaxReturnGlobalInfoNode);
};

} // namespace tvm
Expand Down
22 changes: 0 additions & 22 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,28 +649,6 @@ class NullExprNode : public LeafExprNode {
TVM_DECLARE_FINAL_OBJECT_INFO(NullExprNode, LeafExprNode);
};

/*!
* \brief Managed reference to NullExprNode
* \sa NullExprNode
*/
class NullExpr : public LeafExpr {
public:
/*!
* \brief The constructor
* \param span The source span of the expression.
*/
TVM_DLL explicit NullExpr(Span span);

/*!
* \brief Create a null expression.
* \param span The source span of the expression.
* \return The created prim value.
*/

TVM_DEFINE_OBJECT_REF_METHODS(NullExpr, LeafExpr, NullExprNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(NullExprNode);
};

/*!
* \brief Represent a string literal constant.
*/
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class IRDocsifierNode : public Object {
/*! \brief Metadata printing */
std::unordered_map<String, Array<ObjectRef>> metadata;
/*! \brief Return exprs used to help tell whether or not an expr is a return*/
std::unordered_set<RelayExpr, ObjectPtrHash, ObjectPtrEqual> return_exprs;
std::unordered_set<RelayExpr, ObjectPtrHash, ObjectPtrEqual> relax_return_exprs;
/*! \brief The variable names used already */
std::unordered_set<String> defined_names;
/*! \brief Common prefixes of variable usages */
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from .container import Array, Map
from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
from .function import BaseFunc, CallingConv
from .global_info import GlobalInfo, ReturnGlobalInfo, DummyGlobalInfo
from .global_info import GlobalInfo, RelaxReturnGlobalInfo, DummyGlobalInfo
from .memory_pools import (
ConstantMemoryPools,
ConstantPoolInfo,
Expand Down
27 changes: 7 additions & 20 deletions python/tvm/ir/global_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,33 +37,20 @@ def same_as(self, other):
return super().__eq__(other)


class ReturnGlobalInfo(GlobalInfo):
"""ReturnGlobalInfo in the IR.
class RelaxReturnGlobalInfo(GlobalInfo):
"""RelaxReturnGlobalInfo in the IR.
Parameters
----------
return_exprs : List[Expr]
relax_return_exprs : List[Expr]
The expressions to be returned.
"""

return_exprs: List[Expr]
relax_return_exprs: List[Expr]

def __init__(self, return_exprs: List[Expr]) -> None:
print("yes entering ReturnGlobalInfo in global_info.py, return_exprs: ", return_exprs)
self.return_exprs = return_exprs
self.__init_handle_by_constructor__(_ffi_api.ReturnGlobalInfo, return_exprs)

def add():
pass

def update(return_exprs: List[Expr]):
pass

def get() -> GlobalInfo:
pass

def get_exprs(self):
return self.return_exprs
def __init__(self, relax_return_exprs: List[Expr]) -> None:
self.relax_return_exprs = relax_return_exprs
self.__init_handle_by_constructor__(_ffi_api.RelaxReturnGlobalInfo, relax_return_exprs)


class DummyGlobalInfo(GlobalInfo):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/ir_builder/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
module_global_infos,
module_get_global_infos,
module_update_global_infos,
return_global_info,
relax_return_global_info,
dummy_global_info,
)
13 changes: 7 additions & 6 deletions python/tvm/script/ir_builder/ir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from typing import Dict, List

from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, ReturnGlobalInfo, DummyGlobalInfo
from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, RelaxReturnGlobalInfo, DummyGlobalInfo
from tvm.ir import RelayExpr as Expr
from tvm.runtime import Object as tvm_Object

Expand Down Expand Up @@ -123,20 +123,21 @@ def module_update_global_infos(global_infos: Dict[str, List[GlobalInfo]]) -> Non
############################### GlobalInfo ###############################


def return_global_info(return_exprs: List[Expr]) -> ReturnGlobalInfo:
def relax_return_global_info(relax_return_exprs: List[Expr] = None) -> RelaxReturnGlobalInfo:
"""Create a return global info expression.
Parameters
----------
return_exprs : List[Expr]
relax_return_exprs : List[Expr]
The expressions to be returned.
Returns
-------
res : ReturnGlobalInfo
res : RelaxReturnGlobalInfo
The result return global info.
"""
print("yes return_global_info in ir_builder/ir.py")
return ReturnGlobalInfo(return_exprs) # type: ignore[attr-defined] # pylint: disable=no-member
if relax_return_exprs is None:
relax_return_exprs = []
return RelaxReturnGlobalInfo(relax_return_exprs) # type: ignore[attr-defined] # pylint: disable=no-member


def dummy_global_info() -> DummyGlobalInfo:
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,16 +231,22 @@ class Parser(doc.NodeVisitor):
var_table : VarTable
The variable table for parsing.
aux_dict: Dict[str, List[Any]]
The auxiliary dict for storing global info. like return exprs
of RelaxReturnGloablInfo
"""

diag: Diagnostics
dispatch_tokens: List[str]
var_table: VarTable
aux_dict: Dict[str, List[Any]]

def __init__(self, source: Source) -> None:
self.diag = Diagnostics(source)
self.dispatch_tokens = ["default"]
self.var_table = VarTable()
self.aux_dict = {}

def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any:
"""The main parse method for parser.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/parser/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
"module_global_infos",
"module_get_global_infos",
"module_update_global_infos",
"return_global_info",
"relax_return_global_info",
"dummy_global_info",
]
30 changes: 14 additions & 16 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,6 @@
from .entry import MatchCastPair, StructInfoProxy, TupleProxy


# An global list to record all exprs to return
return_expr_list = []


def bind_assign_value(
self: Parser,
node: doc.expr,
Expand Down Expand Up @@ -330,18 +326,18 @@ def visit_return(self: Parser, node: doc.Assign) -> None:
"""
TODO (yongwww):
issue 1): Save all values into a global list, and add into global_info in the end of parsing -> Status: wip
=> we can just have a single api like add_return_global_info into the ReturnGlobalInfo,
=> we can just have a single api like add_relax_return_global_info into the RelaxReturnGlobalInfo,
Solution:
[x]o1: Save all return values in a global list, and assembly it in the end of parsing,
don't allow user to provide it. Ignore if it exists
o2: Create an IRModuleNode::GetGlobalInfo(String name), plus UpdateGlobalInfo should help do the modification
But how to expose it to parser? doesn't work, hard to expose to ir_builder
o3: add ModuleGetGlobalInfos and ModuleUpdateGlobalInfos in src/script/ir_builder/ir/ir.cc
and python/tvm/script/ir_builder/ir/ir.py
how to reassembly the ReturnGlobalInfo is a problem, before the fetch returnGlobalInfo is a runtime.Object
how to reassembly the RelaxReturnGlobalInfo is a problem, before the fetch returnGlobalInfo is a runtime.Object
seems there is no way to update it, so give up o3
Solution: expose get elements of ReturnGlobalInfo into IR-builder
Solution: expose get elements of RelaxReturnGlobalInfo into IR-builder
issue 2): global issue was required explicitly at the beggining of the ir_module,
Expand All @@ -350,8 +346,8 @@ def visit_return(self: Parser, node: doc.Assign) -> None:
issue 3): need to hide the return global info, it shouldn't be visible to users,
it might crash the exiting test cases -> Status: todo
Solution: solution in 2) should help fix test cases, since we will have return_global_info anyway,
the only concern is that the ordering of return_exprs, topological ordering for relax func parsing
Solution: solution in 2) should help fix test cases, since we will have relax_return_global_info anyway,
the only concern is that the ordering of relax_return_exprs, topological ordering for relax func parsing
should fix it too. And it just potentially impact test structural_equal, no functionality impacted!
Conclusion:
Expand All @@ -361,14 +357,16 @@ def visit_return(self: Parser, node: doc.Assign) -> None:
So, I decided to move forward with GlobalInfo, because it is already there.
"""

return_expr_list.append(value)
print("Entering return visit")
# use var_table to record the return exprs
# "relax_return_exprs" was used as key for return exprs
return_expr_key = "relax_return_exprs"
if return_expr_key not in self.aux_dict:
self.aux_dict[return_expr_key] = []
self.aux_dict[return_expr_key].append(value)
ginfos = I.module_get_global_infos()
print("the current global info: ", ginfos)
ret_ginfo = I.return_global_info(return_expr_list)
# str "relax_return_exprs" was reserved as key for return exprs in global_info
ginfos["return_exprs"] = [ret_ginfo]

ret_ginfo = I.relax_return_global_info(self.aux_dict[return_expr_key])

ginfos[return_expr_key] = [ret_ginfo]
I.module_update_global_infos(ginfos)

R.ret_value(value) # TODO(yongwww): probably we can remove R.ret_value as well
Expand Down
13 changes: 7 additions & 6 deletions src/ir/global_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
#include <tvm/ir/global_info.h>
namespace tvm {

TVM_REGISTER_NODE_TYPE(ReturnGlobalInfoNode);
TVM_REGISTER_NODE_TYPE(RelaxReturnGlobalInfoNode);

TVM_REGISTER_GLOBAL("ir.ReturnGlobalInfo").set_body_typed([](Array<RelayExpr> return_exprs) {
auto n = make_object<ReturnGlobalInfoNode>();
n->return_exprs = return_exprs;
return ReturnGlobalInfo(n);
});
TVM_REGISTER_GLOBAL("ir.RelaxReturnGlobalInfo")
.set_body_typed([](Array<RelayExpr> relax_return_exprs) {
auto n = make_object<RelaxReturnGlobalInfoNode>();
n->relax_return_exprs = relax_return_exprs;
return RelaxReturnGlobalInfo(n);
});

TVM_REGISTER_NODE_TYPE(DummyGlobalInfoNode);
TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() {
Expand Down
11 changes: 0 additions & 11 deletions src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,6 @@ TVM_REGISTER_GLOBAL("relax.PrimValue").set_body_typed([](PrimExpr value, Span sp
return PrimValue(value, span);
});

NullExpr::NullExpr(Span span) {
ObjectPtr<NullExprNode> n = make_object<NullExprNode>();
n->checked_type_ = ObjectType();
n->struct_info_ = ObjectStructInfo();
n->span = std::move(span);
}

TVM_REGISTER_NODE_TYPE(NullExprNode);

TVM_REGISTER_GLOBAL("relax.NullExpr").set_body_typed([](Span span) { return NullExpr(span); });

StringImm::StringImm(String value, Span span) {
ObjectPtr<StringImmNode> n = make_object<StringImmNode>();
n->value = std::move(value);
Expand Down
6 changes: 2 additions & 4 deletions src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ void FunctionFrameNode::ExitWithScope() {
body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value()));
} else {
// todo (yongwww): handle null for no return for func's body
body =
this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, tvm::relax::NullExpr()));
//binding_blocks.pop_back()
LOG(FATAL) << "ValueError: Cannot find the output for the function";
}

auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
Expand Down Expand Up @@ -264,8 +264,6 @@ void ElseFrameNode::ExitWithScope() {
output = GetSeqExprForBranch(GetRef<ElseFrame>(this), &var_name);
IfFrame frame = FindIfFrame("R.Else");
frame->else_expr = output;
CHECK(frame->var_name == var_name)
<< "This last binding of both branches must have the same variable.";
}

TVM_REGISTER_NODE_TYPE(FunctionFrameNode);
Expand Down
33 changes: 15 additions & 18 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
->Call({d->AsDoc<ExprDoc>(mod->attrs, p->Attr("attrs"))})));
}
if (mod->global_infos.defined() && !mod->global_infos.empty()) {
// todo(yongwww): global return_exprs for printer
(*f)->stmts.push_back(ExprStmtDoc(
IR(d, "module_global_infos") //
->Call({d->AsDoc<ExprDoc>(mod->global_infos, p->Attr("global_infos"))})));
// RelaxReturnGlobalInfo was not printed
ExprStmtDoc mod_ginfos = ExprStmtDoc(
IR(d, "module_global_infos")
->Call({d->AsDoc<ExprDoc>(mod->global_infos, p->Attr("global_infos"))}));
if (mod->global_infos.size() > 1 || mod->global_infos.count("relax_return_exprs") == 0) {
(*f)->stmts.push_back(mod_ginfos);
}
}
for (const auto& entry : functions) {
const GlobalVar& gv = entry.gv;
Expand Down Expand Up @@ -104,20 +107,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<ReturnGlobalInfo>("",
[](ReturnGlobalInfo rginfo, ObjectPath p,
IRDocsifier d) -> Doc {
Array<ExprDoc> return_exprs;
for (const auto& ret_expr : rginfo->return_exprs) {
d->AddReturnExpr(ret_expr);
// return_exprs.push_back(d->AsDoc<ExprDoc>(ret_expr,
// p->Attr("return_exprs")));
}
// return IR(d,
// "return_global_info")->Call({ListDoc(return_exprs)});

return IR(d, "return_global_info")->Call({});
});
.set_dispatch<RelaxReturnGlobalInfo>("",
[](RelaxReturnGlobalInfo rginfo, ObjectPath p,
IRDocsifier d) -> Doc {
for (const auto& ret_expr : rginfo->relax_return_exprs) {
d->AddReturnExpr(ret_expr);
}
return IR(d, "relax_return_global_info")->Call({});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<DummyGlobalInfo>("", [](GlobalInfo ginfo, ObjectPath p, IRDocsifier d) -> Doc {
Expand Down
4 changes: 3 additions & 1 deletion src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) {
[{LiteralDoc::Int(index, NullOpt)}];
}

void IRDocsifierNode::AddReturnExpr(const RelayExpr& ret_expr) { return_exprs.insert(ret_expr); }
void IRDocsifierNode::AddReturnExpr(const RelayExpr& ret_expr) {
relax_return_exprs.insert(ret_expr);
}

Optional<RelayExpr> IRDocsifierNode::LookupBinding(const relax::Var& var) {
auto it = binding_table_.find(var->vid);
Expand Down
Loading

0 comments on commit e32d773

Please sign in to comment.