-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle empty tensors during definition of cat #3313
base: main
Are you sure you want to change the base?
Changes from all commits
47c6f0b
0952f76
a6271ae
4b23eda
e8ce453
06e9aaf
96ce7ea
77cd43e
b60fa28
e1691ac
10dd140
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
#include <expr_evaluator.h> | ||
#include <expr_simplifier.h> | ||
#include <ir/builder.h> | ||
#include <ir/utils.h> | ||
|
@@ -13,6 +14,7 @@ | |
#include <ops/utils.h> | ||
#include <transform_view.h> | ||
#include <type_promotion.h> | ||
#include "polymorphic_value.h" | ||
|
||
namespace nvfuser { | ||
|
||
|
@@ -479,6 +481,21 @@ bool hasSimilarDtype(DataType base, DataType dt) { | |
NVF_THROW("Unrecognized base dtype."); | ||
} | ||
|
||
Val* zeroForDtype(DataType dtype) { | ||
// Create a zero of the appropriate type | ||
if (isComplexType(dtype)) { | ||
return IrBuilder::create<Val>(std::complex<double>(0), dtype); | ||
} else if (isFloatingPointType(dtype)) { | ||
return IrBuilder::create<Val>(0.0, dtype); | ||
} else if (isBooleanType(dtype)) { | ||
return IrBuilder::create<Val>(false, dtype); | ||
} else { | ||
return IrBuilder::create<Val>(0L, dtype); | ||
} | ||
NVF_THROW("Unsupported dtype in zeroForDtype: ", dtype); | ||
return nullptr; | ||
} | ||
|
||
// Padding widths are assumed to be non-negative. Currently there's no | ||
// validation. | ||
TensorView* pad( | ||
|
@@ -488,17 +505,7 @@ TensorView* pad( | |
std::optional<IterType> iter_type_opt) { | ||
DataType dt = inp->getDataType().value(); | ||
if (!value) { | ||
// Create a zero of the appropriate type | ||
if (isComplexType(dt)) { | ||
value = static_cast<Val*>( | ||
IrBuilder::create<Val>(std::complex<double>(0), dt)); | ||
} else if (isFloatingPointType(dt)) { | ||
value = static_cast<Val*>(IrBuilder::create<Val>(0.0, dt)); | ||
} else if (isBooleanType(dt)) { | ||
value = static_cast<Val*>(IrBuilder::create<Val>(false, dt)); | ||
} else { | ||
value = static_cast<Val*>(IrBuilder::create<Val>(0L, dt)); | ||
} | ||
value = zeroForDtype(dt); | ||
} | ||
NVF_CHECK( | ||
hasSimilarDtype(dt, value->getDataType().value()), | ||
|
@@ -541,6 +548,8 @@ TensorView* pad( | |
// Indicates if any dimension is actually padded. Can be false even | ||
// when non-empty padding width vector is passed | ||
bool is_padded_any = false; | ||
// If all of the padded dimensions are actually empty to begin with, then we | ||
// can replace this operation with full() | ||
for (const auto idx : c10::irange(ndims)) { | ||
auto inp_root_id = inp_dom.at(idx); | ||
IterDomain* out_root_id = nullptr; | ||
|
@@ -568,14 +577,26 @@ TensorView* pad( | |
return set(inp); | ||
} | ||
|
||
if (std::any_of(inp_dom.begin(), inp_dom.end(), [](IterDomain* id) { | ||
Val* input_extent = id->getMaybeExpandedExtent(); | ||
return input_extent->isConstScalar() && | ||
input_extent->evaluate().as<int64_t>() == 0; | ||
})) { | ||
// We are padding an empty tensor. Instead of PadOp, use FullOp | ||
std::vector<Val*> shape; | ||
shape.reserve(logical_ids.size()); | ||
for (IterDomain* id : logical_ids) { | ||
shape.push_back(id->getMaybeExpandedExtent()); | ||
} | ||
return full(shape, value, dt); | ||
} | ||
auto out = IrBuilder::create<TensorView>( | ||
IrBuilder::create<TensorDomain>( | ||
root_ids, | ||
logical_ids, | ||
logical_ids, | ||
TensorDomain::getContiguityFilledWith(logical_ids, true)), | ||
*inp->getDataType()); | ||
|
||
IrBuilder::create<PadOp>(out, inp, normalized_pad_widths, value); | ||
|
||
return out; | ||
|
@@ -586,28 +607,46 @@ TensorView* pad( | |
// output. All of the inputs to CatOp have the same shape as the | ||
// output shape. | ||
TensorView* cat( | ||
const std::vector<TensorView*>& inputs, | ||
const std::vector<TensorView*>& orig_inputs, | ||
int64_t cat_dim, | ||
std::optional<IterType> iter_type_opt, | ||
bool manual_padding) { | ||
NVF_CHECK(!inputs.empty(), "No input tensor given"); | ||
NVF_CHECK(!orig_inputs.empty(), "No input tensor given"); | ||
|
||
const auto dtype = inputs.at(0)->getDataType().value(); | ||
const DataType dtype = orig_inputs.at(0)->getDataType().value(); | ||
|
||
ExpressionEvaluator expr_eval; | ||
const auto extentIsEmpty = [&expr_eval](IterDomain* id) { | ||
PolymorphicValue extent = expr_eval.evaluate(id->getMaybeExpandedExtent()); | ||
return extent.hasValue() && extent.as<int64_t>() == 0ll; | ||
}; | ||
|
||
// Filter out TVs from orig_inputs that have zero size in cat dim | ||
std::vector<TensorView*> inputs; | ||
std::vector<std::vector<IterDomain*>> inp_doms; | ||
int64_t ndims = -1; | ||
|
||
for (auto inp : inputs) { | ||
bool all_inputs_empty = true; | ||
for (TensorView* inp : orig_inputs) { | ||
NVF_CHECK( | ||
inp->getDataType().value() == dtype, | ||
"Can't concatenate tensors with different data types: ", | ||
dtype, | ||
", ", | ||
inp->getDataType().value()); | ||
inp_doms.emplace_back(TensorDomain::noReductions(inp->getLogicalDomain())); | ||
auto i_ndims = static_cast<int64_t>(inp_doms.back().size()); | ||
const std::vector<IterDomain*> inp_dom = | ||
TensorDomain::noReductions(inp->getLogicalDomain()); | ||
|
||
auto i_ndims = static_cast<int64_t>(inp_dom.size()); | ||
if (ndims == -1) { | ||
ndims = i_ndims; | ||
if (cat_dim < 0) { | ||
cat_dim += ndims; | ||
} | ||
NVF_CHECK( | ||
cat_dim >= 0 && cat_dim < ndims, | ||
"Invalid dimension to cat: ", | ||
cat_dim); | ||
} else { | ||
NVF_CHECK( | ||
ndims == i_ndims, | ||
|
@@ -616,16 +655,52 @@ TensorView* cat( | |
", expected: ", | ||
ndims); | ||
} | ||
} | ||
|
||
if (cat_dim < 0) { | ||
cat_dim += ndims; | ||
bool cat_dim_empty = false; | ||
// Check whether this input is possibly non-empty in any dimension | ||
bool found_empty = false; | ||
for (size_t dim : c10::irange(ndims)) { | ||
if (extentIsEmpty(inp_dom[dim])) { | ||
found_empty = true; | ||
if (dim == cat_dim) { | ||
cat_dim_empty = true; | ||
} | ||
} | ||
} | ||
all_inputs_empty = all_inputs_empty && found_empty; | ||
|
||
if (cat_dim_empty) { | ||
// Remove inputs that are empty in the cat dimension | ||
continue; | ||
} | ||
|
||
inputs.push_back(inp); | ||
inp_doms.emplace_back(inp_dom); | ||
} | ||
|
||
NVF_CHECK( | ||
cat_dim >= 0 && cat_dim < ndims, "Invalid dimension to cat: ", cat_dim); | ||
if (inputs.empty()) { | ||
// All tensors are empty in cat dimension | ||
return set(orig_inputs.at(0)); | ||
} else if (all_inputs_empty) { | ||
// All tensors are empty in at least one non-cat dimension. That means we | ||
// can generate the output using full. The output size is computed using the | ||
// size of the first original input except in the cat dim which is the sum | ||
// of the extents of `inputs` in that dimension. | ||
std::vector<Val*> shape(ndims, nullptr); | ||
for (const std::vector<IterDomain*>& inp_dom : inp_doms) { | ||
for (size_t dim : c10::irange(ndims)) { | ||
Val* extent = inp_dom.at(dim)->getMaybeExpandedExtent(); | ||
if (dim == cat_dim) { | ||
shape[dim] = SimplifyingIrBuilder::addExpr(shape[dim], extent); | ||
} else if (shape[dim] == nullptr) { | ||
shape[dim] = extent; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure what the correct behavior should be here. Can we just use the first extent for a non-cat dimension? Suppose the first cat input has size zero for that dimension, the output of the cat would also have zero for the dimension. Would that be expected? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case it should be caught by the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, this was what I was thinking about. I'm not sure if that's actually allowed. Is it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could have something like this I think:
In this case we would normally exact map There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, makes sense. |
||
} | ||
} | ||
} | ||
return full(shape, /*fill_value=*/zeroForDtype(dtype), dtype); | ||
} | ||
|
||
// Special handling for the case where there's only one input | ||
// Special handling for the case where there's only one non-empty input | ||
if (inputs.size() == 1) { | ||
return set(inputs.at(0)); | ||
} | ||
|
@@ -703,7 +778,8 @@ TensorView* cat( | |
: FusionGuard::getCurFusion()->zeroVal(); | ||
left_pad_i = left_pad; | ||
right_pad_i = right_pad; | ||
left_pad = add(left_pad, inp_root_id->getMaybeExpandedExtent()); | ||
left_pad = SimplifyingIrBuilder::addExpr( | ||
left_pad, inp_root_id->getMaybeExpandedExtent()); | ||
} | ||
// The pad width argument to pad should be ordered such that the | ||
// widths of inner dimensions come first. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use
Fusion::zeroVal(DatatType)
?