Skip to content

Commit

Permalink
Patch vectorization on permuted inputs for PadOp (#3439)
Browse files Browse the repository at this point in the history
What's in this PR:
* Avoid vectorization validation check on consumer of vectorized op. We
believe it's safe to do so, since the allocation itself doesn't use that
allocation domain.
* Add a cpp test where the consumer of PadOp has an allocation domain
that's not consistent with the producer.

For future reference:
Alternatively, we can add a `set` after `PadOp` to mimic a cache on
input. This allows us to propagate the allocation domain from input to
the output of `PadOp`, which is the consumer of the vectorized op; while
still preserving the allocation domain on the original out and propagate
it to be the output from `set`.
We decided not to pursue that, because the validation doesn't seem to be
a right check.
  • Loading branch information
jjsjann123 authored and Priya2698 committed Nov 20, 2024
1 parent 9bc3ecc commit 50e059e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
6 changes: 5 additions & 1 deletion csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,13 @@ class VectorizeValidator : public OptInDispatch {
auto ldst = dynamic_cast<LoadStoreOp*>(tv->definition());
bool is_ldmatrix_trans =
ldst != nullptr && mma_utils::isLdMatrixTranspose(ldst);
if (!is_ldmatrix_trans) {
if (!is_ldmatrix_trans && name.compare("consumer") != 0) {
// ldmatrix.trans is a hardware transpose instruction that can do
// "vectorized" read from discontiguous memory
// We don't think allocation domain of consumer is used in allocation. We
// skip it in validation here. Note that this assert was hit for
// vectorized pad, because we do not propagate allocation domain for
// PadOp. See: https://github.com/NVIDIA/Fuser/pull/3439
NVF_CHECK(
last_alloc_dim == vec_alloc_id,
"Vectorized dim for ",
Expand Down
43 changes: 43 additions & 0 deletions tests/cpp/test_pointwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -730,4 +730,47 @@ INSTANTIATE_TEST_SUITE_P(
ss << "_outer_unroll_" << std::get<2>(info.param);
return sanitizeTestName(ss.str());
});

TEST_F(PointwiseTest, VectorizePadLoweringPermuted) {
// Pointwise scheduler applies permutation to restore contiguous memory access
// on reference TV. Vectorization validation requires vectorized operations to
// preserve the allocation domain of their inputs. This test checks that PadOp
// propagates the allocation domain properly.
auto fusion_ptr = std::make_unique<Fusion>();
auto& fusion = *fusion_ptr;
FusionGuard fg(fusion_ptr.get());

// input is permuted
auto tv0 = TensorViewBuilder()
.shape({1024, 1024})
.dtype(DataType::Float)
.contiguity(true)
.strideOrder({0, 1})
.build();
fusion.addInput(tv0);
auto tv1 = pad(tv0, {IrBuilder::create<Val>(4L), IrBuilder::create<Val>(4L)});
auto tv2 = relu(tv1);
fusion.addOutput(tv2);
// output is permuted
tv2->setAllocationDomain({tv2->axis(1), tv2->axis(0)}, true);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 =
at::randn({1024 * 1024}, options).as_strided({1024, 1024}, {1, 1024});
std::vector<c10::IValue> aten_inputs({t0});

auto cg_outputs =
scheduleAndRun(&fusion, SchedulerType::PointWise, aten_inputs).outputs;
// check that we vectorize 4
bool found_vectorize = false;
for (auto id : fusion.outputs().at(0)->as<TensorView>()->getLoopDomain()) {
if (id->getParallelType() == ParallelType::Vectorize) {
EXPECT_EQ(id->extent()->evaluate(), 4);
found_vectorize = true;
break;
}
}
EXPECT_TRUE(found_vectorize);
testValidate(&fusion, cg_outputs, aten_inputs, __LINE__, __FILE__);
}
} // namespace nvfuser

0 comments on commit 50e059e

Please sign in to comment.