Skip to content

Commit

Permalink
Accept axis mapping when defining MmaOp (#3391)
Browse files Browse the repository at this point in the history
This keeps the default interface of `fusedMultiplySum` but also adds an
option to provide an `MmaOp::AxisMapping` object. This mapping defines,
for each output dimension, which axis in each operand (if any)
corresponds to that output dimension.

This PR does not alter the behavior of
`mma_utils::MatmulPattern::translateToMmaOp` meaning we still have
BroadcastOp in translations for Hopper matmuls, but that change should
be relatively simpler.

Fixes #3372 

The included test only checks that dimensions are properly mapped in an
MmaOp defined without broadcast axes. In followup PRs I plan to do the
following:
1. Demonstrate scheduling a Hopper matmul with unbroadcasted inputs
manually. This should surface any bugs in the lowering of the MmaOp
instruction when broadcasts are absent.
2. Ensure that we don't depend on having broadcast dims in the Hopper
matmul scheduler. For example, we will handle this case in
`moveInnerBroadcastLeft` and we may also need to adjust the swizzling of
the TMA smem load TensorView. At this point we will be able to
automatically schedule an `MmaOp` without broadcasted inputs that has
been manually defined using our automatic scheduler.
3. Add an option
`MatmulPattern::translateToMmaOp(/*avoid_intermediates=*/true)` and
enable that in the Hopper matmul scheduler. At this point it will be
safe for us to accept `MatmulOp` and `LinearOp` in the Hopper matmul
scheduler.
  • Loading branch information
jacobhinkle authored Nov 12, 2024
1 parent a5022da commit 030c2ba
Show file tree
Hide file tree
Showing 12 changed files with 269 additions and 320 deletions.
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2087,8 +2087,8 @@ void IndexLowering::handle(const MmaOp* mma) {
}
const auto out = lowerDstIndex(
mma->out(), {}, false, getMmaOutType(mma->out()->as<TensorView>()));
auto mma_indexed =
IrBuilder::create<MmaOp>(out, a, b, mma->init(), mma->macro());
auto mma_indexed = IrBuilder::create<MmaOp>(
out, a, b, mma->init(), mma->axisMapping(), mma->macro());
pushBack(mma_indexed);
GpuLower::current()->propagateExprInfo(mma, back());
}
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ class ReplaceExprInput : private kir::ExprMutator {
replaced_inputs->at(node->inA()),
replaced_inputs->at(node->inB()),
node->init(),
node->axisMapping(),
node->macro());
registerReplaceWithPredicate(node, replacement);
}
Expand Down
70 changes: 50 additions & 20 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1359,16 +1359,61 @@ class GroupedWelfordOp : public Expr {
class NVF_API MmaOp : public Expr {
public:
using AxesData = std::vector<int64_t>;
// AxisMapping denotes the pairing of two input dimensions to produce an
// output dimension. It holds two vectors of integers indicating the
// corresponding position of each output axis in either the A or B input.
// Positions refer to the noReductions logical domain of each input.
// NOTE: Axis positions are absolute, meaning you cannot specify them
// relative to the last dimension since -1 has special meaning.
// NOTE: -1 indicates that the axis does not exist, so Broadcast input
// domains should be listed with their actual position and not -1.
//
// Example 1:
// a [ K, 1, M ]
// b [ 1, N, K ]
// out [ M, N, rK ]
// axisMapping:
// a_axes = [ 2, 1, 0 ]
// b_axes = [ 0, 1, 2 ]
// This results in the following groups of mapped axes:
// { tv_a->axis(2), tv_b->axis(0), out->axis(0) }
// { tv_a->axis(1), tv_b->axis(1), out->axis(1) }
// { tv_a->axis(0), tv_b->axis(2), out->axis(2) }
//
// Example 1:
// a [ K, M ]
// b [ 1, N, K ]
// out [ M, N, rK ]
// axisMapping:
// a_axes = [ 1, -1, 0 ]
// b_axes = [ 0, 1, 2 ]
// This results in the following groups of mapped axes:
// { tv_a->axis(1), tv_b->axis(0), out->axis(0) }
// { tv_b->axis(1), out->axis(1) }
// { tv_a->axis(0), tv_b->axis(2), out->axis(2) }
struct AxisMapping {
AxesData a_axes;
AxesData b_axes;

static AxisMapping trivialMapping(size_t dimension);
};
using Expr::Expr;

MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);
MmaOp(
IrBuilderPasskey,
Val* out,
Val* in_a,
Val* in_b,
Val* init,
const AxisMapping& axis_mapping);

MmaOp(
IrBuilderPasskey,
Val* out,
Val* in_a,
Val* in_b,
Val* init,
const AxisMapping& axis_mapping,
const MmaMacro& options);

NVFUSER_DECLARE_CLONE_AND_CREATE
Expand Down Expand Up @@ -1426,32 +1471,17 @@ class NVF_API MmaOp : public Expr {

void setMacro(MmaMacro options);

const auto& mAxes() const {
return attribute<AxesData>(ATTR_POS_M_AXES);
}

const auto& nAxes() const {
return attribute<AxesData>(ATTR_POS_N_AXES);
}

const auto& kAxes() const {
return attribute<AxesData>(ATTR_POS_K_AXES);
}

const auto& batchAxes() const {
return attribute<AxesData>(ATTR_POS_BATCH_AXES);
const AxisMapping& axisMapping() const {
return attribute<AxisMapping>(ATTR_POS_AXIS_MAPPING);
}

private:
// Predefined idexes of attributes stored for this IR node, to avoid
// Predefined indices of attributes stored for this IR node, to avoid
// magic numbers, based on order in which attributes are initialized
// in constructor
static constexpr size_t ATTR_POS_INIT = 0;
static constexpr size_t ATTR_POS_MACRO = 1;
static constexpr size_t ATTR_POS_M_AXES = 2;
static constexpr size_t ATTR_POS_N_AXES = 3;
static constexpr size_t ATTR_POS_K_AXES = 4;
static constexpr size_t ATTR_POS_BATCH_AXES = 5;
static constexpr size_t ATTR_POS_AXIS_MAPPING = 2;
};

//! The semantics are identical to torch.broadcast_to.
Expand Down
50 changes: 26 additions & 24 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1980,12 +1980,24 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedWelfordOp)

//==============================================================================================================================

MmaOp::AxisMapping MmaOp::AxisMapping::trivialMapping(size_t dimension) {
AxesData a_axes, b_axes;
a_axes.reserve(dimension);
b_axes.reserve(dimension);
for (size_t i : c10::irange(dimension)) {
a_axes.push_back((int64_t)i);
b_axes.push_back((int64_t)i);
}
return {a_axes, b_axes};
}

MmaOp::MmaOp(
IrBuilderPasskey passkey,
Val* out,
Val* in_a,
Val* in_b,
Val* init)
Val* init,
const AxisMapping& axis_mapping)
: Expr(passkey) {
NVF_ERROR(
out->getValType().value() == ValType::TensorView ||
Expand All @@ -2002,35 +2014,24 @@ MmaOp::MmaOp(
in_b->getValType().value() == ValType::TensorIndex,
in_b->getValType().value());

NVF_ERROR(
axis_mapping.a_axes.size() == axis_mapping.b_axes.size(),
"Must have the same number of axis positions in axis mapping for each operand");

auto* out_tv = ir_utils::getTv(out);
NVF_ERROR(
axis_mapping.a_axes.size() == out_tv->getMaybeRootDomain().size(),
"Must have the same number of axis positions in axis mapping as output root dimensions");

addOutput(out);
addInput(in_a);
addInput(in_b);
// ATTR_POS_INIT
addAttribute(init);
// ATTR_POS_MACRO
addDataAttribute(MmaMacro::NoMMA);
// ATTR_POS_M_AXES
addDataAttribute(AxesData{});
// ATTR_POS_N_AXES
addDataAttribute(AxesData{});
// ATTR_POS_K_AXES
addDataAttribute(AxesData{});
// ATTR_POS_BATCH_AXES
addDataAttribute(AxesData{});

MmaOpUtils::MmaOpDetails mma_details;
// Detailed consistency checks for use case with TensorViews as
// inputs/output
if (in_a->isA<TensorView>() && in_b->isA<TensorView>() &&
out->isA<TensorView>()) {
mma_details = MmaOpUtils::getMmaOpDetails(
out->as<TensorView>(), in_a->as<TensorView>(), in_b->as<TensorView>());
}

attribute<AxesData>(ATTR_POS_M_AXES) = std::move(mma_details.m_axes);
attribute<AxesData>(ATTR_POS_N_AXES) = std::move(mma_details.n_axes);
attribute<AxesData>(ATTR_POS_K_AXES) = std::move(mma_details.k_axes);
attribute<AxesData>(ATTR_POS_BATCH_AXES) = std::move(mma_details.batch_axes);
// ATTR_POS_AXIS_MAPPING
addDataAttribute(axis_mapping);
}

MmaOp::MmaOp(
Expand All @@ -2039,8 +2040,9 @@ MmaOp::MmaOp(
Val* in_a,
Val* in_b,
Val* init,
const AxisMapping& axis_mapping,
const MmaMacro& macro)
: MmaOp(passkey, out, in_a, in_b, init) {
: MmaOp(passkey, out, in_a, in_b, init, axis_mapping) {
attribute<MmaMacro>(ATTR_POS_MACRO) = macro;
}

Expand Down
155 changes: 0 additions & 155 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1265,158 +1265,3 @@ int64_t getOperationCount(Val* val) {
}

} // namespace nvfuser::ir_utils

namespace nvfuser::MmaOpUtils {

// A helper for gathering details about TensorView object
TensorViewDetails getDetailsFor(const std::vector<IterDomain*>& dims) {
TensorViewDetails details;
for (auto pos : c10::irange((int64_t)dims.size())) {
const auto axis = dims.at(pos);
if (axis->isReduction()) {
details.rdomains.push_back(pos);
} else if (axis->isBroadcast()) {
details.bcasts.push_back(pos);
} else {
details.cdomains.push_back(pos);
}
}
return details;
}

MmaOpDetails getMmaOpDetails(
TensorView* out,
TensorView* in_a,
TensorView* in_b) {
const auto in_a_details =
getDetailsFor(TensorDomain::noDevices(in_a->getLogicalDomain()));
const auto in_b_details =
getDetailsFor(TensorDomain::noDevices(in_b->getLogicalDomain()));
const auto out_details =
getDetailsFor(TensorDomain::noDevices(out->getMaybeRootDomain()));

using AxesData = MmaOp::AxesData;

const auto getMOrNaxes = [](const AxesData& cdomains,
const AxesData& bcasts,
const AxesData& rdomains) {
AxesData result;
// For all concrete domains
for (const auto& cdomain : cdomains) {
// That are in broadcast domains but are not in reduction domains
if ((std::find(bcasts.begin(), bcasts.end(), cdomain) != bcasts.end()) &&
(std::find(rdomains.begin(), rdomains.end(), cdomain) ==
rdomains.end())) {
result.push_back(cdomain);
}
}
return result;
};

const auto getKaxes = [](const AxesData& cdomains_a,
const AxesData& cdomains_b,
const AxesData& rdomains) {
AxesData result;
// For all concrete domains from in_a
for (const auto& cdomain_a : cdomains_a) {
// That are in concrete domains in in_b and are in reduction domains
if ((std::find(cdomains_b.begin(), cdomains_b.end(), cdomain_a) !=
cdomains_b.end()) &&
(std::find(rdomains.begin(), rdomains.end(), cdomain_a) !=
rdomains.end())) {
result.push_back(cdomain_a);
}
}
return result;
};

const auto getBatchAxes = [](const TensorViewDetails& in_a_details,
const TensorViewDetails& in_b_details,
const TensorViewDetails& out_details) {
AxesData result;
// Batch candidates:
// concrete domains that are in all of inputs and output
for (const auto& domain : in_a_details.cdomains) {
if ((std::find(
in_b_details.cdomains.begin(),
in_b_details.cdomains.end(),
domain) != in_b_details.cdomains.end()) &&
(std::find(
out_details.cdomains.begin(),
out_details.cdomains.end(),
domain) != out_details.cdomains.end())) {
result.push_back(domain);
}
}
// Batch candidates:
// broadcast domains that are in all of inputs and output
for (const auto& domain : in_a_details.bcasts) {
if ((std::find(
in_b_details.bcasts.begin(),
in_b_details.bcasts.end(),
domain) != in_b_details.bcasts.end()) &&
(std::find(
out_details.bcasts.begin(), out_details.bcasts.end(), domain) !=
out_details.bcasts.end())) {
result.push_back(domain);
}
}
std::sort(result.begin(), result.end());
return result;
};

const auto validateInputDetails = [](const TensorViewDetails& details,
const std::string& desc) {
NVF_ERROR(!details.bcasts.empty(), desc, ": has no broadcast domains.");
NVF_ERROR(details.rdomains.empty(), desc, ": has reduction domains.");
NVF_ERROR(
details.cdomains.size() >= expected_gemm_cdomains,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
};

const auto validateOutputDetails = [](const TensorViewDetails& details,
const std::string& desc) {
// TODO: revise rules when add support for batch gemms
NVF_ERROR(!details.rdomains.empty(), desc, ": has no reduction domains.");
NVF_ERROR(
(details.cdomains.size() >= expected_gemm_cdomains),
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
};

validateInputDetails(in_a_details, "MmaOp input A");
validateInputDetails(in_b_details, "MmaOp input B");
validateOutputDetails(out_details, "MmaOp output");

MmaOpDetails details;

// For details, check MmaOpDetails
details.m_axes = getMOrNaxes(
in_a_details.cdomains, in_b_details.bcasts, out_details.rdomains);
details.n_axes = getMOrNaxes(
in_b_details.cdomains, in_a_details.bcasts, out_details.rdomains);
details.k_axes = getKaxes(
in_a_details.cdomains, in_b_details.cdomains, out_details.rdomains);
details.batch_axes = getBatchAxes(in_a_details, in_b_details, out_details);

NVF_ERROR(
!details.m_axes.empty(),
"MmaOp inputs must define at least a single M dimension");
NVF_ERROR(
!details.n_axes.empty(),
"MmaOp inputs must define at least a single N dimension");
NVF_ERROR(
!details.k_axes.empty(),
"MmaOp inputs must define at least a single K dimension");

return details;
}

} // namespace nvfuser::MmaOpUtils
Loading

0 comments on commit 030c2ba

Please sign in to comment.