Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple matmul fusion example #2458

Closed
wants to merge 162 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
162 commits
Select commit Hold shift + click to select a range
2af846b
Add test with non-segmented llama FFN
jacobhinkle Jun 26, 2024
b0db33d
Convert most of scheduleMatmul to handle multiple patterns
jacobhinkle Jun 26, 2024
1f94cc0
Compiles
jacobhinkle Jun 26, 2024
e8bb838
Fix test declaration
jacobhinkle Jun 26, 2024
1ec751c
Fixes for some runtime errors
jacobhinkle Jun 26, 2024
cb90e45
Split up propagation to all
jacobhinkle Jun 27, 2024
24b4a3d
Create MultipleMatmulScheduler
jacobhinkle Jun 28, 2024
ba8856e
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Jul 1, 2024
a30c119
Fix some compile errors
jacobhinkle Jul 1, 2024
3a3f53f
Reset rest of unconverted function to the old scheduleMatmul
jacobhinkle Jul 1, 2024
e680135
Remove some already converted code
jacobhinkle Jul 1, 2024
e167548
First draft
jacobhinkle Jul 2, 2024
1f2b4ad
Fill out. Still failing
jacobhinkle Jul 2, 2024
d35fc5a
Handle IterDomains in AbstractTensor if provided
jacobhinkle Jul 2, 2024
2d3d7e3
Revert test, accept graph optionally. Revert change to id_model
jacobhinkle Jul 2, 2024
95f4514
Pass through when only one input is computable
jacobhinkle Jul 2, 2024
14e7325
Fix test
jacobhinkle Jul 2, 2024
cb1232c
Remove debug prints
jacobhinkle Jul 2, 2024
f945b23
clang-tidy
jacobhinkle Jul 2, 2024
23b4fa2
Compile and execute example in test
jacobhinkle Jul 2, 2024
ffd9e1e
Avoid cycles and clean up comments
jacobhinkle Jul 2, 2024
fab9c81
clang-tidy
jacobhinkle Jul 2, 2024
776e04f
No longer inherit from IterVisitor
jacobhinkle Jul 2, 2024
9c1481d
WIP: AbstractMatmulTensor
jacobhinkle Jul 8, 2024
94ed0b9
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Jul 8, 2024
cb25581
Add example to comment on findNearestProducers
jacobhinkle Jul 8, 2024
35ba702
Add comment about uncomputable groups
jacobhinkle Jul 8, 2024
1b8a550
Update ValGraph when replaying abstract schedule
jacobhinkle Jul 8, 2024
1e9b992
Require AbstractId to be a ValGroupAndItsGraph
jacobhinkle Jul 8, 2024
65be841
Rename applyAbstractSchedule -> applyAbstractTransforms
jacobhinkle Jul 8, 2024
88ababf
Rename abstract->abstract_tensor, concrete->tv
jacobhinkle Jul 8, 2024
da8a96d
Change comment to no longer have diverging transform
jacobhinkle Jul 8, 2024
65d330e
Look at a single Expr* from ExprGroup
jacobhinkle Jul 8, 2024
088a6fe
Many small fixes
jacobhinkle Jul 8, 2024
76a974a
Merge remote-tracking branch 'origin/apply_abstract_schedule' into mu…
jacobhinkle Jul 8, 2024
bde0585
Finally compiles again
jacobhinkle Jul 8, 2024
ea17212
findNearestProducers -> mapScheduledGroupsToLoopIterDomains
jacobhinkle Jul 9, 2024
535e265
Enable scheduling a collection of TVs at once
jacobhinkle Jul 9, 2024
78bea23
lintrunner
jacobhinkle Jul 9, 2024
c5e24a8
Small fixes
jacobhinkle Jul 9, 2024
7ff8e59
Merge remote-tracking branch 'origin/apply_abstract_schedule' into mu…
jacobhinkle Jul 9, 2024
e1ba224
Introduce updateIdModel
jacobhinkle Jul 9, 2024
55387a5
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Jul 9, 2024
369484f
Fix tag propagation
jacobhinkle Jul 9, 2024
111bad9
Add swizzleBlockTiles
jacobhinkle Jul 9, 2024
20dfc3e
Update comment
jacobhinkle Jul 10, 2024
7a51c9a
Refactor to use ReplayTransformations, clarify traversal
jacobhinkle Jul 10, 2024
906c561
Use at_tiled_ to find dimensions for doSplitKRFactor
jacobhinkle Jul 10, 2024
1425b85
Add splitk and grid swizzling to params
jacobhinkle Jul 10, 2024
3c5075c
Merge remote-tracking branch 'origin/apply_abstract_schedule' into mu…
jacobhinkle Jul 11, 2024
4aad5e7
Remove unused variable
jacobhinkle Jul 11, 2024
20284aa
Add forwardAroundMissingAxes utility
jacobhinkle Jul 11, 2024
122d67b
Modify graph_ as we go. Introduces cacheAfter and rFactor
jacobhinkle Jul 12, 2024
83a1b15
Add NVFUSER_ENABLE=fuse_multiple_matmuls option
jacobhinkle Jul 12, 2024
bbaccd9
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Jul 12, 2024
1c1284b
Merge remote-tracking branch 'origin/apply_abstract_schedule' into mu…
jacobhinkle Jul 12, 2024
b4dbd85
Restore old makeTile() from main
jacobhinkle Jul 12, 2024
b3c87e6
Update comment
jacobhinkle Jul 12, 2024
20b74fc
Fix smem epilogue by mapping logical/root in rFactor
jacobhinkle Jul 12, 2024
a34898c
Finish swizzleSharedMemory
jacobhinkle Jul 15, 2024
feca720
Throw error instead of forwarding uncomputable merges
jacobhinkle Jul 15, 2024
2ebaea5
Use forwardAroundMissingAxes in test
jacobhinkle Jul 15, 2024
fc6ad18
Clarify logic in replayAbstractId
jacobhinkle Jul 15, 2024
8ca6a48
Add comment
jacobhinkle Jul 15, 2024
c51f488
Renaming variables for clarity
jacobhinkle Jul 15, 2024
c768b6f
Clean up cacheOperands and start scheduleWarpTileWithReduction
jacobhinkle Jul 15, 2024
7f08a6b
Finish scheduleWarpTileWithReduction
jacobhinkle Jul 15, 2024
d41f22e
Update csrc/abstract_tensor_schedule.cpp
jacobhinkle Jul 15, 2024
0bb5697
Clean up comment
jacobhinkle Jul 15, 2024
eee7fe0
Add more comments
jacobhinkle Jul 15, 2024
7dcf1b5
Check transforms on transformed tvs in test
jacobhinkle Jul 15, 2024
e9dcb20
Add set op on input. TEST IS FAILING
jacobhinkle Jul 15, 2024
8bda1eb
Add big comment and start to refactor
jacobhinkle Jul 16, 2024
08109cc
Prefer non-reduction non-bcast IDs in IdModel scheduling
jacobhinkle Jul 16, 2024
50454bf
Move inlining and circ buffering to end of return
jacobhinkle Jul 16, 2024
ac8ed6b
Parallelize BIDx/BIDy/BIDz
jacobhinkle Jul 16, 2024
7b36a2a
Remove forwardAroundMissingAxes
jacobhinkle Jul 16, 2024
a25471b
Merge remote-tracking branch 'origin/apply_abstract_schedule' into mu…
jacobhinkle Jul 16, 2024
5b62fe0
Remove uses of forwardAroundMissingAxes
jacobhinkle Jul 16, 2024
472096f
Add parametrized test for equality of scheduled fusions
jacobhinkle Jul 17, 2024
e5f7322
Switch to directly tiling.
jacobhinkle Jul 17, 2024
9c2d631
Make tests pass with smem operands
jacobhinkle Jul 17, 2024
a64daaa
Add SimpleMatmulBroadcastedInputs test
jacobhinkle Jul 18, 2024
0af238c
Clean up tests and add more fusions
jacobhinkle Jul 18, 2024
938ee87
Fix up parametrized test names
jacobhinkle Jul 18, 2024
e7baaea
Add vec_size_*, smem_epilogue, splitk_factor params
jacobhinkle Jul 18, 2024
9865340
Add separate parametrization. 1152->360 tests
jacobhinkle Jul 18, 2024
c71ae83
Add cta_order and grid_swizzle_factor params
jacobhinkle Jul 18, 2024
02ecc88
Do split-K split in blockTileTensors
jacobhinkle Jul 18, 2024
4a6a984
Fix splitk and grid swizzling
jacobhinkle Jul 18, 2024
33ebb23
Slight improvements to error messages
jacobhinkle Jul 18, 2024
c4f8936
Fix typo in B vectorization. All tests pass currently
jacobhinkle Jul 18, 2024
06ea417
Prepare cache_op computation for multiple operands
jacobhinkle Jul 18, 2024
9357683
Add BatchMatmul test, check mma_results
jacobhinkle Jul 18, 2024
8c5e04f
Schedule mma_results. Working except with split-K
jacobhinkle Jul 18, 2024
bade21f
mma_results passing all tests
jacobhinkle Jul 19, 2024
fd0fcec
Schedule prologue (first try)
jacobhinkle Jul 23, 2024
bdc3996
Give more informative errors in test
jacobhinkle Jul 23, 2024
8a25cb5
Schedule epilogue
jacobhinkle Jul 23, 2024
6072965
Fix misformatting
jacobhinkle Jul 23, 2024
780aa5c
Count dims after finding patterns
jacobhinkle Jul 23, 2024
8b558d5
Check circular buffering, ca_pos, and generated kernel
jacobhinkle Jul 23, 2024
dcfd964
Llama2FFNNoSegmentation now passes
jacobhinkle Jul 23, 2024
8af8ea0
Schedule split-K sum
jacobhinkle Jul 26, 2024
76e8067
Fix batch test by placing countDims() at end of findRoles
jacobhinkle Jul 26, 2024
a4ab34a
Switch to MNK for broadcasted inputs test
jacobhinkle Aug 28, 2024
73d14de
Small fix to test
jacobhinkle Aug 28, 2024
c0c2371
Store mma_results_ only instead of mma_ops_
jacobhinkle Aug 28, 2024
b853538
Only check paralleltype for IDs that were originally parallelized
jacobhinkle Aug 28, 2024
a6ae5ea
Initial try at scheduleOutputTensor
jacobhinkle Aug 29, 2024
afd5a91
Merge commit '7faa6' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
5a8be66
Merge commit '971f1f' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
d7d7ccd
Merge commit '98de77' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
2145e9a
Merge commit '15bdf9f214' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
1eb40a9
Merge commit 'e1a3019913' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
38c3a52
Merge commit '2e44502' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
f7b22df
Merge commit 'c23441aa' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
b369350
Merge commit 'af62096' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
8b48e69
Merge commit '671171f' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
8d77ff5
Merge commit '4a2987' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
01f5494
Merge commit 'a65f35' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
bc26892
Merge commit '3b6bc01f' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
bc82921
Merge commit 'd4dd6d' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
5e43c9d
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Sep 5, 2024
30898e3
Fix mistake in merge
jacobhinkle Sep 5, 2024
a41d01f
Update swizzleSharedMemory to remove "skip"
jacobhinkle Sep 6, 2024
e156662
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Sep 6, 2024
a9c0ae8
Don't swizzle mma_result. Propagate forward to smem_epilogue
jacobhinkle Sep 6, 2024
63b9047
Disable parallel type checking in test
jacobhinkle Sep 6, 2024
cc86668
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Sep 17, 2024
82f0eb8
Split scheduleMmaOperandOrOutputs into two methods
jacobhinkle Sep 17, 2024
a0885d0
Apply workaround to operands also
jacobhinkle Sep 17, 2024
9474808
Disable sameAs check since don't compare ParallelType
jacobhinkle Sep 17, 2024
d891ebe
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Sep 18, 2024
b4db6b8
Remove abstract tensor scheduling
jacobhinkle Sep 18, 2024
dff1a1d
Remove old include
jacobhinkle Sep 18, 2024
ae3b34b
Merge commit '92b635125ae509cc6b2ccbe29e957586a9cbb059' into multiple…
jacobhinkle Sep 23, 2024
c2dd249
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Sep 23, 2024
4304167
Finish merging origin/main
jacobhinkle Sep 23, 2024
aad380d
Enable new test file
jacobhinkle Sep 23, 2024
a04ff99
Merge remote-tracking branch 'origin/main' into multiple_mma_fusion
jacobhinkle Sep 23, 2024
f5f7805
Remove explicit reference to scheduleMultipleMatmuls in test
jacobhinkle Sep 23, 2024
085ca1a
Multiple Matmul Scheduler: schedule prologue
jacobhinkle Sep 23, 2024
5f2c2e6
Remove stuff merged in error
jacobhinkle Sep 23, 2024
cf0c815
Remove unneeded stuff
jacobhinkle Sep 23, 2024
f4d60b0
Merge branch 'multi_matmul_schedule_prologue' into multiple_mma_fusion
jacobhinkle Sep 23, 2024
1d30dec
Reinstate getOperandTv
jacobhinkle Sep 23, 2024
fae3feb
Make cacheOperandsToSmem a method again
jacobhinkle Sep 23, 2024
0f684ef
Remove changes to test_abstract_tensor.cpp
jacobhinkle Sep 23, 2024
540a782
Multiple Matmul Scheduler: schedule mma instruction
jacobhinkle Sep 23, 2024
5b20611
Multiple matmul scheduler: schedule epilogue
jacobhinkle Sep 23, 2024
aa0c7de
clang-tidy
jacobhinkle Sep 23, 2024
73b4d88
Merge branch 'multi_matmul_schedule_prologue' into multi_matmul_sched…
jacobhinkle Sep 23, 2024
4e53a0b
Merge branch 'multi_matmul_schedule_mma_result' into multi_matmul_sch…
jacobhinkle Sep 23, 2024
eda31aa
Merge branch 'multi_matmul_schedule_prologue' into multiple_mma_fusion
jacobhinkle Sep 23, 2024
f0a0efa
Merge branch 'multi_matmul_schedule_epilogue' into multiple_mma_fusion
jacobhinkle Sep 23, 2024
bca30bf
Remove unused functions
jacobhinkle Sep 23, 2024
b1a362c
Remove redundant makeTile declarations
jacobhinkle Sep 23, 2024
3cdaed8
Remove stale declaration
jacobhinkle Sep 23, 2024
69e9333
Merge branch 'multi_matmul_schedule_prologue' into multi_matmul_sched…
jacobhinkle Sep 23, 2024
7357184
Merge branch 'multi_matmul_schedule_mma_result' into multi_matmul_sch…
jacobhinkle Sep 23, 2024
b2d8731
Merge branch 'multi_matmul_schedule_epilogue' into multiple_mma_fusion
jacobhinkle Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions csrc/mma_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,19 @@ size_t hash(const MatMulTileOptions& opts) {
(hash(opts.cta_tile) << 2);
}

