Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle empty tensors during definition of cat #3313

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
126 changes: 101 additions & 25 deletions csrc/ops/alias.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <expr_evaluator.h>
#include <expr_simplifier.h>
#include <ir/builder.h>
#include <ir/utils.h>
Expand All @@ -13,6 +14,7 @@
#include <ops/utils.h>
#include <transform_view.h>
#include <type_promotion.h>
#include "polymorphic_value.h"

namespace nvfuser {

Expand Down Expand Up @@ -479,6 +481,21 @@ bool hasSimilarDtype(DataType base, DataType dt) {
NVF_THROW("Unrecognized base dtype.");
}

Val* zeroForDtype(DataType dtype) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use Fusion::zeroVal(DatatType)?

// Create a zero of the appropriate type
if (isComplexType(dtype)) {
return IrBuilder::create<Val>(std::complex<double>(0), dtype);
} else if (isFloatingPointType(dtype)) {
return IrBuilder::create<Val>(0.0, dtype);
} else if (isBooleanType(dtype)) {
return IrBuilder::create<Val>(false, dtype);
} else {
return IrBuilder::create<Val>(0L, dtype);
}
NVF_THROW("Unsupported dtype in zeroForDtype: ", dtype);
return nullptr;
}

// Padding widths are assumed to be non-negative. Currently there's no
// validation.
TensorView* pad(
Expand All @@ -488,17 +505,7 @@ TensorView* pad(
std::optional<IterType> iter_type_opt) {
DataType dt = inp->getDataType().value();
if (!value) {
// Create a zero of the appropriate type
if (isComplexType(dt)) {
value = static_cast<Val*>(
IrBuilder::create<Val>(std::complex<double>(0), dt));
} else if (isFloatingPointType(dt)) {
value = static_cast<Val*>(IrBuilder::create<Val>(0.0, dt));
} else if (isBooleanType(dt)) {
value = static_cast<Val*>(IrBuilder::create<Val>(false, dt));
} else {
value = static_cast<Val*>(IrBuilder::create<Val>(0L, dt));
}
value = zeroForDtype(dt);
}
NVF_CHECK(
hasSimilarDtype(dt, value->getDataType().value()),
Expand Down Expand Up @@ -541,6 +548,8 @@ TensorView* pad(
// Indicates if any dimension is actually padded. Can be false even
// when non-empty padding width vector is passed
bool is_padded_any = false;
// If all of the padded dimensions are actually empty to begin with, then we
// can replace this operation with full()
for (const auto idx : c10::irange(ndims)) {
auto inp_root_id = inp_dom.at(idx);
IterDomain* out_root_id = nullptr;
Expand Down Expand Up @@ -568,14 +577,26 @@ TensorView* pad(
return set(inp);
}

if (std::any_of(inp_dom.begin(), inp_dom.end(), [](IterDomain* id) {
Val* input_extent = id->getMaybeExpandedExtent();
return input_extent->isConstScalar() &&
input_extent->evaluate().as<int64_t>() == 0;
})) {
// We are padding an empty tensor. Instead of PadOp, use FullOp
std::vector<Val*> shape;
shape.reserve(logical_ids.size());
for (IterDomain* id : logical_ids) {
shape.push_back(id->getMaybeExpandedExtent());
}
return full(shape, value, dt);
}
auto out = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
root_ids,
logical_ids,
logical_ids,
TensorDomain::getContiguityFilledWith(logical_ids, true)),
*inp->getDataType());

IrBuilder::create<PadOp>(out, inp, normalized_pad_widths, value);

return out;
Expand All @@ -586,28 +607,46 @@ TensorView* pad(
// output. All of the inputs to CatOp have the same shape as the
// output shape.
TensorView* cat(
const std::vector<TensorView*>& inputs,
const std::vector<TensorView*>& orig_inputs,
int64_t cat_dim,
std::optional<IterType> iter_type_opt,
bool manual_padding) {
NVF_CHECK(!inputs.empty(), "No input tensor given");
NVF_CHECK(!orig_inputs.empty(), "No input tensor given");

const auto dtype = inputs.at(0)->getDataType().value();
const DataType dtype = orig_inputs.at(0)->getDataType().value();

ExpressionEvaluator expr_eval;
const auto extentIsEmpty = [&expr_eval](IterDomain* id) {
PolymorphicValue extent = expr_eval.evaluate(id->getMaybeExpandedExtent());
return extent.hasValue() && extent.as<int64_t>() == 0ll;
};

// Filter out TVs from orig_inputs that have zero size in cat dim
std::vector<TensorView*> inputs;
std::vector<std::vector<IterDomain*>> inp_doms;
int64_t ndims = -1;

for (auto inp : inputs) {
bool all_inputs_empty = true;
for (TensorView* inp : orig_inputs) {
NVF_CHECK(
inp->getDataType().value() == dtype,
"Can't concatenate tensors with different data types: ",
dtype,
", ",
inp->getDataType().value());
inp_doms.emplace_back(TensorDomain::noReductions(inp->getLogicalDomain()));
auto i_ndims = static_cast<int64_t>(inp_doms.back().size());
const std::vector<IterDomain*> inp_dom =
TensorDomain::noReductions(inp->getLogicalDomain());

auto i_ndims = static_cast<int64_t>(inp_dom.size());
if (ndims == -1) {
ndims = i_ndims;
if (cat_dim < 0) {
cat_dim += ndims;
}
NVF_CHECK(
cat_dim >= 0 && cat_dim < ndims,
"Invalid dimension to cat: ",
cat_dim);
} else {
NVF_CHECK(
ndims == i_ndims,
Expand All @@ -616,16 +655,52 @@ TensorView* cat(
", expected: ",
ndims);
}
}

if (cat_dim < 0) {
cat_dim += ndims;
bool cat_dim_empty = false;
// Check whether this input is possibly non-empty in any dimension
bool found_empty = false;
for (size_t dim : c10::irange(ndims)) {
if (extentIsEmpty(inp_dom[dim])) {
found_empty = true;
if (dim == cat_dim) {
cat_dim_empty = true;
}
}
}
all_inputs_empty = all_inputs_empty && found_empty;

if (cat_dim_empty) {
// Remove inputs that are empty in the cat dimension
continue;
}

inputs.push_back(inp);
inp_doms.emplace_back(inp_dom);
}

NVF_CHECK(
cat_dim >= 0 && cat_dim < ndims, "Invalid dimension to cat: ", cat_dim);
if (inputs.empty()) {
// All tensors are empty in cat dimension
return set(orig_inputs.at(0));
} else if (all_inputs_empty) {
// All tensors are empty in at least one non-cat dimension. That means we
// can generate the output using full. The output size is computed using the
// size of the first original input except in the cat dim which is the sum
// of the extents of `inputs` in that dimension.
std::vector<Val*> shape(ndims, nullptr);
for (const std::vector<IterDomain*>& inp_dom : inp_doms) {
for (size_t dim : c10::irange(ndims)) {
Val* extent = inp_dom.at(dim)->getMaybeExpandedExtent();
if (dim == cat_dim) {
shape[dim] = SimplifyingIrBuilder::addExpr(shape[dim], extent);
} else if (shape[dim] == nullptr) {
shape[dim] = extent;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the correct behavior should be here. Can we just use the first extent for a non-cat dimension? Suppose the first cat input has size zero for that dimension, the output of the cat would also have zero for the dimension. Would that be expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case it should be caught by the inputs.empty() condition at line 678, since all of the inputs should be empty. Actually though, I just realized that assumes that if one of the inputs has constant size 0 in one dimension that all the other inputs will have constant size zero in the same dimension. If some are symbolic this definition should be proof that they're equal by exact mapping, but they won't be detected at line 660. I guess what I should do instead is fire this condition if any input has a zero non-cat dimension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if one of the inputs has constant size 0 in one dimension that all the other inputs will have constant size zero in the same dimension.

Yeah, this was what I was thinking about. I'm not sure if that's actually allowed. Is it?

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Oct 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have something like this I think:

tv0[ 2, 0 ]
tv1[ 3, i0 ]
tv2 = cat({tv0, tv1}, /*axis=*/0)

In this case we would normally exact map tv0->axis(1) with tv1->axis(1) so that i0 must be 0 or else we'll hit an error in ExpressionEvaluator::propagateBoundValuesThroughExactMaps(). But now since we'll just translate this to full({5, 0}) I don't think there's any such constraint, so maybe it would be legal to pass in a tv1 with shape (3,2) and we would not hit an error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, makes sense.

}
}
}
return full(shape, /*fill_value=*/zeroForDtype(dtype), dtype);
}

// Special handling for the case where there's only one input
// Special handling for the case where there's only one non-empty input
if (inputs.size() == 1) {
return set(inputs.at(0));
}
Expand Down Expand Up @@ -703,7 +778,8 @@ TensorView* cat(
: FusionGuard::getCurFusion()->zeroVal();
left_pad_i = left_pad;
right_pad_i = right_pad;
left_pad = add(left_pad, inp_root_id->getMaybeExpandedExtent());
left_pad = SimplifyingIrBuilder::addExpr(
left_pad, inp_root_id->getMaybeExpandedExtent());
}
// The pad width argument to pad should be ordered such that the
// widths of inner dimensions come first.
Expand Down
41 changes: 41 additions & 0 deletions tests/cpp/test_resize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4041,4 +4041,45 @@ TEST_F(ResizeTest, SliceSliceConcatConcat) {
NVF_CHECK(ref.equal(cg_outputs[0]));
}

// Test that we can cat along broadcast dims that have been expanded
// See https://github.com/NVIDIA/Fuser/issues/3292
TEST_F(ResizeTest, CatOfEmpty) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeConcreteTensor({1, 2, 0});
auto tv1 = makeConcreteTensor({1, 2, 0});
auto tv2 = makeConcreteTensor({1, 2, 3});
auto tv3 = makeConcreteTensor({1, 2, 4});

// Cat of tensors where some non-cat dimensions are zero results in a FullOp
// of the proper size
auto tv4 = cat({tv0, tv1}, 0);
EXPECT_TRUE(tv4->definition() != nullptr);
EXPECT_TRUE(tv4->definition()->isA<FullOp>());
ASSERT_TRUE(tv4->axis(0)->extent()->isConstScalar());
EXPECT_EQ(tv4->axis(0)->extent()->evaluate().as<int64_t>(), 2);

// Cat of tensors that are all empty in cat dim results in LoadStoreOp of the
// first input
auto tv5 = cat({tv0, tv1}, 2);
EXPECT_TRUE(tv5->definition() != nullptr);
EXPECT_TRUE(tv5->definition()->isA<LoadStoreOp>());
EXPECT_TRUE(tv5->definition()->input(0) == tv0);

// Cat of tensors, some empty and some not, results in ignoring the empty
// inputs. In this case that leaves only a set op
auto tv6 = cat({tv0, tv1, tv2}, 2);
EXPECT_TRUE(tv6->definition() != nullptr);
EXPECT_TRUE(tv6->definition()->isA<LoadStoreOp>());
EXPECT_TRUE(tv6->definition()->input(0) == tv2);

// Cat of tensors, some empty and some not, results in ignoring the empty
// inputs. In this case that leaves us with a CatOp with two inputs
auto tv7 = cat({tv0, tv1, tv2, tv3}, 2);
EXPECT_TRUE(tv7->definition() != nullptr);
EXPECT_TRUE(tv7->definition()->isA<CatOp>());
EXPECT_EQ(tv7->definition()->inputs().size(), 2);
}

} // namespace nvfuser
Loading