diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index d263ea360d672..a00ea5768e234 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -236,7 +236,6 @@ class IRBuilder : public runtime::ObjectRef { * \sa IRBuilder::ExitWithScope * \sa tvm::support::With */ - static std::vector All(); static IRBuilder Current(); /*! \brief See if the current thread-local scope has an IRBuilder. */ static bool IsInScope(); diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index c99b9220520ce..6d08107167535 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -101,7 +101,6 @@ class FunctionFrameNode : public SeqExprFrameNode { /*! \brief The function attributes. */ Map attrs; - // todo(yongwww) Add Map> global_infos; /*! \brief The block builder to create Relax function. */ tvm::relax::BlockBuilder block_builder; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index 469d94d3993d7..e65fa234e727b 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -145,8 +145,6 @@ class IRDocsifierNode : public Object { Array dispatch_tokens; /*! \brief Mapping from a var to its info */ std::unordered_map obj2info; - /*! \brief A binding table that maps var to value. */ - std::unordered_map binding_table_; /*! \brief Metadata printing */ std::unordered_map> metadata; /*! \brief Return exprs used to help tell whether or not an expr is a return*/ @@ -212,11 +210,11 @@ class IRDocsifierNode : public Object { Optional GetVarDoc(const ObjectRef& obj) const; /*! \brief Add a TVM object to the metadata section*/ ExprDoc AddMetadata(const ObjectRef& obj); - - Optional LookupBinding(const relax::Var& var); - + /*! + * \brief Add an expression into return expression set. + * \param ret_expr The return expression. + */ void AddReturnExpr(const RelayExpr& ret_expr); - /*! * \brief Check if a variable exists in the table. * \param obj The variable object. diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index f7fad0e664a28..6058b0ccdfeba 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -323,53 +323,17 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: def visit_return(self: Parser, node: doc.Assign) -> None: value = self.eval_expr(node.value) value = convert_to_expr(value) - """ - TODO (yongwww): - issue 1): Save all values into a global list, and add into global_info in the end of parsing -> Status: wip - => we can just have a single api like add_relax_return_global_info into the RelaxReturnGlobalInfo, - Solution: - [x]o1: Save all return values in a global list, and assembly it in the end of parsing, - don't allow user to provide it. Ignore if it exists - o2: Create an IRModuleNode::GetGlobalInfo(String name), plus UpdateGlobalInfo should help do the modification - But how to expose it to parser? doesn't work, hard to expose to ir_builder - o3: add ModuleGetGlobalInfos and ModuleUpdateGlobalInfos in src/script/ir_builder/ir/ir.cc - and python/tvm/script/ir_builder/ir/ir.py - how to reassembly the RelaxReturnGlobalInfo is a problem, before the fetch returnGlobalInfo is a runtime.Object - seems there is no way to update it, so give up o3 - - Solution: expose get elements of RelaxReturnGlobalInfo into IR-builder - - - issue 2): global issue was required explicitly at the beggining of the ir_module, - need to figure out a way to update/create a return global info at any point -> Status: todo - Solution: No matter if the tvmscript has explicitly feed the module_gloabl_info or not, and one for return! - - issue 3): need to hide the return global info, it shouldn't be visible to users, - it might crash the exiting test cases -> Status: todo - Solution: solution in 2) should help fix test cases, since we will have relax_return_global_info anyway, - the only concern is that the ordering of relax_return_exprs, topological ordering for relax func parsing - should fix it too. And it just potentially impact test structural_equal, no functionality impacted! - - Conclusion: - 1) The best way is to add "Bool return_body" in SeqExpr, but we need to keep IR constrained at this moment - 2) Introduce func_info in relax function level, similar to global info, but it will introduce return_func_info - into Function, and the IR is affected, then prefer option 1) - So, I decided to move forward with GlobalInfo, because it is already there. - """ - # "relax_return_exprs" was used as key for return exprs return_expr_key = "relax_return_exprs" if return_expr_key not in self.aux_dict: self.aux_dict[return_expr_key] = [] self.aux_dict[return_expr_key].append(value) + # update the return global info ginfos = I.module_get_global_infos() - ret_ginfo = I.relax_return_global_info(self.aux_dict[return_expr_key]) - ginfos[return_expr_key] = [ret_ginfo] I.module_update_global_infos(ginfos) - - R.ret_value(value) # TODO(yongwww): probably we can remove R.ret_value as well + R.ret_value(value) @dispatch.register(token="relax", type_name="If") @@ -379,10 +343,11 @@ def visit_if(self: Parser, node: doc.If) -> None: with R.If(self.eval_expr(node.test)) as if_frame: with self.var_table.with_frame(): with R.Then(): - print("Entering R.Then") self.visit_body(node.body) with self.var_table.with_frame(): with R.Else(): - print("Entering R.Else") self.visit_body(node.orelse) - self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) + if not if_frame.var_name: + self.var_table.add(str(if_frame.var), if_frame.var, allow_shadowing=True) + else: + self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc index 96587de74b039..22623b168b903 100644 --- a/src/ir/global_info.cc +++ b/src/ir/global_info.cc @@ -39,5 +39,4 @@ TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() { auto n = DummyGlobalInfo(make_object()); return n; }); - } // namespace tvm diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 7c2282469931e..c4c57d0c2c2d3 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -60,7 +60,6 @@ void FunctionFrameNode::ExitWithScope() { body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); } else { // todo (yongwww): handle null for no return for func's body - //binding_blocks.pop_back() LOG(FATAL) << "ValueError: Cannot find the output for the function"; } diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 37fa1b053784e..a0ae45677d89c 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -109,18 +109,13 @@ void RetValue(const tvm::relax::Expr& value) { // Exit BlockFrame if (block_frame.defined()) { block_frame.value()->ExitWithScope(); - //ICHECK(!IRBuilder::Current()->FindFrame()) + // ICHECK(!IRBuilder::Current()->FindFrame()) // << "ValueError: Relax functions don't support return in true/false branch of If Node."; } // Step 2. Add the output value to the function frame. Array all_frames = IRBuilder::Current()->frames; - int i = 0; - for (auto f : all_frames) { - LOG(INFO) << "yongwww frame_" << i++ << " = " << f; - } // IfFrame if_frame = IRBuilder::Current()->FindFrame().value(); - // LOG(INFO) << "return if_frame: " << if_frame; IRBuilderFrame last_frame = all_frames[all_frames.size() - 1]; Optional then_frame = IRBuilder::Current()->GetLastFrame(); @@ -138,7 +133,6 @@ void RetValue(const tvm::relax::Expr& value) { LOG(INFO) << "return FunctionFrame frame: " << frame; frame->output = std::move(normalized_value); - // block_frame.value()->ExitWithScope(); } TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 33ba2b78995c9..4e4cf51592f22 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -67,12 +67,6 @@ void IRDocsifierNode::AddReturnExpr(const RelayExpr& ret_expr) { relax_return_exprs.insert(ret_expr); } -Optional IRDocsifierNode::LookupBinding(const relax::Var& var) { - auto it = binding_table_.find(var->vid); - if (it == binding_table_.end()) return NullOpt; - return it->second; -} - bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); } void IRDocsifierNode::RemoveVar(const ObjectRef& obj) { diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index 81e6d41c73c3e..8d48adc171478 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -27,67 +27,26 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& using relax::SeqExpr; ExprDoc cond = d->AsDoc(n->cond, n_p->Attr("cond")); std::vector> branches; - // todo(yongwww): looks the relax_return_exprs are the values, and normalizer adds a new binding - // need to figure out a way to get if the seqexpr.body was bound to one of relax_return_exprs, too - // complicated! - for (auto ret_expr : d->relax_return_exprs) { - LOG(INFO) << "yongwww 33 ret_expr: " << ret_expr; - } - auto true_seq_expr = Downcast(n->true_branch); - auto false_seq_expr = Downcast(n->false_branch); - if (const auto* var_node = true_seq_expr->body.as()) { - auto t_var = GetRef(var_node); - LOG(INFO) << "yongwww true_seq_expr->body: " << t_var << " -- val: " << d->LookupBinding(t_var); - } - - for (auto ele : d->binding_table_) { - LOG(INFO) << "ele k: " << ele.first << " - value: " << ele.second; - } - - if (const auto* var_node = false_seq_expr->body.as()) { - auto t_var = GetRef(var_node); - LOG(INFO) << "yongwww false_seq_expr->body: " << t_var - << " -- val: " << d->LookupBinding(t_var); - } - bool ret_true_branch = false; - bool ret_false_branch = false; - relax::BindingBlock last_block_true = true_seq_expr->blocks[true_seq_expr->blocks.size() - 1]; - relax::Binding last_binding_true = - last_block_true->bindings[last_block_true->bindings.size() - 1]; - if (auto* var_binding = last_binding_true.as()) { - auto last_var_binding_true = GetRef(var_binding); - if (last_var_binding_true->var.same_as(true_seq_expr->body) && - d->relax_return_exprs.find(last_var_binding_true->value) != d->relax_return_exprs.end()) { - ret_true_branch = true; - LOG(INFO) << "yongwww ret_true_branch true"; - } - } - - relax::BindingBlock last_block_false = false_seq_expr->blocks[false_seq_expr->blocks.size() - 1]; - relax::Binding last_binding_false = - last_block_false->bindings[last_block_false->bindings.size() - 1]; - if (auto* var_binding = last_binding_false.as()) { - auto last_var_binding_false = GetRef(var_binding); - if (last_var_binding_false->var.same_as(false_seq_expr->body) && - d->relax_return_exprs.find(last_var_binding_false->value) != d->relax_return_exprs.end()) { - ret_false_branch = true; - LOG(INFO) << "yongwww ret_false_branch true"; + // normalizer adds a new binding, need to figure out if the seqexpr.body was bound + auto is_return = [](const SeqExpr& seq_expr, const IRDocsifier& dd) { + relax::BindingBlock last_block = seq_expr->blocks[seq_expr->blocks.size() - 1]; + relax::Binding last_binding = last_block->bindings[last_block->bindings.size() - 1]; + if (auto* var_binding = last_binding.as()) { + auto last_var_binding = GetRef(var_binding); + if (last_var_binding->var.same_as(seq_expr->body) && + dd->relax_return_exprs.find(last_var_binding->value) != dd->relax_return_exprs.end()) { + return true; + } } - } + return false; + }; - if (d->relax_return_exprs.find(true_seq_expr->body) != d->relax_return_exprs.end()) { - branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); - } else { - branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); - } - - if (d->relax_return_exprs.find(false_seq_expr->body) != d->relax_return_exprs.end()) { - branches.push_back( - PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); - } else { - branches.push_back( - PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); - } + auto true_seq_expr = Downcast(n->true_branch); + auto false_seq_expr = Downcast(n->false_branch); + bool ret_true_branch = is_return(true_seq_expr, d); + bool ret_false_branch = is_return(false_seq_expr, d); + branches.push_back(PrintSeqExpr(true_seq_expr, n_p->Attr("true_branch"), d, ret_true_branch)); + branches.push_back(PrintSeqExpr(false_seq_expr, n_p->Attr("false_branch"), d, ret_false_branch)); if (var.defined()) { for (Array& stmts : branches) { @@ -95,7 +54,6 @@ IfDoc PrintIfExpr(const relax::If& n, const ObjectPath& n_p, const IRDocsifier& ExprDoc ret = Downcast(stmts.back())->expr; stmts.Set(stmts.size() - 1, AssignDoc(var.value(), ret, ann)); } - LOG(INFO) << "yongwww stmts.back() key: " << stmts.back()->GetTypeKey(); } } return IfDoc(cond, branches[0], branches[1]); @@ -117,7 +75,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](relax::VarBinding n, ObjectPath n_p, IRDocsifier d) -> Doc { - d->binding_table_[n->var->vid] = n->value; if (const auto if_ = n->value.as()) { Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); ExprDoc lhs = DefineVar(n->var, d->frames.back(), d); diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 8a6419e5fe3ed..298d58218b6f6 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -33,6 +33,7 @@ def _check( expect: Optional[Union[relax.Function, IRModule]] = None, ): test = parsed.script(show_meta=True) + print(test) roundtrip_mod = tvm.script.from_source(test) tvm.ir.assert_structural_equal(parsed, roundtrip_mod) if expect: @@ -1192,11 +1193,11 @@ def noelse(x: R.Tensor) -> R.Tensor: def foo0(x: R.Tensor) -> R.Tensor: y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") if y: - v = R.add(x, x) - return v + r = R.add(x, x) + return r else: - v = R.multiply(x, x) - return v + r = R.multiply(x, x) + return r @R.function def foo1(x: R.Tensor) -> R.Tensor: @@ -1211,19 +1212,27 @@ def foo1(x: R.Tensor) -> R.Tensor: def foo2(x: R.Tensor) -> R.Tensor: y: R.Tensor((), dtype="bool") = R.const(True, dtype="bool") if y: - v = R.add(x, x) + r = R.add(x, x) else: return R.multiply(x, x) - return v + return r MultiReturn.show() - print("yongwww get_global_info:", MultiReturn.get_global_info("relax_return_exprs")) + # print("yongwww get_global_info:", MultiReturn.get_global_info("relax_return_exprs")) _check(MultiReturn) roundtrip_mod = tvm.script.from_source(MultiReturn.script(show_meta=True)) tvm.ir.assert_structural_equal(MultiReturn, roundtrip_mod, True) +def test_meta_data(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor(None, "float32", ndim=2): + a = R.const([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], "float32") + g = R.add(x, a) + return g + + _check(foo) + + if __name__ == "__main__": - # tvm.testing.main() - # test_module_with_attr_and_global_info() - test_multi_return() + tvm.testing.main()