std::string toString(const MatmulDimRole role) {
switch (role) {
case MatmulDimRole::Batch:
return "Batch";
case MatmulDimRole::M:
return "M";
case MatmulDimRole::N:
return "N";
case MatmulDimRole::K:
return "K";
}
// Unreachable
return "Unrecognized role";
}

} // namespace nvfuser
2 changes: 2 additions & 0 deletions csrc/mma_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ constexpr std::string_view MATMUL_LOG_PREFIX = "[MATMUL DEBUG] ";
//! Named descriptors of domains in matmul
enum class MatmulDimRole { M = 0, N, K, Batch };

std::string toString(MatmulDimRole role);

//! Named descriptors of TensorView roles in fusion
//! OPERAND_A - an input to the fusion that is a producer of a matmul "A" input
//! OPERAND_B - an input to the fusion that is a producer of a matmul "B" input
Expand Down
157 changes: 156 additions & 1 deletion csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,143 @@ void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
// [Kwo Mwo Nwo Mw Nw Mi Ni Ki]
}

void scheduleWarpTile(
TensorView* tv,
MatMulTileOptions tile,
const std::vector<MatmulDimRole>& merged_roles) {
// Tiles are always given in [M, N, K]
auto cta_tile = tile.cta_tile;
auto warp_tile = tile.warp_tile;
auto instruction_tile = tile.instruction_tile;

// Do not split K dimension of CTA tile into multiple warp tiles
NVF_CHECK(
cta_tile.k == warp_tile.k,
"CTA tile and warp tile must have same K dimension");

// Find _inner_ tile dimensions present in tv
std::vector<int64_t> inner_dims;
std::vector<int64_t> required_inner_dim_sizes;
int64_t m_dim = -1, n_dim = -1, k_dim = -1;
int64_t min_inner_pos = tv->nDims(); // position of outermost inner dim
for (size_t i : c10::irange(merged_roles.size())) {
int64_t i_tv = tv->nDims() - 1 - (int64_t)i;
int64_t i_roles = (int64_t)merged_roles.size() - 1 - (int64_t)i;
if (i_tv < 0 || i_roles < 0) {
break;
}
inner_dims.push_back(i_tv);
MatmulDimRole role = merged_roles.at(i_roles);
switch (role) {
case MatmulDimRole::M:
m_dim = i_tv;
required_inner_dim_sizes.push_back(cta_tile.m);
break;
case MatmulDimRole::N:
n_dim = i_tv;
required_inner_dim_sizes.push_back(cta_tile.n);
break;
case MatmulDimRole::K:
k_dim = i_tv;
required_inner_dim_sizes.push_back(cta_tile.k);
break;
default:
continue;
break;
}
min_inner_pos = std::min(min_inner_pos, i_tv);
}

NVF_ERROR(
k_dim != -1,
"scheduleWarpTile should only be called on an operand with K dimension but found ",
tv->toString());

mma_utils::checkDimSize(tv, inner_dims, required_inner_dim_sizes);

// m_dim n_dim k_dim
//[... M, N, K]
// Distribute warp tile:
if (m_dim != -1) {
tv->split(m_dim, warp_tile.m);
if (n_dim > m_dim) {
n_dim++;
}
if (k_dim > m_dim) {
k_dim++;
}
}
if (n_dim != -1) {
tv->split(n_dim, warp_tile.n);
if (m_dim > n_dim) {
m_dim++;
}
if (k_dim > m_dim) {
k_dim++;
}
}

// m_dim m_dim+1 n_dim n_dim+1 k_dim
// [ Mwo Mw Nwo Nw K]
if (m_dim != -1) {
tv->split(m_dim + 1, instruction_tile.m);
if (n_dim > m_dim) {
n_dim++;
}
if (k_dim > m_dim) {
k_dim++;
}
}
if (n_dim != -1) {
tv->split(n_dim + 1, instruction_tile.n);
if (m_dim > n_dim) {
m_dim++;
}
if (k_dim > m_dim) {
k_dim++;
}
}
if (k_dim != -1) {
tv->split(k_dim, instruction_tile.k);
if (m_dim > k_dim) {
m_dim++;
}
if (n_dim > k_dim) {
n_dim++;
}
}

// m_dim m_dim+1 m_dim+2 n_dim n_dim+1 n_dim+2 k_dim k_dim+1
// [ Mwo Mw Mi Nwo Nw Ni Kwo Ki]

// Nowwe reorder. The new_order will be
// [Kwo Mwo Nwo Mw Nw Mi Ni Ki]

std::vector<IterDomain*> new_loop;
new_loop.reserve(tv->nDims());
// Add all the outer dimensions as-is
for (size_t i : c10::irange(min_inner_pos)) {
new_loop.push_back(tv->axis((int64_t)i));
}
if (k_dim != -1) {
new_loop.push_back(tv->axis(k_dim));
}
for (size_t j : c10::irange(3)) {
if (m_dim != -1) {
new_loop.push_back(tv->axis(m_dim + (int64_t)j));
}
if (n_dim != -1) {
new_loop.push_back(tv->axis(n_dim + (int64_t)j));
}
}
if (k_dim != -1) {
new_loop.push_back(tv->axis(k_dim + 1));
}
NVF_ERROR(new_loop.size() == tv->nDims());

tv->setLoopDomain(new_loop);
}

