Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Commit

Permalink
[Op] Remove MatchCast requirement for manipulation ops (#436)
Browse files Browse the repository at this point in the history
Prior to this PR, the structure info inference of some manipulation
operators rejects unknown ndim or unknown shape for safety reason.
As many people are pointing out, this requirement increases the
overhead of our frontend importers, in the way of forcing the
importers to use MatchCast, which turns out to be ineffective and
troublesome.

Therefore, this PR removes such requirements, turning into the
behavior of optimistically trust the input that they have the desired
ndim or shape when those properties are unknown.

Nevertheless, this PR leaves TODO items at such places, which serve
as reminders for us to support corresponding runtime ndim or shape
check in the future.
  • Loading branch information
MasterJH5574 authored Feb 14, 2023
1 parent 69a9700 commit ce5c7f4
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 103 deletions.
91 changes: 28 additions & 63 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,16 @@ StructInfo InferStructInfoBroadcastTo(const Call& call, const BlockBuilder& ctx)
const auto* old_len_int = old_len.as<IntImmNode>();
if (old_len_int != nullptr && old_len_int->value == 1) {
continue;
} else if (!analyzer->CanProveEqual(old_len, tgt_len)) {
// We would like to ensure safety, and therefore placed a stronger requirement for user to
// use MatchCast.
// Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit
// this requirement to reduce the workload of importers and better support dynamic shapes.
} else if (analyzer->CanProve(old_len != tgt_len)) {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "broadcast_to expects the input tensor shape is broadcastable to the target shape. "
"The target shape at dim "
<< tgt_ndim - i - 1 << " is " << tgt_len << " while the input tensor shape at dim "
<< old_ndim - i - 1 << " is " << old_len
<< ", where the broadcastability cannot be determined at compile-time. If the old shape "
"at this dim is symbolic and will be 1 at runtime, to ensure safety, please use "
"MatchCast to explicitly make the shape at this dimension be static 1.");
<< old_ndim - i - 1 << " is " << old_len << ", which are not equal.");
}
// Todo(relax-team): revisit here for better check on if the tensor length
// is consistent with the length in the given shape.
}
return TensorStructInfo(/*shape=*/call->args[1], data_sinfo->dtype);
}
Expand Down Expand Up @@ -200,7 +195,7 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
}

const auto* attrs = call->attrs.as<ConcatAttrs>();
int output_ndim = kUnknownNDim;
int output_ndim = attrs->axis.defined() ? kUnknownNDim : 1;
DataType output_dtype = DataType::Void();
bool shape_unknown = false;
bool is_void_dtype = false;
Expand All @@ -221,19 +216,9 @@ StructInfo InferStructInfoConcat(const Call& call, const BlockBuilder& ctx) {
}

// Update the output ndim.
if (!attrs->axis.defined() && sinfo->ndim != 1) {
// To ensure safety, we require all tensors to explicitly have ndim 1 when the concat axis
// is not specified.
// Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit
// this requirement to reduce the workload of importers and better support dynamic shapes.
ctx->ReportFatal(
Diagnostic::Error(call)
<< "Concat expects all input tensors to be flattened 1-dimensional tensor when the axis "
"is not specified. However, the input contains a tensor with ndim dimension "
<< (sinfo->ndim == kUnknownNDim ? "unknown" : std::to_string(sinfo->ndim))
<< ". If the ndim is unknown, please use MatchCast to match it to 1-dimensional tensor "
"first.");
} else if (output_ndim == kUnknownNDim) {
// Todo(relax-team): revisit here for better check on if the input tensor has
// ndim 1 when the input axis is undefined.
if (output_ndim == kUnknownNDim) {
output_ndim = sinfo->ndim;
} else if (sinfo->ndim != kUnknownNDim && sinfo->ndim != output_ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
Expand Down Expand Up @@ -421,22 +406,10 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder&
}
}

// We would like to ensure safety, and therefore placed a stronger requirement for user to
// use MatchCast before layout_transform if the shape of input is not known at compile time.
// Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit
// this requirement to reduce the workload of importers and better support dynamic shapes.
auto report_error_for_unknown_shape =
[&]() {
ctx->ReportFatal(
Diagnostic::Error(call)
<< "layout_transform expects the input tensor to have known rank (expected rank = "
<< index_map->initial_indices.size()
<< ") and shape. For input tensors, whose shape cannot be determined at compile time, "
"please use MatchCast to get input with symbolic shape.");
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
};

