Skip to content

Commit

Permalink
Use deep evaluation of extents in remove_empty pass (#3301)
Browse files Browse the repository at this point in the history
For dynamic fusions, we detect empty tensors and set their extents to
immediate constant 0. Later, in the remove_empty preseg pass, we do a
shallow check that extents are empty so that we can simplify the fusion.
When the fusion is not dynamic there is no concretization step where we
would do this extent replacement, so we might have constant 0 extents
that are compound scalars. This caused us to miss some empty tensors in
#3292, particularly one of the inputs to a `cat`.

This PR:
- Uses a deep evaluation of each `getMaybeExpandedExtent()` to determine
if an axis is empty
- Adds an ExpressionEvaluator field to `EmptyTensorRemover` to avoid
repeating the deep evaluation when possible. This won't help prevent
repeated evaluation of symbolic extents; we could track those in an
`unordered_set` potentially instead.

Fixes #3292

---------

Co-authored-by: Naoya Maruyama <[email protected]>
  • Loading branch information
jacobhinkle and naoyam authored Oct 30, 2024
1 parent a4465df commit 81dd1d2
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 27 deletions.
63 changes: 37 additions & 26 deletions csrc/preseg_passes/remove_empty.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
// clang-format on
#include <preseg_passes/remove_empty.h>

#include <expr_evaluator.h>
#include <ir/utils.h>
#include <iter_visitor.h>
#include <ops/alias.h>
#include <ops/arith.h>
#include <polymorphic_value.h>

#include <algorithm>
#include <limits>
Expand All @@ -21,29 +23,6 @@ namespace nvfuser::preseg_passes {

namespace {

//! Get a vector of the integer positions of constant zero extent axes in the
//! input domain. This will typically be used like
//! `emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain()))`
std::vector<int64_t> emptyAxes(const std::vector<IterDomain*>& domain) {
std::vector<int64_t> empty_axes;
for (auto ax : c10::irange(domain.size())) {
auto id = domain.at(ax);
if (id->getMaybeExpandedExtent()->isConst() &&
id->getMaybeExpandedExtent()->evaluate().as<int64_t>() == 0) {
empty_axes.push_back((int64_t)ax);
}
}
return empty_axes;
}

//! Check whether a TensorView is empty. During concretization, we traverse to
//! find a minimal set of TensorViews that have zero extents, and we then set
//! their extents to a constant 0. Here we check for those constant zero
//! extents.
bool isTVEmpty(TensorView* tv) {
return !emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain())).empty();
}

//! EmptyTensorRemover performs a backward traversal of the Fusion. When it
//! detects a TensorView that has at least one extent that is zero, we do the
//! following:
Expand All @@ -69,9 +48,34 @@ class EmptyTensorRemover : public DeadCodeRemover {
public:
EmptyTensorRemover(Fusion* fusion) : DeadCodeRemover(fusion) {}

protected:
private:
using DeadCodeRemover::handle;

//! Get a vector of the integer positions of constant zero extent axes in the
//! input domain. This will typically be used like
//! `emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain()))`
std::vector<int64_t> emptyAxes(const std::vector<IterDomain*>& domain) {
std::vector<int64_t> empty_axes;
for (auto ax : c10::irange(domain.size())) {
auto id = domain.at(ax);
PolymorphicValue extent =
expr_eval_.evaluate(id->getMaybeExpandedExtent());
if (extent.hasValue() && extent.as<int64_t>() == 0) {
empty_axes.push_back((int64_t)ax);
}
}
return empty_axes;
}

//! Check whether a TensorView is empty. During concretization, we traverse to
//! find a minimal set of TensorViews that have zero extents, and we then set
//! their extents to a constant 0. Here we check for those constant zero
//! extents.
bool isTVEmpty(TensorView* tv) {
return !emptyAxes(TensorDomain::noReductions(tv->getLogicalDomain()))
.empty();
}

//! If tv is a fusion output, we check whether it is empty and if so, replace
//! it with full(). For non-outputs that are not inputs, we simply check that
//! the tensor is not provably empty.
Expand Down Expand Up @@ -257,8 +261,9 @@ class EmptyTensorRemover : public DeadCodeRemover {
"Inputs to CatOp must be outputs of PadOps");
auto tv = inp->definition()->as<PadOp>()->in()->as<TensorView>();
auto cat_id = TensorDomain::noReductions(tv->getLogicalDomain()).at(dim);
if (cat_id->getMaybeExpandedExtent()->isConst() &&
cat_id->getMaybeExpandedExtent()->evaluate().as<int64_t>() == 0) {
PolymorphicValue extent =
expr_eval_.evaluate(cat_id->getMaybeExpandedExtent());
if (extent.hasValue() && extent.as<int64_t>() == 0) {
continue;
}
non_empty_inputs.push_back(tv);
Expand Down Expand Up @@ -312,6 +317,12 @@ class EmptyTensorRemover : public DeadCodeRemover {
registerReplacement(out, new_tv);
}
}

private:
// We use this ExpressionEvaluator without binding any inputs. This lets us
// quickly repeatedly evaluate compound constant expressions like
// ( fmax(0, ( fmin(( ceilDiv(576, 9) ), 0) )) )
ExpressionEvaluator expr_eval_;
};

} // namespace
Expand Down
3 changes: 2 additions & 1 deletion csrc/serde/fusion_record.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ python_frontend::RecordFunctor* deserializeOpRecord(
const RecordFunctor* buffer) {
NVF_ERROR(
str_to_func_map.find(buffer->name()->str()) != str_to_func_map.end(),
"Missing mapping from operation string to nvfuser function in serde deserialization.");
"Missing mapping from operation string to nvfuser function in serde deserialization: ",
buffer->name()->str());
return new python_frontend::OpRecord<Signature...>(
parseStateArgs(buffer->args()),
parseStateArgs(buffer->outputs()),
Expand Down
74 changes: 74 additions & 0 deletions tests/python/test_python_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4600,3 +4600,77 @@ def fusion_func(fd: FusionDefinition) -> None:
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs)
for out in nvf_out:
self.assertTrue(out.allclose(x[:, 1:, 2:]))

