Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/NVIDIA/Fuser into compiled_…
Browse files Browse the repository at this point in the history
…kernel
  • Loading branch information
csarofeen committed Oct 7, 2024
2 parents acfe78a + 61a77e0 commit ab83fb7
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 122 deletions.
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,10 +508,10 @@ class AllocationInserter : public kir::ExprMutator {
init = default_val;
}

if (ir_utils::isCpAsyncOp(expr)) {
if (ir_utils::isCpAsyncOp(expr) || ir_utils::isCpAsyncBulk(expr)) {
NVF_CHECK(
init == nullptr || init->isZero(),
"cp.async initialized with non-zero is not supported");
"cp.async and cp.async.bulk initialized with non-zero is not supported");
// cp.async will automatically fill zero when out of bound
init = nullptr;
}
Expand Down
4 changes: 1 addition & 3 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,7 @@ class NVF_API TensorView : public Val {
//! Transforms the innermost iterdomains according to the given mma swizzle,
//! this should be used on the tvs that are inputs of a MmaOp or are loaded
//! using TMA.
void applyMmaSwizzleForTMALoad(
MmaInputSmemSwizzle swizzle,
bool permute_outer_dim = true);
void applyMmaSwizzleForTMALoad(MmaInputSmemSwizzle swizzle);

//! Returns if this tensor view has swizzle operator on its tensor domain.
//! This is the temporary flag for indicating that the new swizzle
Expand Down
18 changes: 2 additions & 16 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -929,14 +929,9 @@ void MmaSwizzler::parallelizeAsBulkSkippingFirstIDs(
}
}

// Please note that we currently do not fully support
// not splitting the outer dimension. This only works when
// the inner-dimension is not split, that is the inner dim
// is less or equal to the swizzle size (in bytes).
void MmaSwizzler::scheduleTMALoadForMma(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
bool permute_outer_dim) {
MmaInputSmemSwizzle swizzle) {
// In the comments below I have kept K as the outer dimension. That is
// just to have a concrete running example - it can be inner or outer.

Expand Down Expand Up @@ -968,16 +963,7 @@ void MmaSwizzler::scheduleTMALoadForMma(
// [NO, K, NI] ->
// [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)]
tv->swizzleTMABox(swizzle);

// If the outer dim is split, then we pull out KO to be outside NO
// and KO and NO are both not marked bulk parallel, else NO is outer
// and only NO is not marked bulk parallel.
if (permute_outer_dim) {
// [NO, KO(2), KIO(2), KII(4), NIO(2), NII(8)] ->
// [KO(2), NO(2), KIO(2), KII(4), NIO(2), NII(8)]
tv->reorder({{-6, -5}});
}
num_ids_to_skip += permute_outer_dim ? 2 : 1;
num_ids_to_skip += 1;
}

parallelizeAsBulkSkippingFirstIDs(tv, num_ids_to_skip);
Expand Down
3 changes: 1 addition & 2 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ class MmaSwizzler {
//! outermost.
static void scheduleTMALoadForMma(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
bool permute_outer_dim = true);
MmaInputSmemSwizzle swizzle);

//! Parallelize all dims as bulk expect the first dims mentioned in the second
//! param.
Expand Down
7 changes: 2 additions & 5 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1388,18 +1388,15 @@ void TensorView::swizzleTMABox(MmaInputSmemSwizzle swizzle) {
this->swizzle(SwizzleType::XOR, -4, -2);
}

void TensorView::applyMmaSwizzleForTMALoad(
MmaInputSmemSwizzle swizzle,
bool permute_outer_dim) {
void TensorView::applyMmaSwizzleForTMALoad(MmaInputSmemSwizzle swizzle) {
NVF_ERROR(
getMemoryType() == MemoryType::Shared,
"Shared memory swizzle is only supported for shared memory");
NVF_ERROR(
definition()->as<LoadStoreOp>()->opType() ==
LoadStoreOpType::CpAsyncBulkTensorTile,
"Operation requires a TMA operation");
mma_utils::MmaSwizzler::scheduleTMALoadForMma(
this, swizzle, permute_outer_dim);
mma_utils::MmaSwizzler::scheduleTMALoadForMma(this, swizzle);
}

void TensorView::commitLeafToLogical() {
Expand Down
Loading

0 comments on commit ab83fb7

Please sign in to comment.