if (data_sinfo->IsUnknownNdim()) return report_error_for_unknown_shape();
if (data_sinfo->IsUnknownNdim()) {
// Todo(relax-team): revisit here for better check on if the input tensor has desired ndim.
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
}

// If rank is known, check that it is compatible with the index_map, i.e., #dims match.
if (index_map->initial_indices.size() != static_cast<size_t>(data_sinfo->ndim)) {
Expand All @@ -446,15 +419,16 @@ StructInfo InferStructInfoLayoutTransform(const Call& call, const BlockBuilder&
<< data_sinfo->ndim << " != " << index_map->initial_indices.size());
}

if (!data_sinfo->shape.defined()) return report_error_for_unknown_shape();

// If input shape is known, get the ShapeStructInfo of the shape expr.
const auto* shape_sinfo = GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value());
if (!data_sinfo->shape.defined()) {
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
}

if (!shape_sinfo->values.defined()) return report_error_for_unknown_shape();
ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(data_sinfo->shape.value()->struct_info_);
if (!shape_sinfo->values.defined()) {
return TensorStructInfo(data_sinfo->dtype, /*ndim=*/index_map->final_indices.size());
}

Array<PrimExpr> input_shape = shape_sinfo->values.value();
Array<PrimExpr> output_shape = index_map->MapShape(input_shape);
Array<PrimExpr> output_shape = index_map->MapShape(shape_sinfo->values.value());
return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype);
}

Expand Down Expand Up @@ -490,22 +464,16 @@ StructInfo InferStructInfoPermuteDims(const Call& call, const BlockBuilder& ctx)
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);

