Skip to content

Commit

Permalink
[ARCH] NestedMsg util functions (apache#347)
Browse files Browse the repository at this point in the history
Add two utility functions of NestedMsg for robust and easy tuple handling.
  • Loading branch information
spectrometerHBH authored Jan 10, 2023
1 parent e793c75 commit f7cfdb5
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 10 deletions.
84 changes: 82 additions & 2 deletions include/tvm/relax/nested_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
#define TVM_RELAX_NESTED_MSG_H_

#include <tvm/relax/expr.h>
#include <tvm/relax/struct_info.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/optional.h>

#include <utility>
#include <vector>

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -123,7 +125,7 @@ class NestedMsg : public ObjectRef {
NestedMsg<T>& operator=(NestedMsg<T>&&) = default;
/*!
* \brief Construct from an ObjectPtr
* whose type already satiesfies the constraint
* whose type already satisfies the constraint
* \param ptr
*/
explicit NestedMsg(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
Expand Down Expand Up @@ -159,7 +161,7 @@ class NestedMsg : public ObjectRef {
}

// delete the int constructor
// since NestedMsg<Integer>(0) is ambiguious
// since NestedMsg<Integer>(0) is ambiguous
// 0 can be implicitly casted to nullptr_t
explicit NestedMsg(int val) = delete;
NestedMsg<T>& operator=(int val) = delete;
Expand Down Expand Up @@ -275,6 +277,34 @@ NestedMsg<T> MapToNestedMsg(Expr expr, FType fmapleaf) {
}
}

/*!
* \brief Map structinfo with possible nested-sinfo to nested message.
*
* This function will unpack recursive sinfo and run fmapleaf for each leaf,
* then recursively combines the results together into a NestedMsg.
*
* The nesting structure will corresponds to the tuple structure.
*
* \param sinfo The input struct info.
* \param fmapleaf The mapping function for each leaf with signature
* NestedMsg<T> fmap(StructInfo)
* \tparam T the content type of nested msg
* \tparam FType The mapping function type
*/
template <typename T, typename FType>
NestedMsg<T> MapToNestedMsg(StructInfo sinfo, FType fmapleaf) {
if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
Array<NestedMsg<T>> res;
res.reserve(tuple->fields.size());
for (StructInfo x : tuple->fields) {
res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
}
return res;
} else {
return fmapleaf(sinfo);
}
}

/*!
* \brief Recursively combine two nested message into one.
*
Expand Down Expand Up @@ -341,6 +371,56 @@ void DecomposeNestedMsg(Expr expr, NestedMsg<T> msg, FType fvisitleaf) {
}
}

/*!
* \brief Recursively transform the tuple structure in expr and msgs along with it.
*
* This function will call ftransleaf for each leaf expression in expr.
* This function will throw an error if the nesting structure in msg does not
* match the tuple nesting structure in expr.
*
* \param expr The input expression to be transform. 
* \param msgs The input messages to guide the transformation.
* \param ftransleaf with signature ftransleaf(Expr, Array<NestedMsg<T>>)->Expr
* \tparam T the content type of nested msg
* \tparam N the number of messages
* \tparam FType The visit function type.
*/
template <typename T, std::size_t N, typename FType>
Expr TransformTupleLeaf(Expr expr, std::array<NestedMsg<T>, N> msgs, FType ftransleaf) {
StructInfo sinfo = GetStructInfo(expr);
if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
std::array<Array<NestedMsg<T>>, N> msg_arrays;
for (size_t i = 0; i < N; ++i) {
ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
msg_arrays[i] = msgs[i].NestedArray();
}
bool same = true;
Array<Expr> fields;
fields.reserve(tuple->fields.size());
for (size_t i = 0; i < tuple->fields.size(); ++i) {
Expr field;
if (const auto* expr_tuple = expr.as<TupleNode>()) {
field = expr_tuple->fields[i];
} else {
field = TupleGetItem(expr, i);
UpdateStructInfo(field, tuple->fields[i]);
}
std::array<NestedMsg<T>, N> sub_msgs;
for (size_t j = 0; j < N; ++j) {
sub_msgs[j] = msg_arrays[j][i];
}
fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), ftransleaf));
same &= (fields.back().same_as(field));
}
return same ? expr : Tuple(fields);
} else {
for (const auto& msg : msgs) {
ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
}
return ftransleaf(expr, msgs);
}
}

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_NESTED_MSG_H_
68 changes: 60 additions & 8 deletions tests/cpp/nested_msg_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/relax/nested_msg.h>
#include <tvm/relax/struct_info.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/expr.h>

#include <algorithm>
Expand All @@ -38,7 +40,7 @@ using namespace tvm::relax;

