Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[USMP] Initial implementation of liveness analysis for Relax + TIR #250

Open
wants to merge 170 commits into
base: relax
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
170 commits
Select commit Hold shift + click to select a range
bf03d1b
disable GH
tqchen May 21, 2021
4863b8e
Relax Virtual Machine
ZihengJiang Apr 19, 2021
6dffcf7
Relax AST (#2)
jroesch Aug 12, 2021
448507d
Implementation of CallDPS (#3)
ZihengJiang Aug 21, 2021
8b06254
Update AST and Shape() implementation (#5)
ZihengJiang Aug 30, 2021
738ce46
Relax IRBuilder (#4)
YuchenJin Aug 31, 2021
fe23a23
Relax IR Parser (#6)
altanh Sep 13, 2021
7ccf467
Shape and type deduction (#7)
YuchenJin Sep 21, 2021
ac572d3
Relax pretty printer (#8)
altanh Sep 23, 2021
8c1383e
[Parser][Printer] Switch to output annotation for dataflow blocks (#9)
altanh Sep 24, 2021
b7a781d
Update MatchShape AST Node (#11)
ZihengJiang Sep 27, 2021
f8519c1
[Parser][Printer] More parser/printer improvements (#12)
altanh Sep 27, 2021
fd8f28d
Relax IRVisitor/IRMuator (#10)
YuchenJin Sep 27, 2021
f3c1d02
[Parser][Printer] update parser and printer for match_shape (#13)
altanh Sep 27, 2021
f952881
Reorganize source code. (#14)
ZihengJiang Sep 28, 2021
4a6750a
[Parser][Printer] Add class -> IRModule parsing, and extern func supp…
altanh Sep 29, 2021
97192ce
[Parser][Printer] relax call_packed arity, return IRModule factory, p…
altanh Sep 29, 2021
057416b
[PASS] Shape lowering (#16)
ZihengJiang Sep 30, 2021
debc524
[Parser][Printer] explicitly parse and print attrs_type_key in calls …
altanh Sep 30, 2021
175b9da
VM compiler. (#18)
YuchenJin Oct 6, 2021
4a1d3da
Add type hint. (#20)
YuchenJin Oct 6, 2021
9efbd96
Redesign IRBuilder to BlockBuilder (#22)
YuchenJin Oct 17, 2021
f5ff3bb
End2End Lowering Stage2: Enable Lowering from ShapeExpr to VM Executa…
ZihengJiang Oct 18, 2021
ab47e0a
End2End Lowering (#23)
ZihengJiang Oct 22, 2021
d4a6db9
Fixes and improvements (#24)
altanh Oct 30, 2021
390d239
Update fixes for rebase
tqchen Oct 30, 2021
b6fb268
rebase is green
altanh Nov 3, 2021
ab35a8e
VM compiler refactor (#25)
YuchenJin Nov 4, 2021
99401ac
Fix vm build. (#35)
YuchenJin Nov 9, 2021
35d1ae3
ExprMutator refactor & Normalizer (#32)
YuchenJin Nov 10, 2021
88a6228
Migrate passes to Pass Infra (#37)
YuchenJin Nov 12, 2021
d958b06
Generic dispatching in Visitor (#39)
YuchenJin Nov 15, 2021
d6168d5
Update Shape lowering pass (#38)
YuchenJin Nov 16, 2021
e7f55b1
fix IRModule parsing by resolving GlobalVars later (#41)
altanh Nov 18, 2021
d2028a7
TE Integration (#36)
ZihengJiang Nov 21, 2021
ca726ca
Visit shape in Visitor/Mutator (#45)
YuchenJin Nov 24, 2021
252715f
Call topi and external library through emit_te and add MLP example (#50)
YuchenJin Nov 24, 2021
a5fb98b
[EmitTE] EmitTE Symbolic Shape (#53)
hypercubestart Dec 1, 2021
681dd82
Update vm build. (#55)
YuchenJin Dec 1, 2021
22fdee3
[TESTING] pytorch-like nn.Module API to build neural network (#54)
YuchenJin Dec 7, 2021
f968748
[EmitTe] Dynamic TIR Function (w/ unbound TIR vars) (#57)
hypercubestart Dec 15, 2021
5eb7576
call_dps -> call_tir (#60)
electriclilies Jan 5, 2022
4a8d252
[VM] Add control flow to relax vm (#61)
YuchenJin Jan 14, 2022
ab3dd5a
[Refactor] Format; Simplify PackedFunc Registration; Unify parameter …
junrushao Jan 16, 2022
4dff592
[Relax/IR] Disallow creating Binding directly (#66)
junrushao Jan 17, 2022
9d4de10
[EmitTE] multi-output semantics for call_tir, Tuples (#62)
hypercubestart Jan 17, 2022
9a3d9af
xfail => pytest.raises; fix a unittest (#67)
junrushao Jan 18, 2022
c4b585d
[CreatePrimFunc] Support multi-source ReduceNode (#64)
hypercubestart Jan 24, 2022
900c259
[BlockBuilder] Avoid generating duplicated PrimFunc (#68)
YuchenJin Jan 24, 2022
84ba831
[CI] Set up CI; format and lint relax code to pass CI (#72)
YuchenJin Jan 25, 2022
2f3c297
[VM] Add control flow in VmCodeGen, add vm.builtin.copy (#69)
yongwww Jan 26, 2022
70468fe
[BlockBuilder] Emit TupleGetItem (#73)
YuchenJin Feb 2, 2022
0a552da
[Pass] Refactor shape lowering pass to better handle static shape cas…
YuchenJin Feb 2, 2022
ea8cc43
Parse relax TupleGetItem (#77)
yongwww Feb 9, 2022
d01c616
Relay->Relax translator (ResNet example) (#75)
YuchenJin Feb 14, 2022
6345945
AutoTIR integration (#58)
sunggg Feb 15, 2022
08aa45f
Bug fix; print ShapeExpr (#82)
YuchenJin Feb 25, 2022
45a6af2
[TESTS] Enable Tests (#78)
ZihengJiang Feb 25, 2022
a15b3e1
Make offset type specific to avoid errors on non-linux systems. (#84)
jwfromm Mar 2, 2022
79e39d7
Rebase.
YuchenJin Mar 2, 2022
fdb8837
[Bugfix] Fix bb multi-function creation bug (#86)
Robslhc Mar 4, 2022
364d3e8
Add metadata section, support constant and metadata in parser & print…
yongwww Mar 10, 2022
3314a66
Fix bug in relax.vm.build to pass target argument. (#91)
Mar 11, 2022
ce70b0f
Clean up task extraction (#92)
masahi Mar 11, 2022
0f0a87c
Change call_tir convention; Unify shape/type deduction rule (#94)
YuchenJin Mar 18, 2022
0300c31
[VM] Enhance VM Executable as a Subclass of runtime::Module (#95)
MasterJH5574 Mar 20, 2022
01cc71f
[VM] Refactor and improve vm. (#96)
tqchen Mar 21, 2022
6086b0d
[VM][Refactor] Move VM files to TVM runtime directory (#98)
MasterJH5574 Mar 22, 2022
fd77c23
Improve printer for DynTensorType and ShapeExpr (#97)
LeshengJin Mar 23, 2022
e2a5f17
Fix after rebase
yongwww Mar 23, 2022
8118fb4
[VM] Initialize VM through packed function (#101)
MasterJH5574 Mar 24, 2022
d37f1eb
[VM] Fix hardcoded device type in memory lowering (#106)
YuchenJin Mar 25, 2022
da397f9
[Bugfix] Fix call_tir parsing bug (#109)
YuchenJin Mar 25, 2022
50866dc
[FIX] fix structural_equal_hash (#107)
Hzfengsy Mar 26, 2022
5998d07
introduce blockbuilder call_te (#110)
Hzfengsy Mar 28, 2022
9272d85
[FIX] Fix structure equal hash for MatchShape (#112)
tqchen Mar 28, 2022
424f0f9
[CI] Enable GPU tests; Add AutoTIR cuda test. (#115)
YuchenJin Mar 30, 2022
03c4191
[BlockBuilder] Deduce and fill shape/type for Expr in Normalize. (#116)
YuchenJin Apr 1, 2022
839ec89
Temporary remove function type deduction in normalizer. (#119)
tqchen Apr 2, 2022
e553000
[PASS] Fold constant & Bind Params (#113)
jinhongyii Apr 2, 2022
cc4ef90
Add a new Expr to represent runtime dependent shapes. (#117)
psrivas2 Apr 4, 2022
5ee3d91
Remove type annotation from Var. (#121)
YuchenJin Apr 5, 2022
b1e7b93
DataflowBlockPass (#114)
LeshengJin Apr 6, 2022
6cec9f3
[VM] Copy constant tensors to device (#124)
MasterJH5574 Apr 12, 2022
69bd401
[VM] Support sub function call and recursion. (#125)
YuchenJin Apr 12, 2022
37b6fbd
Update autotir integration after rebase
yongwww Apr 12, 2022
a4376d5
Add tune_relax to integrate with task scheduler (#127)
yongwww Apr 13, 2022
9ed287a
Deprecate `[]` in favor `()` in Tensor annotation. (#123)
psrivas2 Apr 14, 2022
7bf563a
[Relay Translator] Use OpStrategy for lowering (#130)
sunggg Apr 21, 2022
0151bc5
[Relax][MS] Task extraction with proper weights (#129)
MasterJH5574 Apr 21, 2022
906c7be
[Pass] Python pass decorator and ExprFunctor (#126)
LeshengJin Apr 22, 2022
a8dd6f2
[Printer][Parser] Modify Tensor annotation printing and parsing. (#128)
psrivas2 Apr 24, 2022
94425fc
[AST][BlockBuilder] Normalize relax.Function; Refactor BlockBuilder t…
YuchenJin Apr 26, 2022
58ca128
[AST][Type] Introduce ObjectType; Infer the type of call_packed by ty…
YuchenJin Apr 28, 2022
9a13879
[BlockBuilder] Sub function call shape deduction: constant shape case…
YuchenJin May 1, 2022
8621e55
Introduce Relax function attribute and drop name field in Relax funct…
sunggg May 2, 2022
ab700bf
Add ShapeType to ShapeExpr.checked_type during construction (#139)
Hzfengsy May 3, 2022
a0c7e21
Add `relax.unique` operator in Relax. (#135)
psrivas2 May 5, 2022
f0c2228
FuseOps for relax (#141)
Hzfengsy May 10, 2022
91bb40c
Change after rebase
yongwww May 11, 2022
3af8be8
[Analysis] IRModule well-formed check (#142)
LeshengJin May 13, 2022
d1dd274
Support Closure (#140)
yongwww May 13, 2022
3e82f9e
Refactor shape lowering pass and Blockbuilder. (#145)
YuchenJin May 17, 2022
08e96e6
Add support to import relay models with Any dim. (#146)
psrivas2 May 20, 2022
315b947
Print/parse tir cast/max operations in Relax shape (#149)
psrivas2 May 20, 2022
434247e
[Mutator] Separate unnormalized-form and normal-form mutators (#148)
YuchenJin May 25, 2022
045b706
[Pass] Relax Transform FuseTIR (#150)
Hzfengsy May 26, 2022
a632b4e
add test cases for FuseTIR (#156)
Hzfengsy May 29, 2022
7891b90
[REFACTOR] Move TIR op kind analysis to relax as it is relax oriented…
tqchen May 29, 2022
e20448e
[Pass Infra] Tuning Pass API (#144)
sunggg May 31, 2022
7b8897b
[PASS] Remove Unused Functions in IRModule (#151)
sunggg Jun 1, 2022
8226b48
[Parser] Add FuncType support (#154)
yongwww Jun 2, 2022
5b89370
Fix shape lowering pass bug for non i64 dims. (#152)
psrivas2 Jun 2, 2022
1938e97
[E2E] End-to-End tuning e2e_script (#153)
Hzfengsy Jun 7, 2022
1affb41
[Pass] Lambda Lifting (#99)
yongwww Jun 11, 2022
fb4fb42
[Relay translator] Allow replacing default topi function with user-pr…
YuchenJin Jun 12, 2022
8bb760d
Update after rebase
yongwww Jun 12, 2022
ef2462c
WellFormed Instrument (#165)
LeshengJin Jun 16, 2022
18c96bf
[VM] Add set_input interface; Fix e2e tuning script. (#166)
YuchenJin Jun 23, 2022
1c1a9a4
[CI] Enable Hexagon CI in Jenkins. (#169)
psrivas2 Jun 24, 2022
dcd9acd
Run static/dynamic models over Hexagon using Relax VM RPC (#167)
psrivas2 Jun 28, 2022
873a139
[BYOC][PASS] Prototype implementation of modular compilation w/ Tenso…
sunggg Jun 29, 2022
96d11fa
Update tests to use `set_input` for rpc calls. (#173)
psrivas2 Jun 29, 2022
c13a0e4
[Parser] Enable R.parser.pretty_print to print TIR PrimFunc (#174)
psrivas2 Jun 30, 2022
631c231
[Refactor] Generic dispatching for `IsBaseOf`; Simplify Type/Expr ini…
YuchenJin Jul 1, 2022
38bed4e
[VM] Deprecate API to save/load executable to file (#176)
psrivas2 Jul 12, 2022
f22bc5a
[Pass] Python ExprMutatorBase/ExprMutator (#172)
LeshengJin Jul 14, 2022
8888c24
fix print twice issue (#181)
Hzfengsy Jul 17, 2022
873656a
[Fix] fix windows build issue (#182)
Hzfengsy Jul 17, 2022
949e098
[Pass Infra] Tuning API serialization and database support (#168)
sunggg Jul 18, 2022
5557391
Fix after rebase
yongwww Jul 15, 2022
3acb42a
[Bugfix][VM] Ensure set_input works over RPC by not returning an arra…
slyubomirsky Jul 19, 2022
21263ad
[Pass] Enhance BindParams to take numpy dict as input (#184)
Hzfengsy Jul 23, 2022
f3cc123
[UX] Highlight TVMScript with Pygments (#185)
ganler Jul 27, 2022
5170f80
Dataflow Pattern Lang: Core Matching Features (#163)
ganler Aug 1, 2022
c96643c
[UX] Adopt changes from tvm-main and render code with IPython.display…
ganler Aug 3, 2022
f33313d
[Testing][AST] Add a simple AST printer for debugging (#198)
slyubomirsky Aug 3, 2022
decb8f6
Fix BlockBuilder Scope Recovery in Misuse (#199)
tqchen Aug 3, 2022
951cc3a
[Example][UX] Make the RPC timeout configurable in the `e2e_auto_tir`…
slyubomirsky Aug 4, 2022
d173755
[Pass] Introduce metaschedule as a tuning pass (#188)
sunggg Aug 4, 2022
96d2761
[Pass] Implement legacy lowering pass that leverages relay op strateg…
sunggg Aug 5, 2022
232efff
[Op][Debugging] Add a print operator (#201)
slyubomirsky Aug 9, 2022
b2c5679
[VM][UX] Implement stateful API (#207)
slyubomirsky Aug 10, 2022
582470a
Clean warning messages by Clang and Pylint (#215)
ganler Aug 11, 2022
7f59681
[Pass][UX] Statement rewriter for DataflowBlock (#210)
ganler Aug 12, 2022
5d84a5e
[VM][UX] Allow for saving closures to avoid extra dictionary lookups …
slyubomirsky Aug 12, 2022
41f2d0c
[Bugfix][VM] Fix var binding to a ConstantNode; Force VM if.cond regi…
YuchenJin Aug 14, 2022
8551e3a
Update with rebase
yongwww Aug 15, 2022
44fa1c3
[FIX] Fix windows build issue when allocating a dynamic array (#219)
Hzfengsy Aug 17, 2022
ce71864
[BugFix] Expose `relax.expr.Constant` to `relax.Constant` (#230)
MasterJH5574 Aug 18, 2022
1ff111d
[Pass] New Python ExprVisitor/ExprMutator! (#190)
LeshengJin Aug 19, 2022
2d783f6
[Hexgaon] Use uploaded path to load module. (#238)
psrivas2 Sep 1, 2022
4d31359
[Pass] Canonicalizing Bindings (#233)
slyubomirsky Sep 8, 2022
dade30d
[BugFix] Enable emit global MatchShape (#246)
MasterJH5574 Sep 14, 2022
0e64490
[VM][Benchmarking] Add option for saving e2e results as CSV file (#247)
slyubomirsky Sep 14, 2022
b669e58
[Bugfix][Op] Register attributes for unique and print (#248)
slyubomirsky Sep 14, 2022
39448c3
[Call TIR] Fix bug when invoking call_tir with scalar values. (#254)
Sep 22, 2022
4dc634d
[Bugfix][VM] Properly convert tensor inputs in `save_function` (#257)
slyubomirsky Sep 24, 2022
af8f50d
[Expr] Allow annotating return shape on function nodes (#253)
slyubomirsky Oct 10, 2022
ceb6048
[Analysis] Expose analyses related to vars in Python (#265)
slyubomirsky Oct 11, 2022
b46ced7
[Pass] Support Function and If in Normalize pass. (#268)
YuchenJin Oct 12, 2022
64dbec3
[Op][Debugging] Add `assert` operator (#260)
slyubomirsky Oct 14, 2022
985a123
Recover dropped commits
junrushao Oct 15, 2022
935a3dc
Enable Hexagon tests
junrushao Oct 17, 2022
83de3b2
Recover: [Pass] Separate ApplyHistoryBest from tuning passes (#226)
sunggg Oct 18, 2022
2963787
Recover: [Bugfix] Couple of bug fixes to run TVM-gen code together wi…
sunggg Oct 18, 2022
c6d6a06
Reenable autotvm silencer; fix e2e_auto_tir.py; fix lint.
sunggg Oct 18, 2022
32a03f8
[TVMScript] Update Type Annotation Behavior of the Parser (#269)
Hzfengsy Oct 20, 2022
4e3a3c2
[USMP] Initial implementation of liveness analysis for Relax + TIR
gigiblender Sep 15, 2022
9885a8e
[USMP] Implement AssignPoolInfo pass
gigiblender Oct 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[BlockBuilder] Avoid generating duplicated PrimFunc (#68)
* Avoid generating duplicated primfunc.

* Move logic to c++.

* Update method names.
YuchenJin authored and junrushao committed Oct 14, 2022
commit 900c259bbaf3e32abc3fcdf38b481a4608e556b2
25 changes: 25 additions & 0 deletions include/tvm/relax/block_builder.h
Original file line number Diff line number Diff line change
@@ -141,6 +141,22 @@ class BlockBuilderNode : public Object {
*/
NameTable* name_table();

/*!
* \brief Add a Relax function or a TIR PrimFunc to \p context_mod_.
* \param func The function to be added.
* \param func_name The name of the function to be added.
* \note If the function to be added already exists in \p context_mod_, return its
* GlobalVar directly.
* \return The global var bound to the added function.
*/
GlobalVar AddFuncToContext(const BaseFunc& func, const String& func_name);

/*!
* \brief Get the context IRModule being built.
* \return The IRModule being built by BlockBuilder.
*/
IRModule GetContextIRModule() const;

void VisitAttrs(AttrVisitor* v) {}

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
@@ -150,6 +166,15 @@ class BlockBuilderNode : public Object {
private:
Var Emit(const Expr& expr, bool is_dataflow, std::string name_hint);

/*! \brief The IRModule being built by the BlockBuilder. */
IRModule context_mod_;

/*!
* \brief A hashmap to store the mapping of Relax functions and TIR PrimFuncs
* in \p _context_mod to their GlobalVar to avoid generating duplicated functions.
*/
std::unordered_map<BaseFunc, GlobalVar, StructuralHash, StructuralEqual> func_map_;

protected:
/*!
* \brief A representation of a block frame.
46 changes: 30 additions & 16 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
@@ -123,7 +123,6 @@ def current():

def __init__(self):
self._blocks = []
self._context_mod = tvm.IRModule()
# a boolean flag that tracks if emit_func_output has been called
self._is_emit_func_output_called = False
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate)
@@ -259,22 +258,22 @@ def dataflow(self) -> DataflowScope:
"""
return DataflowScope(self)

def emit(self, call: relay.Call) -> Var:
"""Emit a call node.
This infers the shape and type of the CallNode, create a variable,
and bind the CallNode to the variable.
def emit(self, expr: Expr) -> Var:
"""Emit an expr.
This infers the shape and type of the expr, create a variable,
and bind the expr to the variable.

Parameters
----------
call : tvm.relax.Call
The call node to be emitted.
expr : tvm.relax.Expr
The Expr to be emitted.

Returns
-------
ret : tvm.relax.Var
A newly created variable that gets binded to the call code.
A newly created variable that gets binded to the input expr.
"""
return _ffi_api.BlockBuilderEmit(self, call)
return _ffi_api.BlockBuilderEmit(self, expr)

def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
"""Emit a call node according to the te function.
@@ -403,9 +402,7 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"]) -> Ten
inputs = [*te_args] + outs
tir_func = tvm.te.create_prim_func(inputs, unbound_tir_vars)
func_name = self.get_unique_name(func.__name__)
tir_func = tir_func.with_attr("global_symbol", func_name)
gvar = GlobalVar(func_name)
self._context_mod[gvar] = tir_func
gvar = self.add_func(tir_func, func_name)

call_args = [x.op.value for x in te_args]
output_shape = (
@@ -418,7 +415,7 @@ def rx_func(x: Tensor[(n,), "float32"], y: Tensor[((n + 1),), "float32"]) -> Ten
call = call_tir(output_shape, gvar, call_args, tir_vars=ShapeExpr(unbound_tir_vars))
else:
call = call_tir(output_shape, gvar, call_args)
return _ffi_api.BlockBuilderEmit(self, call)
return self.emit(call)

def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
"""Emit a MatchShape.
@@ -506,8 +503,7 @@ def emit_func_output(
func = rx.Function(
self._func_params, seqe, rx.DynTensorType(-1), rx.GlobalVar(self._func_name)
)
gvar = rx.GlobalVar(self._func_name)
self._context_mod[gvar] = func
self.add_func(func, self._func_name)

def normalize(self, expr: Expr) -> Expr:
"""Normalize an Expr to complete its shape and type.
@@ -532,7 +528,7 @@ def get(self) -> tvm.IRModule:
ret : tvm.IRModule
An IRModule with Relax and TIR functions being built.
"""
return self._context_mod
return _ffi_api.BlockBuilderGetContextIRModule(self)

def get_unique_name(self, name_prefix: str) -> str:
"""Generate a unique name with a specified prefix.
@@ -548,3 +544,21 @@ def get_unique_name(self, name_prefix: str) -> str:
The generated name.
"""
return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix)

def add_func(self, func: BaseFunc, func_name: str) -> GlobalVar:
"""Add a Relax function or a TIR PrimFunc to the IRModule being built.

Parameters
----------
func : BaseFunc
The function to be added.

func_name : str
The name of the function to be added.

Returns
-------
gvar : GlobalVar
The global var bound to the added function.
"""
return _ffi_api.BlockBuilderAddFuncToContext(self, func, func_name)
31 changes: 29 additions & 2 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
#include <tvm/relay/op.h>
#include <tvm/tir/function.h>

namespace tvm {
namespace relax {
@@ -525,6 +526,26 @@ BlockBuilderNode::BlockFrame* BlockBuilderNode::CurrentFrame() {

NameTable* BlockBuilderNode::name_table() { return name_table_.get(); }

GlobalVar BlockBuilderNode::AddFuncToContext(const BaseFunc& func, const String& func_name) {
auto it = func_map_.find(func);
if (it == func_map_.end()) {
GlobalVar gvar = GlobalVar(func_name);
if (const tir::PrimFuncNode* prim_func = func.as<tir::PrimFuncNode>()) {
tir::PrimFunc fn = GetRef<tir::PrimFunc>(prim_func);
fn = WithAttr(std::move(fn), "global_symbol", runtime::String(func_name));
context_mod_->Add(gvar, fn);
} else {
context_mod_->Add(gvar, func);
}
func_map_.emplace(func, gvar);
return gvar;
} else {
return it->second;
}
}

IRModule BlockBuilderNode::GetContextIRModule() const { return context_mod_; }

BlockBuilder BlockBuilder::Create() { return BlockBuilder(make_object<BlockBuilderNode>()); }

TVM_REGISTER_GLOBAL("relax.BlockBuilderCreate").set_body_typed(BlockBuilder::Create);
@@ -541,8 +562,8 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEndBlock")
TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize")
.set_body_method<BlockBuilder>(&BlockBuilderNode::Normalize);

TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder builder, Call call) {
return builder->Emit(call);
TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit").set_body_typed([](BlockBuilder builder, Expr expr) {
return builder->Emit(expr);
});

TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchShape")
@@ -560,5 +581,11 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName")
return builder->name_table()->GetUniqueName(name_hint);
});

TVM_REGISTER_GLOBAL("relax.BlockBuilderAddFuncToContext")
.set_body_method<BlockBuilder>(&BlockBuilderNode::AddFuncToContext);

TVM_REGISTER_GLOBAL("relax.BlockBuilderGetContextIRModule")
.set_body_method<BlockBuilder>(&BlockBuilderNode::GetContextIRModule);

} // namespace relax
} // namespace tvm
31 changes: 21 additions & 10 deletions tests/python/relax/test_blockbuilder.py
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
from tvm import tir, te
from tvm import relay
from tvm import relax as rx
from tvm.tir.function import PrimFunc

from tvm.ir.base import assert_structural_equal
from tvm.relax import ExternFunc, ShapeExpr, op
@@ -277,18 +278,18 @@ def test_emit_te():
x = rx.Var("x", [n, m], type_anno)
y = rx.Var("y", [n, m], type_anno)
z = rx.Var("z", [n, m], type_anno)

def te_func(args, args_dict, msg):
A, B = args
C = args_dict["C"]
D = te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
E = te.compute((128, 128), lambda i, j: D[i, j] - C[i, j])
return E

with bb.function("rx_func", [x, y, z]):
out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello")
bb.emit_func_output(out)

mod = bb.get()
rx_func = mod["rx_func"]

@@ -334,10 +335,20 @@ def te_func(A):
x1 = bb.emit_te(te_func, x)
y1 = bb.emit_te(te_func, y)
bb.emit_func_output(y1)

func = bb.get()["rx_func"]
assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
assert func.body.blocks[0].bindings[1].value.args[1].name_hint == "te_func1"

mod = bb.get()
rx_func = mod["rx_func"]

prim_func = []
for gv in mod.get_global_vars():
if isinstance(mod[gv], PrimFunc):
prim_func.append(mod[gv])

# only one PrimFunc is generated
assert len(prim_func) == 1
assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
assert rx_func.body.blocks[0].bindings[1].value.args[1].name_hint == "te_func"


def test_emit_te_multiple_output():
bb = rx.BlockBuilder()
@@ -366,6 +377,7 @@ def te_func(A):
assert isinstance(rx_func.body.blocks[0].bindings[0].value.args[0][0], rx.ShapeExpr)
assert isinstance(rx_func.body.blocks[0].bindings[0].value.args[0][1], rx.ShapeExpr)


def test_emit_te_extern():
bb = rx.BlockBuilder()
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
@@ -376,10 +388,10 @@ def test_emit_te_extern():
with bb.function("rx_cblas_matmul", [x, y]):
out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False)
bb.emit_func_output(out)

mod = bb.get()
rx_func = mod["rx_cblas_matmul"]

# check Relax function calls TIR function with call_tir call
assert rx_func.params[0] == x
assert rx_func.params[1] == y
@@ -471,4 +483,3 @@ def test_no_func_params_fail():
test_emit_func_output_twice_fail()
test_func_params_twice_fail()
test_no_func_params_fail()