def test_issue_3292(self):
inputs = [
torch.testing.make_tensor(
(5, 5, 576), dtype=torch.float32, device="cuda:0"
),
]

def fusion_func(fd: FusionDefinition) -> None:
T2 = fd.define_tensor(
shape=[5, 5, 576],
contiguity=[True, True, True],
dtype=DataType.Float,
is_cpu=False,
stride_order=[2, 1, 0],
)
T30 = fd.ops.reshape(T2, new_shape=[5, 5, 1, 9, 64])
T31 = fd.ops.permute(T30, dims=[0, 2, 3, 1, 4])
T50 = fd.ops.slice(
T31,
start_indices=[0, 0, 0, 0, 0],
end_indices=[5, 1, 7, 5, 64],
strides=[1, 1, 1, 1, 1],
manual_normalization=0,
)
T108 = fd.ops.reshape(T50, new_shape=[5, 7, 5, 64])
T136 = fd.ops.slice(
T108,
start_indices=[0, 0, 0, 0],
end_indices=[5, 7, 5, 32],
strides=[1, 1, 1, 1],
manual_normalization=0,
)
T152 = fd.ops.slice(
T108,
start_indices=[0, 0, 0, 32],
end_indices=[5, 7, 5, 64],
strides=[1, 1, 1, 1],
manual_normalization=0,
)
T153 = fd.ops.neg(T152)
T154 = fd.ops.cat([T153, T136], dim=-1, manual_padding=0)
T161 = fd.ops.mul(T108, T108)
T168 = fd.ops.mul(T154, T154)
T169 = fd.ops.add(T161, T168)
T185 = fd.ops.slice(
T108,
start_indices=[0, 0, 0, 0],
end_indices=[5, 7, 5, 32],
strides=[1, 1, 1, 1],
manual_normalization=0,
)
T201 = fd.ops.slice(
T108,
start_indices=[0, 0, 0, 32],
end_indices=[5, 7, 5, 64],
strides=[1, 1, 1, 1],
manual_normalization=0,
)
T202 = fd.ops.neg(T201)
T203 = fd.ops.cat([T202, T185], dim=-1, manual_padding=0)
T205 = fd.ops.mul(T203, T203)
T222 = fd.ops.slice(
T108,
start_indices=[0, 0, 0, 0],
end_indices=[5, 7, 5, 0],
strides=[1, 1, 1, 1],
manual_normalization=0,
)
T223 = fd.ops.cat([T169, T222], dim=-1, manual_padding=0)
fd.add_output(T223)

# is_clonable=False is because translation fails with missing ceilDiv
nvf_out, _ = self.exec_nvfuser(fusion_func, inputs, is_clonable=False)

0 comments on commit 81dd1d2

Please sign in to comment.