Skip to content

Commit

Permalink
check vectorization factor of shared memory consumers to avoid illega…
Browse files Browse the repository at this point in the history
…l vectorization size (#3271)

**Issue** InnerOuter persistent scheduler uses shared memory to store
persistent buffers, the data flow is `input in gmem ---> async copy to
smem --> vectorized load to registers (smem consumers)`, the `-->` are
simply `LoadStoreOp` and same vectorization factors of these two copies
are used. [CI](https://nv/e2E/118278383) found a case where the shared
memory persistent buffers have a data type of fp32 while the inputs are
fp16 (when there are view ops, project to inputs is not used). The
vectorization factor is set to 8 and caused 32 bytes vectorization when
loading from shared memory to registers.

**Changes**:
(1) Added code to handle the vectorization of smem consumers. Add an
additional split if `smem --> regs` copy leads to vectorization larger
than 16 bytes.
(2) Added a test

**Results**: Ensure vectorizations are <= 16 bytes.

**Following works**
See issue #3272

---------

Co-authored-by: Naoya Maruyama <[email protected]>
  • Loading branch information
liqiangxl and naoyam authored Oct 29, 2024
1 parent f6975f3 commit 81d1667
Show file tree
Hide file tree
Showing 6 changed files with 163 additions and 32 deletions.
16 changes: 9 additions & 7 deletions csrc/scheduler/normalization_inner_outer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -967,11 +967,7 @@ void scheduleInnerOuterPersistentKernel(
scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet);
const auto& unroll_vectorizable_cached_tvs =
reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize(
inner_reference_tv,
is_vectorize,
cached_inputs,
cached_outputs,
smem_consumers);
inner_reference_tv, is_vectorize, cached_inputs, cached_outputs);
reduction_scheduler_utils::propagateParallelization(
inner_reduction_tvs[0],
inner_reference_tv,
Expand All @@ -998,8 +994,7 @@ void scheduleInnerOuterPersistentKernel(
outer_reference_tvs[i],
is_vectorize,
cached_inputs,
cached_outputs,
smem_consumers);
cached_outputs);
reduction_scheduler_utils::propagateParallelization(
outer_reduction_tvs[i],
outer_reference_tvs[i],
Expand Down Expand Up @@ -1044,6 +1039,13 @@ void scheduleInnerOuterPersistentKernel(
}
}

// Needs special handling of vectorized loading from shared memory due to
// potential different data types of inputs and shared memory tensor.
if (is_vectorize) {
reduction_scheduler_utils::sharedMemoryConsumerVectorization(
smem_consumers, rparams->unroll_factor_inner_reduction);
}

// Remove dummy outputs as they can inadvertently affect CA positions
for (auto output : dummy_outputs) {
fusion->removeOutput(output);
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/normalization_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,7 @@ void schedulePersistentKernel(
unroll,
vectorize,
is_outer_grid_persistence,
rparams->unroll_factor_inner_reduction,
reduction_tvs,
cached_inputs,
cached_outputs,
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,7 @@ void scheduleReduction(Fusion* fusion, const ReductionParams* rparams) {
unroll,
vectorize,
use_iter_grouped_reduction,
rparams->unroll_factor_inner_reduction,
reduction_tvs,
cached_inputs,
cached_outputs);
Expand Down
77 changes: 60 additions & 17 deletions csrc/scheduler/reduction_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <scheduler/reduction_utils.h>

#include <expr_evaluator.h>
#include <ir/cloner.h>
#include <ir/utils.h>
#include <multidevice/utils.h>
#include <ops/arith.h>
#include <scheduler/reduction_utils.h>
#include <scheduler/registry.h>
#include <scheduler/runtime_info.h>
#include <scheduler/tools/inlining.h>
#include <scheduler/tools/maxinfo_propagator.h>
#include <scheduler/utils.h>
Expand Down Expand Up @@ -348,6 +348,7 @@ void multiReductionInliner(
const bool is_unroll_or_vectorization,
const bool vectorize,
const bool use_grouped_reduction,
const int64_t vectorizatoin_factor,
std::vector<TensorView*> reduction_tvs,
std::vector<TensorView*> cached_inputs,
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs,
Expand All @@ -361,7 +362,7 @@ void multiReductionInliner(
}

const auto& unroll_vectorizable_cached_tvs = getCachedTvsToUnrollOrVectorize(
reference_tv, vectorize, cached_inputs, cached_outputs, smem_consumers);
reference_tv, vectorize, cached_inputs, cached_outputs);
reduction_scheduler_utils::propagateParallelization(
reduction_tv,
reference_tv,
Expand All @@ -370,6 +371,13 @@ void multiReductionInliner(
reduction_tvs,
unroll_vectorizable_cached_tvs);

// Needs special handling of vectorized loading from shared memory due to
// potential different data types of inputs and shared memory tensor.
if (vectorize) {
reduction_scheduler_utils::sharedMemoryConsumerVectorization(
smem_consumers, vectorizatoin_factor);
}

// Remove dummy outputs as they can inadvertently affect CA positions
for (auto output : dummy_outputs) {
fusion->removeOutput(output);
Expand Down Expand Up @@ -428,8 +436,7 @@ std::unordered_set<TensorView*> getCachedTvsToUnrollOrVectorize(
TensorView* reference_tv,
bool vectorize,
const std::vector<TensorView*>& cached_inputs,
const std::vector<std::pair<TensorView*, TensorView*>>& cached_outputs,
const std::vector<TensorView*>& smem_consumers) {
const std::vector<std::pair<TensorView*, TensorView*>>& cached_outputs) {
auto reduced_tv = ir_utils::getSoleProducerTv(reference_tv);
// Grab all tensor views that should be vectorized
auto vectorizable_inputs_outputs =
Expand Down Expand Up @@ -469,18 +476,6 @@ std::unordered_set<TensorView*> getCachedTvsToUnrollOrVectorize(
}
}

if (vectorize) {
for (auto tv : smem_consumers) {
// smem_consumers were added in schedule process
// movePersistentBufferToSmem() using cacheAfter()
NVF_ERROR(
vectorizable_expr(tv->definition()),
"Expected a vectorizable expression, but got: ",
tv->definition()->toString());
unroll_vectorizable_tvs.emplace(tv);
}
}

return unroll_vectorizable_tvs;
}

Expand Down Expand Up @@ -1009,5 +1004,53 @@ std::ostream& operator<<(std::ostream& os, ReductionType reduction_type) {
return os;
}

void sharedMemoryConsumerVectorization(
std::vector<TensorView*>& smem_consumers,
int64_t io_vectorization_factor) {
for (auto tv : smem_consumers) {
// they were creatd with cacheAfter.
NVF_ERROR(
tv->definition()->isA<LoadStoreOp>(),
"smem consumers should be LoadStoreOp. Got: ",
tv->definition()->toString());

// non-concretized broadcast domains are moved to the innermost before
// transform propagation, should skip these axes.
int64_t vect_axis_pos = -1;
while (tv->axis(vect_axis_pos)->isBroadcast()) {
vect_axis_pos--;
NVF_ERROR(
vect_axis_pos + tv->nDims() >= 0,
"Out of bound access when visiting dim ",
vect_axis_pos,
" in Tv: ",
tv->toString());
}
// they were transformed with innermost axis has extent equal to
// vectorization factor set for io tvs.
NVF_ERROR(
tv->axis(vect_axis_pos)->extent()->isConst(),
"Extent of the innermost axis of smem consumers should be constant. Got: ",
tv->toString());
auto innermost_extent =
tv->axis(vect_axis_pos)->extent()->evaluate().as<int64_t>();
NVF_ERROR(
innermost_extent == io_vectorization_factor,
"Extent of the innermost axis of smem consumers should be equal to the vectorization factor of fuion inputs and outputs. Got: ",
innermost_extent,
", expected: ",
io_vectorization_factor);
auto dtype_bytes = dataTypeSize(tv->getDataType().value());
auto max_vect_factor =
SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_bytes;
// additional split is added if the innermost extent is greater than max
// vectorization factor.
if (innermost_extent > max_vect_factor) {
tv->split(vect_axis_pos, max_vect_factor);
}
tv->axis(vect_axis_pos)->parallelize(ParallelType::Vectorize);
}
}

} // namespace reduction_scheduler_utils
} // namespace nvfuser
24 changes: 20 additions & 4 deletions csrc/scheduler/reduction_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void multiReductionInliner(
const bool unroll,
const bool vectorize,
const bool use_grouped_reduction,
const int64_t vectorizatoin_factor,
std::vector<TensorView*> reduction_tvs,
std::vector<TensorView*> cached_inputs,
std::vector<std::pair<TensorView*, TensorView*>> cached_outputs,
Expand Down Expand Up @@ -65,14 +66,11 @@ void propagateRFactor(
// is_vectorize: Indicates if vectorization is applied in the scheduler.
// cached_inputs: Inputs cached in registers or shared memory.
// cached_outputs: Outputs cached in registers.
// smem_consumers: Consumers of shared memory persistent buffers, they are
// register cached Tvs after the shared memory tv.
NVF_API std::unordered_set<TensorView*> getCachedTvsToUnrollOrVectorize(
TensorView* reference_tv,
bool is_vectorize,
const std::vector<TensorView*>& cached_inputs,
const std::vector<std::pair<TensorView*, TensorView*>>& cached_outputs,
const std::vector<TensorView*>& smem_consumers);
const std::vector<std::pair<TensorView*, TensorView*>>& cached_outputs);

// Propagate parallelization from the reference TensorView to other TensorViews.
// Unroll, Vectorize, and MisalignedVectorize types are explicitly handled for
Expand Down Expand Up @@ -139,5 +137,23 @@ std::string toString(ReductionType reduction_type);
ReductionType getReductionType(Fusion* fusion);
ReductionType getReductionType(const std::vector<TensorView*>& reduction_tvs);

/**
* @brief Vectorize shared memory consumers
*
* Applies vectorization to shared memory consumers.
* If extent of the last dim multiples vectorization factor exceeds hardware
* limitations, additional split is added.
*
* @param smem_consumers Vector of TensorView pointers representing shared
* memory consumers
* @param io_vectorization_factor Vectorization factor set for fusion inputs and
* outputs
* @note TODO: Optimize writing to shared memory and address bank conflicts for
* float32 with innermost extent of 8
*/
void sharedMemoryConsumerVectorization(
std::vector<TensorView*>& smem_consumers,
const int64_t io_vectorization_factor);

} // namespace reduction_scheduler_utils
} // namespace nvfuser
76 changes: 72 additions & 4 deletions tests/cpp/test_combined_inner_outer_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,15 +610,15 @@ TEST_F(CombinedSchedulerTest, CombinedReduction) {
false,
inner_reduction_tvs,
reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize(
reference_tv_inner, true, cached_inputs, cached_outputs, {}));
reference_tv_inner, true, cached_inputs, cached_outputs));
reduction_scheduler_utils::propagateParallelization(
outer_reduction_tv,
reference_tv_outer,
true,
false,
outer_reduction_tvs,
reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize(
reference_tv_outer, true, cached_inputs, cached_outputs, {}));
reference_tv_outer, true, cached_inputs, cached_outputs));

