Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[Transform] RewriteDataflowReshape to op and VMBuiltinLower handling (#…
Browse files Browse the repository at this point in the history
…415)

* [Transform] RewriteDataflowReshape to op and VMBuiltinLower handling

Priior to this PR, the pass transforms calls of reshape
PrimFunc in dataflow blocks to direct calls of runtime packed func
“vm.builtin.reshape.” The consequence of this behavior is that the
memory planning pass has to check the reshape op by string comparison
of `ExternFunc.global_symbol`, which is not ideal.

Therefore, this PR changes the RewriteDataflowReshape’s behavior,
transforming calls of reshape PrimFunc to our high-level reshape op
“relax.reshape,” and let the VMBuiltinLower pass to lowers the op to
calls of “vm.builtin.reshape.”
  • Loading branch information
MasterJH5574 authored Feb 9, 2023
1 parent d85dda9 commit 1f84c7b
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 22 deletions.
10 changes: 7 additions & 3 deletions include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,16 @@ TVM_DLL Pass ToNonDataflow();
TVM_DLL Pass CallTIRRewrite();

/*!
* \brief Convert all reshape-like call_tir to VM reshape operator call.
* The VM reshape operator calls will be further lowered to a CreateView
* operation at runtime, instead of doing real data copy.
* \brief Convert all reshape-like call_tir whose corresponding binding
* vars are DataflowVars to relax.reshape operator calls. The relax.reshape
* calls will be lowered an external builtin function call in a subsequent
* pass, where the external builtin function does a CreateView operation
* at runtime, instead of doing real data copy.
* Here "reshape-like" includes reshape, expand_dims, flatten, etc.
*
* \return The Pass.
* \note The pass is applied at the first stage of Relax VM build, before
* rewriting call_tir, as this pass requires dataflow information.
*/
TVM_DLL Pass RewriteDataflowReshape();

Expand Down
14 changes: 11 additions & 3 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,22 @@ def CallTIRRewrite() -> tvm.ir.transform.Pass:


