Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept axis mapping when defining MmaOp #3391

Merged
merged 17 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2087,8 +2087,8 @@ void IndexLowering::handle(const MmaOp* mma) {
}
const auto out = lowerDstIndex(
mma->out(), {}, false, getMmaOutType(mma->out()->as<TensorView>()));
auto mma_indexed =
IrBuilder::create<MmaOp>(out, a, b, mma->init(), mma->macro());
auto mma_indexed = IrBuilder::create<MmaOp>(
out, a, b, mma->init(), mma->axisMapping(), mma->macro());
pushBack(mma_indexed);
GpuLower::current()->propagateExprInfo(mma, back());
}
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ class ReplaceExprInput : private kir::ExprMutator {
replaced_inputs->at(node->inA()),
replaced_inputs->at(node->inB()),
node->init(),
node->axisMapping(),
node->macro());
registerReplaceWithPredicate(node, replacement);
}
Expand Down
70 changes: 50 additions & 20 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1359,16 +1359,61 @@ class GroupedWelfordOp : public Expr {
class NVF_API MmaOp : public Expr {
public:
using AxesData = std::vector<int64_t>;
// AxisMapping denotes the pairing of two input dimensions to produce an
// output dimension. It holds two vectors of integers indicating the
// corresponding position of each output axis in either the A or B input.
// Positions refer to the noReductions logical domain of each input.
// NOTE: Axis positions are absolute, meaning you cannot specify them
// relative to the last dimension since -1 has special meaning.
// NOTE: -1 indicates that the axis does not exist, so Broadcast input
// domains should be listed with their actual position and not -1.
//
// Example 1:
// a [ K, 1, M ]
// b [ 1, N, K ]
// out [ M, N, rK ]
// axisMapping:
// a_axes = [ 2, 1, 0 ]
// b_axes = [ 0, 1, 2 ]
// This results in the following groups of mapped axes:
// { tv_a->axis(2), tv_b->axis(0), out->axis(0) }
// { tv_a->axis(1), tv_b->axis(1), out->axis(1) }
// { tv_a->axis(0), tv_b->axis(2), out->axis(2) }
//
// Example 1:
// a [ K, M ]
// b [ 1, N, K ]
// out [ M, N, rK ]
// axisMapping:
// a_axes = [ 1, -1, 0 ]
// b_axes = [ 0, 1, 2 ]
// This results in the following groups of mapped axes:
// { tv_a->axis(1), tv_b->axis(0), out->axis(0) }
// { tv_b->axis(1), out->axis(1) }
// { tv_a->axis(0), tv_b->axis(2), out->axis(2) }
struct AxisMapping {
AxesData a_axes;
AxesData b_axes;

static AxisMapping trivialMapping(size_t dimension);
};
using Expr::Expr;

MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init);
MmaOp(
IrBuilderPasskey,
Val* out,
Val* in_a,
Val* in_b,
Val* init,
const AxisMapping& axis_mapping);

MmaOp(
IrBuilderPasskey,
Val* out,
Val* in_a,
Val* in_b,
Val* init,
const AxisMapping& axis_mapping,
const MmaMacro& options);

NVFUSER_DECLARE_CLONE_AND_CREATE
Expand Down Expand Up @@ -1426,32 +1471,17 @@ class NVF_API MmaOp : public Expr {

void setMacro(MmaMacro options);

const auto& mAxes() const {
return attribute<AxesData>(ATTR_POS_M_AXES);
}

const auto& nAxes() const {
return attribute<AxesData>(ATTR_POS_N_AXES);
}

const auto& kAxes() const {
return attribute<AxesData>(ATTR_POS_K_AXES);
}

const auto& batchAxes() const {
return attribute<AxesData>(ATTR_POS_BATCH_AXES);
const AxisMapping& axisMapping() const {
return attribute<AxisMapping>(ATTR_POS_AXIS_MAPPING);
}

private:
// Predefined idexes of attributes stored for this IR node, to avoid
// Predefined indices of attributes stored for this IR node, to avoid
// magic numbers, based on order in which attributes are initialized
// in constructor
static constexpr size_t ATTR_POS_INIT = 0;
static constexpr size_t ATTR_POS_MACRO = 1;
static constexpr size_t ATTR_POS_M_AXES = 2;
static constexpr size_t ATTR_POS_N_AXES = 3;
static constexpr size_t ATTR_POS_K_AXES = 4;
static constexpr size_t ATTR_POS_BATCH_AXES = 5;
static constexpr size_t ATTR_POS_AXIS_MAPPING = 2;
};

//! The semantics are identical to torch.broadcast_to.
Expand Down
50 changes: 26 additions & 24 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1980,12 +1980,24 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedWelfordOp)

//==============================================================================================================================

MmaOp::AxisMapping MmaOp::AxisMapping::trivialMapping(size_t dimension) {
AxesData a_axes, b_axes;
a_axes.reserve(dimension);
b_axes.reserve(dimension);
for (size_t i : c10::irange(dimension)) {
a_axes.push_back((int64_t)i);
b_axes.push_back((int64_t)i);
}
return {a_axes, b_axes};
}

MmaOp::MmaOp(
IrBuilderPasskey passkey,
Val* out,
Val* in_a,
Val* in_b,
Val* init)
Val* init,
const AxisMapping& axis_mapping)
: Expr(passkey) {
NVF_ERROR(
out->getValType().value() == ValType::TensorView ||
Expand All @@ -2002,35 +2014,24 @@ MmaOp::MmaOp(
in_b->getValType().value() == ValType::TensorIndex,
in_b->getValType().value());

NVF_ERROR(
axis_mapping.a_axes.size() == axis_mapping.b_axes.size(),
"Must have the same number of axis positions in axis mapping for each operand");

auto* out_tv = ir_utils::getTv(out);
NVF_ERROR(
axis_mapping.a_axes.size() == out_tv->getMaybeRootDomain().size(),
"Must have the same number of axis positions in axis mapping as output root dimensions");

addOutput(out);
addInput(in_a);
addInput(in_b);
// ATTR_POS_INIT
addAttribute(init);
// ATTR_POS_MACRO
addDataAttribute(MmaMacro::NoMMA);
// ATTR_POS_M_AXES
addDataAttribute(AxesData{});
// ATTR_POS_N_AXES
addDataAttribute(AxesData{});
// ATTR_POS_K_AXES
addDataAttribute(AxesData{});
// ATTR_POS_BATCH_AXES
addDataAttribute(AxesData{});

MmaOpUtils::MmaOpDetails mma_details;
// Detailed consistency checks for use case with TensorViews as
// inputs/output
if (in_a->isA<TensorView>() && in_b->isA<TensorView>() &&
out->isA<TensorView>()) {
mma_details = MmaOpUtils::getMmaOpDetails(
out->as<TensorView>(), in_a->as<TensorView>(), in_b->as<TensorView>());
}

attribute<AxesData>(ATTR_POS_M_AXES) = std::move(mma_details.m_axes);
attribute<AxesData>(ATTR_POS_N_AXES) = std::move(mma_details.n_axes);
attribute<AxesData>(ATTR_POS_K_AXES) = std::move(mma_details.k_axes);
attribute<AxesData>(ATTR_POS_BATCH_AXES) = std::move(mma_details.batch_axes);
// ATTR_POS_AXIS_MAPPING
addDataAttribute(axis_mapping);
}

MmaOp::MmaOp(
Expand All @@ -2039,8 +2040,9 @@ MmaOp::MmaOp(
Val* in_a,
Val* in_b,
Val* init,
const AxisMapping& axis_mapping,
const MmaMacro& macro)
: MmaOp(passkey, out, in_a, in_b, init) {
: MmaOp(passkey, out, in_a, in_b, init, axis_mapping) {
attribute<MmaMacro>(ATTR_POS_MACRO) = macro;
}

Expand Down
155 changes: 0 additions & 155 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1265,158 +1265,3 @@ int64_t getOperationCount(Val* val) {
}

} // namespace nvfuser::ir_utils

namespace nvfuser::MmaOpUtils {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only needed for defining MmaOp::mAxes() and friends, but:

  1. Those methods are never used so I removed them and
  2. We can reconstruct that information easily using mma->axisMapping().


// A helper for gathering details about TensorView object
TensorViewDetails getDetailsFor(const std::vector<IterDomain*>& dims) {
TensorViewDetails details;
for (auto pos : c10::irange((int64_t)dims.size())) {
const auto axis = dims.at(pos);
if (axis->isReduction()) {
details.rdomains.push_back(pos);
} else if (axis->isBroadcast()) {
details.bcasts.push_back(pos);
} else {
details.cdomains.push_back(pos);
}
}
return details;
}

MmaOpDetails getMmaOpDetails(
TensorView* out,
TensorView* in_a,
TensorView* in_b) {
const auto in_a_details =
getDetailsFor(TensorDomain::noDevices(in_a->getLogicalDomain()));
const auto in_b_details =
getDetailsFor(TensorDomain::noDevices(in_b->getLogicalDomain()));
const auto out_details =
getDetailsFor(TensorDomain::noDevices(out->getMaybeRootDomain()));

using AxesData = MmaOp::AxesData;

const auto getMOrNaxes = [](const AxesData& cdomains,
const AxesData& bcasts,
const AxesData& rdomains) {
AxesData result;
// For all concrete domains
for (const auto& cdomain : cdomains) {
// That are in broadcast domains but are not in reduction domains
if ((std::find(bcasts.begin(), bcasts.end(), cdomain) != bcasts.end()) &&
(std::find(rdomains.begin(), rdomains.end(), cdomain) ==
rdomains.end())) {
result.push_back(cdomain);
}
}
return result;
};

const auto getKaxes = [](const AxesData& cdomains_a,
const AxesData& cdomains_b,
const AxesData& rdomains) {
AxesData result;
// For all concrete domains from in_a
for (const auto& cdomain_a : cdomains_a) {
// That are in concrete domains in in_b and are in reduction domains
if ((std::find(cdomains_b.begin(), cdomains_b.end(), cdomain_a) !=
cdomains_b.end()) &&
(std::find(rdomains.begin(), rdomains.end(), cdomain_a) !=
rdomains.end())) {
result.push_back(cdomain_a);
}
}
return result;
};

const auto getBatchAxes = [](const TensorViewDetails& in_a_details,
const TensorViewDetails& in_b_details,
const TensorViewDetails& out_details) {
AxesData result;
// Batch candidates:
// concrete domains that are in all of inputs and output
for (const auto& domain : in_a_details.cdomains) {
if ((std::find(
in_b_details.cdomains.begin(),
in_b_details.cdomains.end(),
domain) != in_b_details.cdomains.end()) &&
(std::find(
out_details.cdomains.begin(),
out_details.cdomains.end(),
domain) != out_details.cdomains.end())) {
result.push_back(domain);
}
}
// Batch candidates:
// broadcast domains that are in all of inputs and output
for (const auto& domain : in_a_details.bcasts) {
if ((std::find(
in_b_details.bcasts.begin(),
in_b_details.bcasts.end(),
domain) != in_b_details.bcasts.end()) &&
(std::find(
out_details.bcasts.begin(), out_details.bcasts.end(), domain) !=
out_details.bcasts.end())) {
result.push_back(domain);
}
}
std::sort(result.begin(), result.end());
return result;
};

const auto validateInputDetails = [](const TensorViewDetails& details,
const std::string& desc) {
NVF_ERROR(!details.bcasts.empty(), desc, ": has no broadcast domains.");
NVF_ERROR(details.rdomains.empty(), desc, ": has reduction domains.");
NVF_ERROR(
details.cdomains.size() >= expected_gemm_cdomains,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
};

const auto validateOutputDetails = [](const TensorViewDetails& details,
const std::string& desc) {
// TODO: revise rules when add support for batch gemms
NVF_ERROR(!details.rdomains.empty(), desc, ": has no reduction domains.");
NVF_ERROR(
(details.cdomains.size() >= expected_gemm_cdomains),
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
};

validateInputDetails(in_a_details, "MmaOp input A");
validateInputDetails(in_b_details, "MmaOp input B");
validateOutputDetails(out_details, "MmaOp output");

MmaOpDetails details;

// For details, check MmaOpDetails
details.m_axes = getMOrNaxes(
in_a_details.cdomains, in_b_details.bcasts, out_details.rdomains);
details.n_axes = getMOrNaxes(
in_b_details.cdomains, in_a_details.bcasts, out_details.rdomains);
details.k_axes = getKaxes(
in_a_details.cdomains, in_b_details.cdomains, out_details.rdomains);
details.batch_axes = getBatchAxes(in_a_details, in_b_details, out_details);

NVF_ERROR(
!details.m_axes.empty(),
"MmaOp inputs must define at least a single M dimension");
NVF_ERROR(
!details.n_axes.empty(),
"MmaOp inputs must define at least a single N dimension");
NVF_ERROR(
!details.k_axes.empty(),
"MmaOp inputs must define at least a single K dimension");

return details;
}

} // namespace nvfuser::MmaOpUtils
Loading
Loading