void scheduleWarpTileWithNoReduction(TensorView* tv, MatMulTileOptions tile) {
// Assumes
// [M, N, K]
Expand Down Expand Up @@ -434,7 +571,7 @@ void makeTile(
abten.reorder(reorder_map_old_to_new);
}

void makeTile(TensorView* tv, std::vector<int64_t> tile_sizes) {
void makeTile(TensorView* tv, const std::vector<int64_t>& tile_sizes) {
// We will create an AbstractMatmulTensor so that we can use the abstract
// makeTile implementation above.

Expand Down Expand Up @@ -2145,4 +2282,22 @@ std::optional<std::pair<DimRolesMap, TensorRolesMap>> allPatternRoles(

} // namespace mma_utils

std::string toString(const mma_utils::AbstractMatmulTensor& abten) {
std::ostringstream ss;
ss << "AbstractMatmulTensor (" << abten.size() << "):" << std::endl;
for (size_t i : c10::irange(abten.size())) {
const AbstractId& abs_id = abten[(int64_t)i];
const std::optional<MatmulDimRole> role = abten.getTag((int64_t)i).value();
ss << " " << (role.has_value() ? toString(role.value()) : "no role");
if (abs_id.is<ValGroupAndItsGraph>()) {
const ValGroup& g = abs_id.as<ValGroupAndItsGraph>().group;
for (Val* v : g->vector()) {
ss << " " << v->toString();
}
}
ss << std::endl;
}
return ss.str();
}

} // namespace nvfuser
17 changes: 14 additions & 3 deletions csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ NVF_API void scheduleWarpTileWithReduction(
TensorView* tv,
MatMulTileOptions tile);

//! Same as above, but takes a vector of dim roles that existed in tv before
//! splitting/reordering to create block tiles.
//! Note that if the input has a dimension ordering other than M/N K, it will
//! be reordered to have M/N K inner dimensions in this utility.
NVF_API void scheduleWarpTile(
TensorView* tv,
MatMulTileOptions tile,
const std::vector<MatmulDimRole>& merged_dim_roles);

//! Schedule utility for mma output in matmul main loop:
//! Realize the hierarchical tiling based on the given tiling options
//! on consumers of mma ops in epilog.
Expand All @@ -61,10 +70,10 @@ NVF_API void scheduleWarpTileWithNoReduction(
//! Eg.
//! A[B,I0,I1,I2] -> makeTile({1,2,3})
//! Gives A[B, I0o, I1o, I2o, I0i(1), I1i(2), I2i(3)]
void makeTile(TensorView* tv, std::vector<int64_t> tile_sizes);
void makeTile(TensorView* tv, const std::vector<int64_t>& tile_sizes);

//! The above call assumes the axes in TV are [(B), M, N, K]. In this version,
//! we provide the dimension roles that are present for this tensor.
//! The above call assumes the axes are [(B), M, N, K]. In this version, we
//! provide the dimension roles that are present for this tensor.
void makeTile(
TensorView* tv,
const GemmTile& tile_sizes,
Expand Down Expand Up @@ -465,4 +474,6 @@ std::optional<std::pair<DimRolesMap, TensorRolesMap>> allPatternRoles(

} // namespace mma_utils

std::string toString(const mma_utils::AbstractMatmulTensor& abten);

} // namespace nvfuser
Loading
Loading