const auto* attrs = call->attrs.as<PermuteDimsAttrs>();
if (data_sinfo->IsUnknownNdim()) {
if (attrs->axes.defined()) {
// Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit
// this requirement to reduce the workload of importers and better support dynamic shapes.
ctx->ReportFatal(Diagnostic::Error(call)
<< "PermuteDims cannot be performed when the input tensor " << data_sinfo
<< " ndim is unknown while the given number of axes " << attrs->axes.value()
<< " is clear. Please use MatchCast to match the input tensor to a specific "
"ndim before doing PermuteDims.");
}

// Todo(relax-team): revisit here for better check on if the input tensor has
// ndim same as the number of input axes.
if (!attrs->axes.defined() && data_sinfo->IsUnknownNdim()) {
return TensorStructInfo(data_sinfo->dtype, kUnknownNDim);
}

if (attrs->axes.defined()) {
int n_axis = attrs->axes.value().size();
if (n_axis != data_sinfo->ndim) {
if (!data_sinfo->IsUnknownNdim() && n_axis != data_sinfo->ndim) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "PermuteDims expects the number of input axes to equal the ndim of the "
"input tensor. However, the tensor ndim is "
Expand Down Expand Up @@ -811,13 +779,10 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim - axes.size());
}
for (int i = 0; i < static_cast<int>(axes.size()); ++i) {
// When `axis` is given, the dim lengths at the axes must be static constant integer 1.
// Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1.
// When `axis` is given, the dim lengths at the axes must be integer 1 when it is not symbolic
const auto* int_len = shape_value.value()[axes[i]].as<IntImmNode>();
if (int_len == nullptr || int_len->value != 1) {
// We would like to ensure safety, and therefore placed a stronger requirement for user to
// use MatchCast.
// Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit
// this requirement to reduce the workload of importers and better support dynamic shapes.
if (int_len != nullptr && int_len->value != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Squeeze expects the input tensor shape values at the given axis "
"positions to be all 1. However, the tensor shape at axis "
Expand Down
88 changes: 48 additions & 40 deletions tests/python/relax/test_op_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,12 +707,18 @@ def test_layout_transform_infer_struct_info_unknown_shape():
tiling_transform = lambda a, b: (a, b // 2, b % 2)

x_unknown_shape = relax.Var("x", R.Tensor("float32", ndim=2))
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform))
_check_inference(
bb,
relax.op.layout_transform(x_unknown_shape, index_map=tiling_transform),
relax.TensorStructInfo(dtype="float32", ndim=3),
)

x_unknown_rank_dtype = relax.Var("x", R.Tensor())
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform))
_check_inference(
bb,
relax.op.layout_transform(x_unknown_rank_dtype, index_map=tiling_transform),
relax.TensorStructInfo(dtype="", ndim=3),
)


def test_layout_transform_infer_struct_info_symbolic_shape():
Expand Down Expand Up @@ -743,13 +749,19 @@ def test_layout_transform_infer_struct_info_shape_var():

s_unknown_shape = relax.Var("s", relax.ShapeStructInfo(ndim=2))
x_unknown_shape = relax.Var("x", relax.TensorStructInfo(s_unknown_shape, "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform))
_check_inference(
bb,
relax.op.layout_transform(x_unknown_shape, index_map=tiling_padding_transform),
relax.TensorStructInfo(dtype="float32", ndim=3),
)

s_unknown_rank = relax.Var("s", relax.ShapeStructInfo())
x_unknown_rank = relax.Var("x", relax.TensorStructInfo(s_unknown_rank, "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform))
_check_inference(
bb,
relax.op.layout_transform(x_unknown_rank, index_map=tiling_padding_transform),
relax.TensorStructInfo(dtype="float32", ndim=3),
)

a = tir.Var("a", "int64")
b = tir.Var("b", "int64")
Expand Down Expand Up @@ -924,12 +936,10 @@ def test_squeeze_infer_struct_info_axis_length_not_one():

with pytest.raises(TVMError):
bb.normalize(relax.op.squeeze(x0, [0]))
with pytest.raises(TVMError):
bb.normalize(relax.op.squeeze(x1, [0]))
_check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.squeeze(x2, [0]))
with pytest.raises(TVMError):
bb.normalize(relax.op.squeeze(x3, [0]))
_check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2))


def test_squeeze_infer_struct_info_wrong_input_type():
Expand Down Expand Up @@ -1513,14 +1523,12 @@ def test_concat_infer_struct_info_without_axis_but_tensor_not_one_dimensional():
bb.normalize(relax.op.concat([x0], axis=None))
with pytest.raises(TVMError):
bb.normalize(relax.op.concat([x1], axis=None))
with pytest.raises(TVMError):
bb.normalize(relax.op.concat([x2], axis=None))
_check_inference(bb, relax.op.concat([x2], axis=None), relax.TensorStructInfo(dtype="float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.concat([x3], axis=None))
with pytest.raises(TVMError):
bb.normalize(relax.op.concat([x4], axis=None))
with pytest.raises(TVMError):
bb.normalize(relax.op.concat([x5], axis=None))
_check_inference(bb, relax.op.concat([x5], axis=None), relax.TensorStructInfo(s2, "float32"))


def test_concat_infer_struct_info_inconsistent_dtype():
Expand Down Expand Up @@ -2323,30 +2331,30 @@ def test_broadcast_to_infer_struct_info_not_broadcastable_symbolic():
stgt1 = relax.Var("stgt", relax.ShapeStructInfo((2, 1)))
stgt2 = relax.Var("stgt", relax.ShapeStructInfo((b, a)))

with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x0, (2, b)))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x0, (2, 1)))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x0, (b, a)))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x0, stgt0))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x0, stgt1))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x0, stgt2))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x1, (2, b)))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x1, (2, 1)))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x1, (b, a)))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x1, stgt0))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x1, stgt1))
with pytest.raises(TVMError):
bb.normalize(relax.op.broadcast_to(x1, stgt2))
_check_inference(
bb, relax.op.broadcast_to(x0, (2, b)), relax.TensorStructInfo((2, b), "float32")
)
_check_inference(
bb, relax.op.broadcast_to(x0, (2, 1)), relax.TensorStructInfo((2, 1), "float32")
)
_check_inference(
bb, relax.op.broadcast_to(x0, (b, a)), relax.TensorStructInfo((b, a), "float32")
)
_check_inference(bb, relax.op.broadcast_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32"))
_check_inference(bb, relax.op.broadcast_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32"))
_check_inference(bb, relax.op.broadcast_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32"))
_check_inference(
bb, relax.op.broadcast_to(x1, (2, b)), relax.TensorStructInfo((2, b), "float32")
)
_check_inference(
bb, relax.op.broadcast_to(x1, (2, 1)), relax.TensorStructInfo((2, 1), "float32")
)
_check_inference(
bb, relax.op.broadcast_to(x1, (b, a)), relax.TensorStructInfo((b, a), "float32")
)
_check_inference(bb, relax.op.broadcast_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32"))
_check_inference(bb, relax.op.broadcast_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32"))
_check_inference(bb, relax.op.broadcast_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32"))


def test_broadcast_to_infer_struct_info_wrong_input_type():
Expand Down

0 comments on commit ce5c7f4

Please sign in to comment.