TEST(NestedMsg, Basic) {
// start with no annotation
relax::Var x("x", NullOpt, NullOpt), y("y", NullOpt, NullOpt);
relax::Var x("x", NullOpt), y("y", NullOpt);

// constructor from array, T and nullopt.
NestedMsg<relax::Expr> msg({x, NullOpt, x});
Expand Down Expand Up @@ -79,7 +81,7 @@ TEST(NestedMsg, Basic) {
}

TEST(NestedMsg, ForEachLeaf) {
relax::Var x("x", NullOpt, NullOpt), y("y", NullOpt, NullOpt);
relax::Var x("x", NullOpt), y("y", NullOpt);
NestedMsg<Expr> msg = {x, {x, y}, NullOpt, {x, {x, y}}};

int x_count = 0, y_count = 0;
Expand All @@ -93,8 +95,8 @@ TEST(NestedMsg, ForEachLeaf) {
}

TEST(NestedMsg, Equal) {
relax::Var x("x", NullOpt, NullOpt), y("y", NullOpt, NullOpt);
relax::Var z("z", NullOpt, NullOpt);
relax::Var x("x", NullOpt), y("y", NullOpt);
relax::Var z("z", NullOpt);

auto fequal = [](Expr lhs, Expr rhs) { return lhs.same_as(rhs); };

Expand Down Expand Up @@ -129,11 +131,13 @@ TEST(NestedMsg, Equal) {
}

TEST(NestedMsg, MapAndDecompose) {
relax::Var x("x", NullOpt, NullOpt), y("y", NullOpt, NullOpt);
relax::Var z("z", NullOpt, NullOpt);
relax::Var x("x", PrimStructInfo(runtime::DataType::Int(16)));
relax::Var y("y", PrimStructInfo(runtime::DataType::Int(32)));
relax::Var z("z", PrimStructInfo(runtime::DataType::Int(64)));

relax::Tuple t0({x, y});
relax::Tuple t1({t0, x, z, t0});
BlockBuilder bb = BlockBuilder::Create(NullOpt);
relax::Expr t0 = bb->Normalize(Tuple({x, y}));
relax::Expr t1 = bb->Normalize(Tuple({t0, x, z, t0}));

auto c0 = Integer(0);
auto c1 = Integer(1);
Expand All @@ -150,6 +154,20 @@ TEST(NestedMsg, MapAndDecompose) {
EXPECT_TRUE(Equal(output, expected,
[](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; }));

auto output2 =
MapToNestedMsg<Integer>(GetStructInfo(t1), [&](StructInfo sinfo) -> NestedMsg<Integer> {
const auto* prim_sinfo = sinfo.as<PrimStructInfoNode>();
if (prim_sinfo == nullptr) return NullOpt;
int bits = prim_sinfo->dtype.bits();
if (bits == 16) return c0;
if (bits == 32) return c1;
if (bits == 64) return c2;
return NullOpt;
});

EXPECT_TRUE(Equal(output2, expected,
[](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; }));

int x_count = 0, y_count = 0, z_count = 0;

DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg<Integer> msg) {
Expand Down Expand Up @@ -186,3 +204,37 @@ TEST(NestedMsg, CombineNestedMsg) {
EXPECT_TRUE(Equal(output, expected,
[](Integer lhs, Integer rhs) -> bool { return lhs->value == rhs->value; }));
}

TEST(NestedMsg, TransformTupleLeaf) {
auto c0 = Integer(0);
auto c1 = Integer(1);
auto c2 = Integer(2);
using NInt = NestedMsg<Integer>;

NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}};
NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}};

PrimStructInfo s = PrimStructInfo(runtime::DataType::Int(32));
relax::Var x("x", s), y("y", s), z("z", s);
BlockBuilder bb = BlockBuilder::Create(NullOpt);
Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, x})})}));

auto ftransleaf = [&](Expr value, std::array<NInt, 2> msgs) -> Expr {
int lhs = Downcast<Integer>(msgs[0].LeafValue())->value;
int rhs = Downcast<Integer>(msgs[1].LeafValue())->value;
if (lhs > rhs)
return z;
else if (lhs == rhs)
return value;
else
return y;
};

Expr expected = Tuple({y, Tuple({y, z}), x, Tuple({y, Tuple({y, z})})});

EXPECT_TRUE(StructuralEqual()(
TransformTupleLeaf(expr, std::array<NInt, 2>({msg1, msg2}), ftransleaf), expected));

EXPECT_TRUE(
expr.same_as(TransformTupleLeaf(expr, std::array<NInt, 2>({msg1, msg1}), ftransleaf)));
}

0 comments on commit f7cfdb5

Please sign in to comment.