Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Mar 2, 2023
1 parent e32d773 commit 6e40d0e
Show file tree
Hide file tree
Showing 10 changed files with 48 additions and 135 deletions.
1 change: 0 additions & 1 deletion include/tvm/script/ir_builder/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ class IRBuilder : public runtime::ObjectRef {
* \sa IRBuilder::ExitWithScope
* \sa tvm::support::With
*/
static std::vector<IRBuilder> All();
static IRBuilder Current();
/*! \brief See if the current thread-local scope has an IRBuilder. */
static bool IsInScope();
Expand Down
1 change: 0 additions & 1 deletion include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ class FunctionFrameNode : public SeqExprFrameNode {
/*! \brief The function attributes. */
Map<String, ObjectRef> attrs;

// todo(yongwww) Add Map<String, Array<GlobalInfo>> global_infos;
/*! \brief The block builder to create Relax function. */
tvm::relax::BlockBuilder block_builder;

Expand Down
10 changes: 4 additions & 6 deletions include/tvm/script/printer/ir_docsifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,6 @@ class IRDocsifierNode : public Object {
Array<String> dispatch_tokens;
/*! \brief Mapping from a var to its info */
std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> obj2info;
/*! \brief A binding table that maps var to value. */
std::unordered_map<relax::Id, RelayExpr, ObjectPtrHash, ObjectPtrEqual> binding_table_;
/*! \brief Metadata printing */
std::unordered_map<String, Array<ObjectRef>> metadata;
/*! \brief Return exprs used to help tell whether or not an expr is a return*/
Expand Down Expand Up @@ -212,11 +210,11 @@ class IRDocsifierNode : public Object {
Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;
/*! \brief Add a TVM object to the metadata section*/
ExprDoc AddMetadata(const ObjectRef& obj);

Optional<RelayExpr> LookupBinding(const relax::Var& var);

/*!
* \brief Add an expression into return expression set.
* \param ret_expr The return expression.
*/
void AddReturnExpr(const RelayExpr& ret_expr);

/*!
* \brief Check if a variable exists in the table.
* \param obj The variable object.
Expand Down
47 changes: 6 additions & 41 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,53 +323,17 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
def visit_return(self: Parser, node: doc.Assign) -> None:
value = self.eval_expr(node.value)
value = convert_to_expr(value)
"""
TODO (yongwww):
issue 1): Save all values into a global list, and add into global_info in the end of parsing -> Status: wip
=> we can just have a single api like add_relax_return_global_info into the RelaxReturnGlobalInfo,
Solution:
[x]o1: Save all return values in a global list, and assembly it in the end of parsing,
don't allow user to provide it. Ignore if it exists
o2: Create an IRModuleNode::GetGlobalInfo(String name), plus UpdateGlobalInfo should help do the modification
But how to expose it to parser? doesn't work, hard to expose to ir_builder
o3: add ModuleGetGlobalInfos and ModuleUpdateGlobalInfos in src/script/ir_builder/ir/ir.cc
and python/tvm/script/ir_builder/ir/ir.py
how to reassembly the RelaxReturnGlobalInfo is a problem, before the fetch returnGlobalInfo is a runtime.Object
seems there is no way to update it, so give up o3
Solution: expose get elements of RelaxReturnGlobalInfo into IR-builder
issue 2): global issue was required explicitly at the beggining of the ir_module,
need to figure out a way to update/create a return global info at any point -> Status: todo
Solution: No matter if the tvmscript has explicitly feed the module_gloabl_info or not, and one for return!
issue 3): need to hide the return global info, it shouldn't be visible to users,
it might crash the exiting test cases -> Status: todo
Solution: solution in 2) should help fix test cases, since we will have relax_return_global_info anyway,
the only concern is that the ordering of relax_return_exprs, topological ordering for relax func parsing
should fix it too. And it just potentially impact test structural_equal, no functionality impacted!
Conclusion:
1) The best way is to add "Bool return_body" in SeqExpr, but we need to keep IR constrained at this moment
2) Introduce func_info in relax function level, similar to global info, but it will introduce return_func_info
into Function, and the IR is affected, then prefer option 1)
So, I decided to move forward with GlobalInfo, because it is already there.
"""

# "relax_return_exprs" was used as key for return exprs
return_expr_key = "relax_return_exprs"
if return_expr_key not in self.aux_dict:
self.aux_dict[return_expr_key] = []
self.aux_dict[return_expr_key].append(value)
# update the return global info
ginfos = I.module_get_global_infos()

ret_ginfo = I.relax_return_global_info(self.aux_dict[return_expr_key])

ginfos[return_expr_key] = [ret_ginfo]
I.module_update_global_infos(ginfos)

R.ret_value(value) # TODO(yongwww): probably we can remove R.ret_value as well
R.ret_value(value)


@dispatch.register(token="relax", type_name="If")
Expand All @@ -379,10 +343,11 @@ def visit_if(self: Parser, node: doc.If) -> None:
with R.If(self.eval_expr(node.test)) as if_frame:
with self.var_table.with_frame():
with R.Then():
print("Entering R.Then")
self.visit_body(node.body)
with self.var_table.with_frame():
with R.Else():
print("Entering R.Else")
self.visit_body(node.orelse)
self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True)
if not if_frame.var_name:
self.var_table.add(str(if_frame.var), if_frame.var, allow_shadowing=True)
else:
self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True)
1 change: 0 additions & 1 deletion src/ir/global_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,4 @@ TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() {
auto n = DummyGlobalInfo(make_object<DummyGlobalInfoNode>());
return n;
});

} // namespace tvm
1 change: 0 additions & 1 deletion src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ void FunctionFrameNode::ExitWithScope() {
body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value()));
} else {
// todo (yongwww): handle null for no return for func's body
//binding_blocks.pop_back()
LOG(FATAL) << "ValueError: Cannot find the output for the function";
}

Expand Down
8 changes: 1 addition & 7 deletions src/script/ir_builder/relax/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,13 @@ void RetValue(const tvm::relax::Expr& value) {
// Exit BlockFrame
if (block_frame.defined()) {
block_frame.value()->ExitWithScope();
//ICHECK(!IRBuilder::Current()->FindFrame<BlockFrame>())
// ICHECK(!IRBuilder::Current()->FindFrame<BlockFrame>())
// << "ValueError: Relax functions don't support return in true/false branch of If Node.";
}
// Step 2. Add the output value to the function frame.
Array<IRBuilderFrame> all_frames = IRBuilder::Current()->frames;
int i = 0;
for (auto f : all_frames) {
LOG(INFO) << "yongwww frame_" << i++ << " = " << f;
}

// IfFrame if_frame = IRBuilder::Current()->FindFrame<IfFrame>().value();
// LOG(INFO) << "return if_frame: " << if_frame;

IRBuilderFrame last_frame = all_frames[all_frames.size() - 1];
Optional<ThenFrame> then_frame = IRBuilder::Current()->GetLastFrame<ThenFrame>();
Expand All @@ -138,7 +133,6 @@ void RetValue(const tvm::relax::Expr& value) {
LOG(INFO) << "return FunctionFrame frame: " << frame;

frame->output = std::move(normalized_value);
// block_frame.value()->ExitWithScope();
}

TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function);
Expand Down
6 changes: 0 additions & 6 deletions src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ void IRDocsifierNode::AddReturnExpr(const RelayExpr& ret_expr) {
relax_return_exprs.insert(ret_expr);
}

Optional<RelayExpr> IRDocsifierNode::LookupBinding(const relax::Var& var) {
auto it = binding_table_.find(var->vid);
if (it == binding_table_.end()) return NullOpt;
return it->second;
}

bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); }

void IRDocsifierNode::RemoveVar(const ObjectRef& obj) {
Expand Down
79 changes: 18 additions & 61 deletions src/script/printer/relax/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,75 +27,33 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier&
using relax::SeqExpr;
ExprDoc cond = d->AsDoc<ExprDoc>(n->cond, n_p->Attr("cond"));
std::vector<Array<StmtDoc>> branches;
// todo(yongwww): looks the relax_return_exprs are the values, and normalizer adds a new binding
// need to figure out a way to get if the seqexpr.body was bound to one of relax_return_exprs, too
// complicated!
for (auto ret_expr : d->relax_return_exprs) {
LOG(INFO) << "yongwww 33 ret_expr: " << ret_expr;
}
auto true_seq_expr = Downcast<SeqExpr>(n->true_branch);
auto false_seq_expr = Downcast<SeqExpr>(n->false_branch);
if (const auto* var_node = true_seq_expr->body.as<relax::VarNode>()) {
auto t_var = GetRef<relax::Var>(var_node);
LOG(INFO) << "yongwww true_seq_expr->body: " << t_var << " -- val: " << d->LookupBinding(t_var);
}

for (auto ele : d->binding_table_) {
LOG(INFO) << "ele k: " << ele.first << " - value: " << ele.second;
}

if (const auto* var_node = false_seq_expr->body.as<relax::VarNode>()) {
auto t_var = GetRef<relax::Var>(var_node);
LOG(INFO) << "yongwww false_seq_expr->body: " << t_var
<< " -- val: " << d->LookupBinding(t_var);
}
bool ret_true_branch = false;
bool ret_false_branch = false;
relax::BindingBlock last_block_true = true_seq_expr->blocks[true_seq_expr->blocks.size() - 1];
relax::Binding last_binding_true =
last_block_true->bindings[last_block_true->bindings.size() - 1];
if (auto* var_binding = last_binding_true.as<relax::VarBindingNode>()) {
auto last_var_binding_true = GetRef<relax::VarBinding>(var_binding);
if (last_var_binding_true->var.same_as(true_seq_expr->body) &&
d->relax_return_exprs.find(last_var_binding_true->value) != d->relax_return_exprs.end()) {
ret_true_branch = true;
LOG(INFO) << "yongwww ret_true_branch true";
}
}

relax::BindingBlock last_block_false = false_seq_expr->blocks[false_seq_expr->blocks.size() - 1];
relax::Binding last_binding_false =
last_block_false->bindings[last_block_false->bindings.size() - 1];
if (auto* var_binding = last_binding_false.as<relax::VarBindingNode>()) {
auto last_var_binding_false = GetRef<relax::VarBinding>(var_binding);
if (last_var_binding_false->var.same_as(false_seq_expr->body) &&
d->relax_return_exprs.find(last_var_binding_false->value) != d->relax_return_exprs.end()) {
ret_false_branch = true;
LOG(INFO) << "yongwww ret_false_branch true";
// normalizer adds a new binding, need to figure out if the seqexpr.body was bound
auto is_return = [](const SeqExpr& seq_expr, const IRDocsifier& dd) {
relax::BindingBlock last_block = seq_expr->blocks[seq_expr->blocks.size() - 1];
relax::Binding last_binding = last_block->bindings[last_block->bindings.size() - 1];
if (auto* var_binding = last_binding.as<relax::VarBindingNode>()) {
auto last_var_binding = GetRef<relax::VarBinding>(var_binding);
if (last_var_binding->var.same_as(seq_expr->body) &&
dd->relax_return_exprs.find(last_var_binding->value) != dd->relax_return_exprs.end()) {
return true;
}
}
}
return false;
};

if (d->relax_return_exprs.find(true_seq_expr->body) != d->relax_return_exprs.end()) {
branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch));
} else {
branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch));
}

if (d->relax_return_exprs.find(false_seq_expr->body) != d->relax_return_exprs.end()) {
branches.push_back(
PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch));
} else {
branches.push_back(
PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch));
}
auto true_seq_expr = Downcast<SeqExpr>(n->true_branch);
auto false_seq_expr = Downcast<SeqExpr>(n->false_branch);
bool ret_true_branch = is_return(true_seq_expr, d);
bool ret_false_branch = is_return(false_seq_expr, d);
branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch));
branches.push_back(PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch));

if (var.defined()) {
for (Array<StmtDoc>& stmts : branches) {
if (!stmts.back()->IsInstance<ReturnDocNode>()) {
ExprDoc ret = Downcast<ExprStmtDoc>(stmts.back())->expr;
stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann));
}
LOG(INFO) << "yongwww stmts.back() key: " << stmts.back()->GetTypeKey();
}
}
return IfDoc(cond, branches[0], branches[1]);
Expand All @@ -117,7 +75,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<relax::VarBinding>( //
"", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc {
d->binding_table_[n->var->vid] = n->value;
if (const auto if_ = n->value.as<relax::IfNode>()) {
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
ExprDoc lhs = DefineVar(n->var, d->frames.back(), d);
Expand Down
29 changes: 19 additions & 10 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _check(
expect: Optional[Union[relax.Function, IRModule]] = None,
):
test = parsed.script(show_meta=True)
print(test)
roundtrip_mod = tvm.script.from_source(test)
tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
if expect:
Expand Down Expand Up @@ -1192,11 +1193,11 @@ def noelse(x: R.Tensor) -> R.Tensor:
def foo0(x: R.Tensor) -> R.Tensor:
y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool")
if y:
v = R.add(x, x)
return v
r = R.add(x, x)
return r
else:
v = R.multiply(x, x)
return v
r = R.multiply(x, x)
return r

@R.function
def foo1(x: R.Tensor) -> R.Tensor:
Expand All @@ -1211,19 +1212,27 @@ def foo1(x: R.Tensor) -> R.Tensor:
def foo2(x: R.Tensor) -> R.Tensor:
y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool")
if y:
v = R.add(x, x)
r = R.add(x, x)
else:
return R.multiply(x, x)
return v
return r

MultiReturn.show()
print("yongwww get_global_info:", MultiReturn.get_global_info("relax_return_exprs"))
# print("yongwww get_global_info:", MultiReturn.get_global_info("relax_return_exprs"))
_check(MultiReturn)
roundtrip_mod = tvm.script.from_source(MultiReturn.script(show_meta=True))
tvm.ir.assert_structural_equal(MultiReturn, roundtrip_mod, True)


def test_meta_data():
@R.function
def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2):
a = R.const([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], "float32")
g = R.add(x, a)
return g

_check(foo)


if __name__ == "__main__":
# tvm.testing.main()
# test_module_with_attr_and_global_info()
test_multi_return()
tvm.testing.main()

0 comments on commit 6e40d0e

Please sign in to comment.