Skip to content

Commit

Permalink
Enable translation of Hopper matmuls (#3440)
Browse files Browse the repository at this point in the history
This enables Hopper matmul in our automatic scheduler by translating
them without introducing new broadcasts. Specifically:

1. Update `mma_utils::MatmulPattern::translateToMmaOp` to optionally
avoid intermediates by using an `MmaOp::AxisMapping`. Enable this option
when the target arch is not Ampere or Turing.
3. Unguard some tests in `test_translate_mma.cpp`

This does not update the default heuristic or change the `canSchedule`
checks. See #3579 for that follow-up PR

---------

Co-authored-by: Ryan Spring <[email protected]>
Co-authored-by: Naoya Maruyama <[email protected]>
Co-authored-by: Jingyue Wu <[email protected]>
Co-authored-by: nsarka <[email protected]>
Co-authored-by: Protonu <[email protected]>
Co-authored-by: samnordmann <[email protected]>
  • Loading branch information
7 people authored Dec 16, 2024
1 parent 7ddfffe commit f91d816
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 45 deletions.
13 changes: 12 additions & 1 deletion csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -800,9 +800,10 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
// scheduler.
// 6. Check if the fusion is resharding.

const auto device_prop = at::cuda::getCurrentDeviceProperties();

// #0
{
const auto device_prop = at::cuda::getCurrentDeviceProperties();
// Use a dummy problem shape to determine whether this is a supported
// device.
const auto mma_op = getMmaOp(
Expand All @@ -824,6 +825,16 @@ std::string getMatmulCompileTimeRejectReason(Fusion* fusion) {
{
for (const mma_utils::MatmulPattern& pattern : patterns) {
Expr* op = pattern.output->definition();
if (device_prop->major >= 9 && op->isA<ReductionOp>()) {
bool found_reduction = false;
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) {
if (found_reduction &&
!pattern.output->axis((int64_t)dim)->isReduction()) {
return "Mul+Sum patterns can only be translated to MmaOp "
"on Hopper if the reduction dim is innermost";
}
}
}
if (op->isA<MatmulOp>() || op->isA<LinearOp>()) {
if (!isOptionEnabled(EnableOption::FuseMatmul)) {
// Check for MatmulOp or LinearOp. If found, then only fuse if option
Expand Down
176 changes: 149 additions & 27 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,7 @@ std::string MatmulPattern::toString() const {
return ss.str();
}

MmaOp* MatmulPattern::translateToMmaOp() {
MmaOp* MatmulPattern::translateToMmaOp(bool avoid_intermediates) {
if (auto mma_op = dynamic_cast<MmaOp*>(output->definition())) {
// No translation needed
return mma_op;
Expand All @@ -1804,6 +1804,13 @@ MmaOp* MatmulPattern::translateToMmaOp() {
// - bias, if present, can be zero or one dimensional. Bias can only be
// present if weight is 2D
//
// When A has dimension greater than two, all the preceding dimensions
// are essentially also M dimensions. The output is shaped like
//
// A [ ... iS0{M} iS1{K} ]
// B [ iS2{N} iS3{K} ]
// out [ ... iS3{M} iS3{N} rS3{K} ]
//
// We translate by broadcasting input, weight, and bias such that the
// contracted dimension K is in the last position (this is true of the
// logical domains in input and weight already). Then we form an MmaOp and
Expand All @@ -1812,15 +1819,51 @@ MmaOp* MatmulPattern::translateToMmaOp() {
NVF_ERROR(
A->nDims() > 1 && B->nDims() > 1,
"Cannot translate LinearOp with 1D input");
std::vector<bool> bcast_dim((size_t)A->nDims() + 1, false);
bcast_dim[bcast_dim.size() - 2] = true; // N
A = broadcast(A, bcast_dim);
NVF_ERROR(
B->nDims() == 2, "Cannot translate LinearOp without 2D weight tensor");
if (avoid_intermediates) {
MmaOp::AxisMapping axis_mapping;
int64_t out_dim = A->nDims() + 1L;
axis_mapping.a_axes.reserve(out_dim);
for (int64_t d : c10::irange(out_dim - 2L)) {
axis_mapping.a_axes.push_back(d);
}
axis_mapping.a_axes.reserve(out_dim);
for (size_t d : c10::irange(out_dim - 2)) {
axis_mapping.a_axes.push_back((int64_t)d);
}
axis_mapping.a_axes.push_back(-1); // missing N dimension
axis_mapping.a_axes.push_back(A->nDims() - 1); // K dimension

axis_mapping.b_axes.reserve(out_dim);
axis_mapping.b_axes.resize(out_dim, -1);
axis_mapping.b_axes[out_dim - 2] = 0; // N
axis_mapping.b_axes[out_dim - 1] = 1; // K

int64_t num_M_dims = 1 + A->nDims() - B->nDims();

// Add loop broadcasts to A and B to mimic logical broadcasts for simpler
// scheduling
A->broadcast(-2); // There's always a single N dimension

for ([[maybe_unused]] size_t i : c10::irange((size_t)num_M_dims)) {
// Broadcast B for every M dimension in A
B->broadcast(0);
}

bcast_dim[bcast_dim.size() - 2] = false; // reset N
std::fill(bcast_dim.begin(), bcast_dim.end() - 2, true);
B = broadcast(B, bcast_dim);
fms = fusedMultiplySum(A, B, {-1}, /*init=*/nullptr, axis_mapping);
} else {
std::vector<bool> bcast_dim(A->nDims() + 1, false);
bcast_dim[bcast_dim.size() - 2] = true; // N
A = broadcast(A, bcast_dim);

bcast_dim[bcast_dim.size() - 2] = false; // reset N
std::fill(bcast_dim.begin(), bcast_dim.end() - 2, true);
B = broadcast(B, bcast_dim);

fms = fusedMultiplySum(A, B, {-1});
}

fms = fusedMultiplySum(A, B, {-1});
mma_op = fms->definition()->as<MmaOp>();

auto* bias = dynamic_cast<TensorView*>(lop->bias());
Expand All @@ -1835,19 +1878,82 @@ MmaOp* MatmulPattern::translateToMmaOp() {
// Also note that the output of MatmulOp is a tensor of shape [..., M, N]
// whose dtype matches that of the inputs. We will most commonly then also
// need to cast the output of the MmaOp to produce the output TensorView.
//
// There are two possibilities:
//
// Case 1: A->nDims() > B->nDims():
//
// A [ ..., B1, ..., Bn, M, K ]
// B [ B1, ..., Bn, K, N ]
//
// All the preceding dimensions in A are additional M dimensions. There
// are batch dimensions in between those and "M".
//
// Case 2: A->nDims() <= B->nDims():
//
// A [ B1, ..., Bn, M, K ]
// B [ ..., B1, ..., Bn, K, N ]
//
// All the preceding dimensions in B are additional N dimensions. There
// are batch dimensions in between those and "N".
//
// In either case, to form the output we transpose B in the last two dims,
// and prepend broadcasts to the lower dimensional input as needed.
NVF_ERROR(
A->nDims() > 1 && B->nDims() > 1,
"Cannot translate MatmulOp with 1D input");
TensorView* Btrans = transpose(B, -2, -1);
A = unsqueeze(A, -2);
B = unsqueeze(Btrans, -3);
// A and B might have different dimensions. If so, broadcast the smaller one
// up to the size of the larger.
int64_t out_dims = std::max(A->nDims(), B->nDims());
// Add new outer broadcast dimensions if necessary
A = ops::maybe_broadcast_inner_to_rank(A, out_dims);
B = ops::maybe_broadcast_inner_to_rank(B, out_dims);
fms = fusedMultiplySum(A, B, {-1});
if (avoid_intermediates) {
MmaOp::AxisMapping axis_mapping;
int64_t out_dims = std::max(A->nDims(), B->nDims()) + 1;

axis_mapping.a_axes.resize((size_t)out_dims, -1);
axis_mapping.b_axes.resize((size_t)out_dims, -1);

for (size_t a_axis : c10::irange((size_t)A->nDims() - 1)) {
// Output is [ ... M, N, K ]
// This loop maps everything but N and K to A
int64_t out_axis = (int64_t)a_axis + (out_dims - 1 - A->nDims());
axis_mapping.a_axes.at((size_t)out_axis) = (int64_t)a_axis;
}
// Map the K dim, skipping one position
axis_mapping.a_axes.at((size_t)out_dims - 1) = A->nDims() - 1;

for (size_t b_axis : c10::irange((size_t)B->nDims() - 2)) {
// Output is [ ... M, N, K ]
// This loop maps everything before M to B, skipping the output M dim
int64_t out_axis = (int64_t)b_axis + (out_dims - B->nDims()) - 1;
axis_mapping.b_axes.at((size_t)out_axis) = (int64_t)b_axis;
}
// Skip the K dim and map N and K
axis_mapping.b_axes.at((size_t)out_dims - 2) = B->nDims() - 1;
axis_mapping.b_axes.at((size_t)out_dims - 1) = B->nDims() - 2;

fms = fusedMultiplySum(A, B, {-1}, /*init=*/nullptr, axis_mapping);

int64_t num_M_dims = std::max(1 + A->nDims() - B->nDims(), (int64_t)1);

// Reorder to BMNK.
// Add loop broadcasts to A and B to mimick logical broadcasts for simpler
// scheduling
A->broadcast(-2);

B->reorder({{-2, -1}});
for ([[maybe_unused]] size_t i : c10::irange((size_t)num_M_dims)) {
// Broadcast B for every M dimension in A
B->broadcast(-3);
}
} else {
TensorView* Btrans = transpose(B, -2, -1);
A = unsqueeze(A, -2);
B = unsqueeze(Btrans, -3);
// A and B might have different dimensions. If so, broadcast the smaller
// one up to the size of the larger.
int64_t out_dims = std::max(A->nDims(), B->nDims());
// Add new outer broadcast dimensions if necessary
A = ops::maybe_broadcast_inner_to_rank(A, out_dims);
B = ops::maybe_broadcast_inner_to_rank(B, out_dims);
fms = fusedMultiplySum(A, B, {-1});
}
mma_op = fms->definition()->as<MmaOp>();
} else {
NVF_THROW(
Expand Down Expand Up @@ -1986,17 +2092,18 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const {
// for each valgroup, store a pair of flags. The first records whether the
// group is present at all in the tv. The second records whether the value is
// concrete (i.e. not reduction, broadcast, or device).
std::unordered_map<ValGroup, std::pair<DimPresence, DimPresence>> flags;
std::unordered_map<ValGroup, DimPresence> flags;
const auto recordPresence = [&graph, &flags](
TensorView* tv, size_t tensor_num) {
for (IterDomain* id : tv->getLogicalDomain()) {
const ValGroup& g = graph.toGroup(id);
auto& [present_flags, concrete_flags] = flags[g];
present_flags.set(tensor_num);
DimPresence& group_flags = flags[g];
// Note: broadcast or device dims will be initialized to have all false
// flags above
if (id->isReduction() || id->isBroadcast() || id->isDeviceDim()) {
continue;
}
concrete_flags.set(tensor_num);
group_flags.set(tensor_num);
}
};
recordPresence(A, 0);
Expand All @@ -2005,8 +2112,7 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const {

DimRolesMap dim_roles;

for (const auto& [g, f] : flags) {
const auto& [present_flags, concrete_flags] = f;
for (const auto& [g, concrete_flags] : flags) {
if (concrete_flags.all() || concrete_flags.none()) {
// Batch dimensions are any of those that are not concretized or reduced.
// These could be all Iteration or all Broadcast
Expand All @@ -2019,9 +2125,25 @@ DimRolesMap MatmulPattern::getDimRoles(IdModel& id_model) const {
dim_roles[g] = MatmulDimRole::N;
} else {
NVF_THROW(
"IterDomain ValGroup should be present in at least two of A, B, output.",
" present_flags: ",
present_flags);
"IterDomain ValGroup should be concrete in at least two of A, B, output.",
" concrete_flags: ",
concrete_flags);
}
}

// NOTE: For Hopper, we create loop broadcasts to mimic logical broadcasts
// when translating MatmulOp and LinearOp. Here we detect these and map them
// appropriately.
for (IterDomain* id : A->getLoopDomain()) {
const ValGroup& g = graph.toGroup(id);
if (dim_roles.count(g) == 0) {
dim_roles[g] = MatmulDimRole::N;
}
}
for (IterDomain* id : B->getLoopDomain()) {
const ValGroup& g = graph.toGroup(id);
if (dim_roles.count(g) == 0) {
dim_roles[g] = MatmulDimRole::M;
}
}

Expand Down
6 changes: 5 additions & 1 deletion csrc/scheduler/mma_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,11 @@ struct MatmulPattern {
//! there is a MatmulOp instead, this function modifies the fusion to insert
//! an MmaOp. TensorViews A and B are unchanged, but this->output might be
//! updated to reflect the replacement tensor.
MmaOp* translateToMmaOp();
//!
//! If avoid_intermediates is true, this function will use an
//! MmaOp::AxisMapping instead of broadcasting and permuting axes, in order to
//! avoid introducing unnecessary copies on Hopper and above.
MmaOp* translateToMmaOp(bool avoid_intermediates = false);

//! Given an IdModel, map groups of IterDomains to dimension roles
//! (MatmulDimRole). Note that ValGroup is a shared_ptr to a
Expand Down
19 changes: 18 additions & 1 deletion csrc/scheduler/multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on

#include <ATen/cuda/CUDAContext.h>
#include <c10/util/irange.h>
#include <multidevice/utils.h>
#include <scheduler/ampere_multi_matmul.h>
#include <scheduler/hopper_multi_matmul.h>
Expand All @@ -21,7 +22,23 @@ void MultipleMatmulScheduler::findPatterns() {
void MultipleMatmulScheduler::translatePatterns() {
mma_results_.reserve(patterns_.size());
for (mma_utils::MatmulPattern& pattern : patterns_) {
MmaOp* mma = pattern.translateToMmaOp();
// TODO: properly handle all mul+sum patterns for Hopper. For now, these
// should work fine as long as the inner dimensions are the ones being
// reduced.
if (!isAmpere(params_->mma_macro) && !isTuring(params_->mma_macro) &&
pattern.output->definition()->isA<ReductionOp>()) {
bool found_reduction = false;
for (size_t dim : c10::irange((size_t)pattern.output->nDims())) {
NVF_ERROR(
!found_reduction ||
!pattern.output->axis((int64_t)dim)->isReduction(),
"Mul+Sum patterns can only be translated on Hopper if the reduction dim is innermost");
}
}

MmaOp* mma = pattern.translateToMmaOp(
/*avoid_intermediates=*/!isAmpere(params_->mma_macro) &&
!isTuring(params_->mma_macro));
mma_results_.push_back(mma->out()->as<TensorView>());
}

Expand Down
Loading

0 comments on commit f91d816

Please sign in to comment.