inlineMost();
LaunchParams launch_constraints;
Expand Down Expand Up @@ -773,7 +773,7 @@ TEST_F(CombinedSchedulerTest, CombinedReductionMultiPerBlock) {
false,
inner_reduction_tvs,
reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize(
reference_tv_inner, true, cached_inputs, cached_outputs, {}),
reference_tv_inner, true, cached_inputs, cached_outputs),
{selected_tvs_inner.begin(), selected_tvs_inner.end()});

const auto& selected_tvs_outer =
Expand All @@ -787,7 +787,7 @@ TEST_F(CombinedSchedulerTest, CombinedReductionMultiPerBlock) {
false,
outer_reduction_tvs,
reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize(
reference_tv_outer, true, cached_inputs, cached_outputs, {}),
reference_tv_outer, true, cached_inputs, cached_outputs),
{selected_tvs_outer.begin(), selected_tvs_outer.end()});

std::vector<TensorView*> cached_gmem_temp{partialResult};
Expand Down Expand Up @@ -926,4 +926,72 @@ TEST_F(CombinedSchedulerTest, InnerOuterNoOuterBroadcastTv) {
"",
persistent_params->lparams);
}

// Reproduce error found in:
// thunder/tests/test_torch_compile_executor.py::test_torch_compile_cat_nvfuser_phi2_tanh
// Only happens when shared memory persistent is used.
TEST_F(CombinedSchedulerTest, SharedMemoryPersistentVectFactor) {
Fusion fusion;
FusionGuard fg(&fusion);
// When the input is float16, the vectorization factor is set to 8.
// If the persistent buffer tv1 is stored in shared memory and is not
// projected to inputs, the scheduler adds a cacheAfter to load tv1 from
// shared memory to registers in a vectorized manner, avoiding bank conflicts.
// However, since tv1 is float32, we can't directly use the vectorization
// factor set for float16 inputs because the maximum allowed vectorization
// width is 16 bytes.
const int dim0 = 1024;
const int dim1 = 4096;
auto dtype = DataType::Half;
auto tv0 = makeContigTensor(2, dtype);
fusion.addInput(tv0);
auto tv1 = castOp(DataType::Float, tv0);
auto tv2 = sum(tv1, {1});
auto tv3 = broadcast(tv2, {false, true});
auto tv4 = add(tv3, tv1);
auto tv5 = sum(tv1, {0});
auto tv6 = castOp(DataType::Half, tv4);
auto tv7 = castOp(DataType::Half, tv5);
fusion.addOutput(tv6);
fusion.addOutput(tv7);

Fusion fusion_copy = fusion;

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({dim0, dim1}, options);
std::vector<c10::IValue> aten_inputs = {t0};

SchedulerRuntimeInfo runtime_info(&fusion, aten_inputs);
ASSERT_TRUE(Schedule::canSchedule(
SchedulerType::InnerOuterPersistent, &fusion, runtime_info));
auto scheduler = SchedulerEntry::makeSchedulerInstance(
SchedulerType::InnerOuterPersistent);
auto heuristic_params = scheduler->computeHeuristics(&fusion, runtime_info);

// disable projection to inputs, so shared memory buffer is using float32
heuristic_params->as<ReductionParams>()->project_persistent_buffers = false;
// Set vectorization factor to 8, so the exent of the innermost dimension
// exceed 16 bytes (8 x 4 = 32 bytes).
heuristic_params->as<ReductionParams>()->unroll_factor_inner_reduction = 8;
// when compute heuristics, the buffer is projected to inputs and the shared
// memory persistent buffer is the input, tv0. Then, we modified the
// heuristics to disable project to inputs, so needs to update the buffer
// being stored in shared memory to the original unprojected buffer, tv1.
heuristic_params->as<ReductionParams>()->smem_persistent_buffers =
std::vector<TensorView*>{tv1};
scheduler->schedule(&fusion, heuristic_params.get());
FusionExecutor fe;
fe.compileFusion(&fusion, aten_inputs);

for (auto tv : fusion.allTvs()) {
if (tv->getMemoryType() == MemoryType::Shared) {
for (auto consumer : ir_utils::consumerTvsOf(tv)) {
EXPECT_TRUE(isVectorized(consumer));
}
}
}
auto cg_outputs = fe.runFusion(
aten_inputs, heuristic_params->as<ReductionParams>()->lparams);
testValidate(&fusion_copy, cg_outputs, aten_inputs, __LINE__, __FILE__);
}
} // namespace nvfuser

0 comments on commit 81d1667

Please sign in to comment.