Skip to content

Commit

Permalink
TMA: Don't merge dims when proven wrong (#3472)
Browse files Browse the repository at this point in the history
In our TMA support, there is a mechanism "Define box by compositing"
which is used to reduce the dimensionality of TMA when possible. The
reason for doing so is because our hardware has a limitation that the
tensor dimension can not be more than 5D, and by collapsing some
dimensions, we will be able to support tensors higher than 5D by viewing
it as a tensor of less than 5D.

While I do believe this mechanism makes great sense in general,
mindlessly collapsing as much as possible can cause trouble. For
example, if the box size is `(2, 256)`, collapsing it into `(512,)` is a
mistake because our hardware has another limitation that the box size
can not be larger than `256`.

This PR makes the "Define box by compositing" mechanism smarter: instead
of always collapsing, we check if collapsing can cause trouble. If so,
we will not collapse. Note that this check can not guarantee 100%
safety. For example, if the box size is a symbolic value `v1`, then
there is no way to get its value in lowering, therefore, it is not
possible to know if merging it with another dim can make the box size
greater than 256. But I believe this is fine in practice, because in my
knowledge, we are not interested in symbolic box size right now, so the
box size is almost always constant.

In the future, assume that we are moving towards analytical heuristics,
that is, the box size is a `Val` that is a function of our input shapes,
and represented by our IR node. In this world, I think dynamic
concretization is necessary to make "Define box by compositing"
bug-free. For example, let's say our box size is `(v1, v2, v3, v4, v5,
v6)`, depending on the values of these variables, we might want to do
different things:
- `(v1, v2, v3, v4, v5, v6) = (256, 2, 128, 256, 256, 256)` --> merge
`v2` with `v3` and use 5D TMA
- `(v1, v2, v3, v4, v5, v6) = (256, 256, 2, 128, 256, 256)` --> merge
`v3` with `v4` and use 5D TMA
- `(v1, v2, v3, v4, v5, v6) = (256, 256, 256, 128, 256, 256)` --> raise
an error, not supported
  • Loading branch information
zasdfgbnm authored Nov 26, 2024
1 parent 58e1514 commit 9c9c34c
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 15 deletions.
87 changes: 73 additions & 14 deletions csrc/device_lower/analysis/tma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,16 +800,22 @@ class DomainMerger {
std::unordered_set<ValGroup>& bulk_groups_;
std::unordered_set<ValGroup>& nonbulk_groups_;
std::list<TMADim>& dim_info_;
MmaInputSmemSwizzle swizzle_;
int64_t item_size_bytes_;

public:
DomainMerger(
std::list<std::tuple<ValGroup, bool, Val*>> raw_tma_domain,
std::unordered_set<ValGroup>& bulk_groups,
std::unordered_set<ValGroup>& nonbulk_groups,
std::list<TMADim>& dim_info)
std::list<TMADim>& dim_info,
MmaInputSmemSwizzle swizzle,
int64_t item_size_bytes)
: bulk_groups_(bulk_groups),
nonbulk_groups_(nonbulk_groups),
dim_info_(dim_info) {
dim_info_(dim_info),
swizzle_(swizzle),
item_size_bytes_(item_size_bytes) {
ValGraph& id_graph = GpuLower::current()->tensorIndexer().traversalGraph();
contiguity_and_stride_.reserve(raw_tma_domain.size());
for (auto& item : raw_tma_domain) {
Expand Down Expand Up @@ -868,6 +874,49 @@ class DomainMerger {
return C;
}

bool shouldMerge(int64_t i) {
auto type0 = type(i);
auto type1 = type(i + 1);

bool may_increasing_box_size = (type0 == CB && type1 == CB);
if (!may_increasing_box_size) {
return true;
}

auto extent0 = (*this)[i]->front()->as<IterDomain>()->extent();
auto extent1 = (*this)[i + 1]->front()->as<IterDomain>()->extent();
Val* merged_extent = SimplifyingIrBuilder::mulExpr(extent0, extent1);

bool merging_innermost = ((int64_t)size() == i + 2);

// If merging makes the size of a dimension larger than 256, we should not
// merge.
constexpr int64_t largest_dim_size =
256; // Dimension size must be <= 256 as limited by hardware.
Val* too_large_after_merge = SimplifyingIrBuilder::gtExpr(
merged_extent, IrBuilder::create<Val>(largest_dim_size));
if (simplifyExpr(too_large_after_merge)->isTrue()) {
return false;
}

// If merging makes the inner size larger than the swizzle size,
// we should not merge
if (merging_innermost && swizzle_ != MmaInputSmemSwizzle::None) {
const int64_t swizzle_size =
getBytesFromSwizzle(swizzle_) / item_size_bytes_;
Val* merging_makes_gt_swizzle_size = SimplifyingIrBuilder::gtExpr(
merged_extent, IrBuilder::create<Val>(swizzle_size));
if (simplifyExpr(merging_makes_gt_swizzle_size)->isTrue()) {
return false;
}
}

// Because the shape is dynamic, we don't know if we should merge or
// not. For this case, we always assume merging is better than not
// merging.
return true;
}

void merge(int64_t i) {
auto type0 = type(i);
auto type1 = type(i + 1);
Expand Down Expand Up @@ -941,18 +990,26 @@ std::vector<TMADim> run(
std::unordered_set<ValGroup>& bulk_groups,
std::unordered_set<ValGroup>& nonbulk_groups,
std::list<TMADim>& dim_info,
int64_t item_size_bytes) {
int64_t item_size_bytes,
MmaInputSmemSwizzle swizzle) {
DomainMerger tma_domain(
std::move(raw_tma_domain), bulk_groups, nonbulk_groups, dim_info);
std::move(raw_tma_domain),
bulk_groups,
nonbulk_groups,
dim_info,
swizzle,
item_size_bytes);
// merge contiguous C groups and CB groups
for (int64_t i = 0; i < (int64_t)tma_domain.size() - 1; i++) {
if (!tma_domain.contiguity(i)) {
continue;
}
if ((tma_domain.type(i) == C && tma_domain.type(i + 1) == C) ||
(tma_domain.type(i) == CB && tma_domain.type(i + 1) == CB)) {
tma_domain.merge(i);
i--;
if (tma_domain.shouldMerge(i)) {
tma_domain.merge(i);
i--;
}
}
}
// merge contiguous C with SB/CB
Expand All @@ -962,8 +1019,10 @@ std::vector<TMADim> run(
}
if (tma_domain.type(i) == C &&
(tma_domain.type(i + 1) == SB || tma_domain.type(i + 1) == CB)) {
tma_domain.merge(i);
i--;
if (tma_domain.shouldMerge(i)) {
tma_domain.merge(i);
i--;
}
}
}

Expand Down Expand Up @@ -1056,19 +1115,19 @@ TMAInfo getTMAInfo(LoadStoreOp* ldst) {
"(this is always the case for nvFuser now)",
", the first element of elementStrides must be one.");

MmaInputSmemSwizzle swizzle = getSwizzleFromBytes(
getCpAsyncBulkTensorSwizzleSize(smem_tv) * core_matrix_width_bytes);

// Handle "defining box by compositing" by collapsing some dimensions in the
// raw TMA domain to get the final TMA domain.
auto final_tma_domain = collapse_tma_domain::run(
std::move(raw_tma_domain),
bulk_groups,
nonbulk_groups,
inferred_dims,
dataTypeSize(gmem_tv->dtype()));
return TMAInfo(
std::move(final_tma_domain),
getSwizzleFromBytes(
getCpAsyncBulkTensorSwizzleSize(smem_tv) * core_matrix_width_bytes),
gmem_tv);
dataTypeSize(gmem_tv->dtype()),
swizzle);
return TMAInfo(std::move(final_tma_domain), swizzle, gmem_tv);
}

} // namespace
Expand Down
49 changes: 48 additions & 1 deletion tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,6 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositing2) {
}
// Parallelize the tile axes
tv1->axis(1)->parallelize(ParallelType::Bulk);
// tv2->axis(1)->parallelize(ParallelType::TIDx);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({32, 4, 2, 8, 8, 8, 2, 8, 4}, options);
Expand All @@ -895,6 +894,54 @@ TEST_F(TMAIndexingTest, DefineBoxByCompositing2) {
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
}

TEST_F(TMAIndexingTest, DefineBoxByCompositingShouldNotMerge) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigConcreteTensor({2, 256, 2, 32});
fusion.addInput(tv0);
auto tv1 = set(tv0);
auto tv2 = set(tv1);
fusion.addOutput(tv2);

tv1->setMemoryType(MemoryType::Shared);
tv1->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

// Use 1 thread and a single instruction to load the entire tensor to smem
for (auto id : tv1->getLoopDomain()) {
id->parallelize(ParallelType::Bulk);
}

// Then use 32 threads to dump results out
tv2->axis(3)->parallelize(ParallelType::TIDx);

// Schedule the allocation domain of tv1 to use 128B swizzle
AbstractTensor alloc1(tv1->getLoopDomain());
alloc1.merge(0);
alloc1.merge(0);
// [1024, 32]
alloc1.split(1, 4);
alloc1.split(0, 8);
// [128, 8, 8, 4]
alloc1.swizzle(SwizzleType::XOR, 1, 2);
tv1->setAllocationDomain(alloc1.as<IterDomain*>(), true);

auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto t0 = at::randn({2, 256, 2, 32}, options);
KernelExecutor ke;
ke.compile(&fusion, {t0}, {}, matmul_cparams);

// Because merging dims will violate hardware requirement, we do not merge
// dims.
EXPECT_EQ(TMADimChecker::getDim(ke.kernel()), 4);

EXPECT_TRUE(PredicatedChecker::isPredicated(tv1, ke.kernel()));

auto cg_outputs = ke.run({t0});
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
}

TEST_F(TMAIndexingTest, DefineBoxByRotation1) {
Fusion fusion;
FusionGuard fg(&fusion);
Expand Down

0 comments on commit 9c9c34c

Please sign in to comment.