def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
"""Convert all reshape-like call_tir to VM reshape operator call.
The VM reshape operator calls will be further lowered to a CreateView
operation at runtime, instead of doing real data copy.
"""Convert all reshape-like call_tir whose corresponding binding
vars are DataflowVars to relax.reshape operator calls. The relax.reshape
calls will be lowered an external builtin function call in a subsequent
pass, where the external builtin function does a CreateView operation
at runtime, instead of doing real data copy.
Here "reshape-like" includes reshape, expand_dims, flatten, etc.
Returns
-------
ret : tvm.ir.transform.Pass
Notes
-----
The pass is applied at the first stage of Relax VM build, before
rewriting call_tir, as this pass requires dataflow information.
"""
return _ffi_api.RewriteDataflowReshape() # type: ignore

Expand Down
12 changes: 12 additions & 0 deletions src/relax/backend/vm/vm_builtin_lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class VMBuiltinLowerMutator : public ExprMutator {

if (call->op == call_tir_dyn_op_) {
return CallTIRDyn(call);
} else if (call->op == reshape_op_) {
return Reshape(call);
} else if (call->op == make_closure_op_) {
return MakeClosure(call);
} else if (call->op == invoke_closure_op_) {
Expand Down Expand Up @@ -102,6 +104,14 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Call(builtin_call_tir_dyn_, args, Attrs(), {void_sinfo_});
}

Expr Reshape(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->struct_info_.defined());
CHECK(call_node->args[1]->IsInstance<ShapeExprNode>())
<< "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr";
return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr MakeClosure(const Call& call_node) {
ICHECK(call_node->args.size() == 2);
ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
Expand Down Expand Up @@ -142,6 +152,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({}));
// object to pattern match.
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
Expand All @@ -151,6 +162,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
// Function to compute allocated shape.
const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"};
const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};
Expand Down
22 changes: 16 additions & 6 deletions src/relax/transform/rewrite_dataflow_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
*/
/*!
* \file src/relax/transform/rewrite_dataflow_reshape.cc
* \brief Transform all reshape within dataflow block to a specialized reshape operator
* \brief Transform all reshape within dataflow block to a relax.reshape operator
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>

#include "../op/tensor/manipulate.h"

namespace tvm {
namespace relax {

Expand All @@ -32,6 +34,8 @@ class DataflowReshapeRewriter : public ExprMutator {
explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {}

private:
using ExprMutator::VisitExpr_;

BindingBlock VisitBindingBlock(const BindingBlock& block) final {
// We only rewrite the bindings inside dataflow blocks.
if (const auto* dataflow_block = block.as<DataflowBlockNode>()) {
Expand All @@ -55,21 +59,27 @@ class DataflowReshapeRewriter : public ExprMutator {
if (!IsCallingTIRReshape(call)) {
return GetRef<Call>(call);
}
static const ExternFunc& vm_builtin_reshape = ExternFunc("vm.builtin.reshape");

// We bring the calls of reshape PrimFunc back to calls of high-level
// relax.reshape op, which will be lowered to calls of the ExternFunc
// vm.builtin.reshape in the VMBuiltinLower pass.
Array<Expr> args = Downcast<Tuple>(call->args[1])->fields;
ICHECK_EQ(args.size(), 1);
TensorStructInfo res_sinfo = Downcast<TensorStructInfo>(call->struct_info_);
ShapeExpr new_shape = Downcast<ShapeExpr>(res_sinfo->shape);
return Call(vm_builtin_reshape, {args[0], new_shape}, Attrs(), {res_sinfo});
ICHECK(res_sinfo->shape.defined());
return reshape(args[0], res_sinfo->shape.value());
}

bool IsCallingTIRReshape(const CallNode* call) {
static const Op& call_tir_op = Op::Get("relax.call_tir");
if (call->op != call_tir_op) {
return false;
}
GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
const auto* func = mod_->functions.Get(gv).as<tir::PrimFuncNode>();
const auto* gv = call->args[0].as<GlobalVarNode>();
if (gv == nullptr) {
return false;
}
const auto* func = mod_->functions.Get(GetRef<GlobalVar>(gv)).as<tir::PrimFuncNode>();
ICHECK_NOTNULL(func);
return HasReshapePattern(GetRef<tir::PrimFunc>(func));
}
Expand Down
29 changes: 25 additions & 4 deletions tests/python/relax/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import pytest

import tvm
import tvm.script
import tvm.testing
from tvm import relax
from tvm.ir import structural_equal
from tvm.ir.base import assert_structural_equal

import tvm.script
from tvm.script import tir as T, relax as R


Expand Down Expand Up @@ -320,5 +320,26 @@ def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor:
assert s3.op.global_symbol == "test.op.identity"


def test_vm_builtin_lower_reshape():
@tvm.script.ir_module
class TestVMReshape:
@R.function
def main(x: R.Tensor((3, 4), "float32")):
y = R.reshape(x, (6, 2))
return y

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((3, 4), "float32")):
y: R.Tensor((6, 2), "float32") = R.call_packed(
"vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32")
)
return y

mod = relax.transform.VMBuiltinLower()(TestVMReshape)
assert_structural_equal(mod, Expected)


if __name__ == "__main__":
pytest.main([__file__])
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,7 @@ def main(
x: R.Tensor((8, 3), dtype="float32")
) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"):
with R.dataflow():
y = R.call_packed(
"vm.builtin.reshape",
x,
(2, 4, 3),
sinfo_args=R.Tensor((2, 4, 3), "float32"),
)
y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3))
# Note: `z` is the output var of the dataflow block, and is thus
# not expected to be rewritten.
z = R.call_tir(
Expand Down

1 comment on commit 1f84c7b

@junrushao
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebase has been pushed to tlc-pack/relax:rebase-staging.

Next steps:

  • S1. Tweak the staging branch to make sure the CI passes.

  • S2. Force push the staging branch to tlc-pack/relax:relax by:

rm -rf /tmp/tvm-rebase
git clone [email protected]:tlc-pack/relax.git --recursive -b rebase-staging /tmp/tvm-rebase
cd /tmp/tvm-rebase
git checkout -b relax
git push origin relax --force
  • S3. Send out an announcement to the community.
Hi folks, just did another round of sync of `tlc-pack/relax:relax`
against `apache/tvm:main`. The original branch has been backed up to
`apache/tvm:auto-backup/2023-02-08`.

Please rebase your branches using:

   git rebase --onto upstream/relax upstream/auto-backup/2023-02-08

Thank you!

Please sign in to comment.