diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index b5a41658101..1ac49f5f2b4 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -2087,8 +2087,8 @@ void IndexLowering::handle(const MmaOp* mma) { } const auto out = lowerDstIndex( mma->out(), {}, false, getMmaOutType(mma->out()->as())); - auto mma_indexed = - IrBuilder::create(out, a, b, mma->init(), mma->macro()); + auto mma_indexed = IrBuilder::create( + out, a, b, mma->init(), mma->axisMapping(), mma->macro()); pushBack(mma_indexed); GpuLower::current()->propagateExprInfo(mma, back()); } diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 56eef906f0e..f77a7520e9a 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -633,6 +633,7 @@ class ReplaceExprInput : private kir::ExprMutator { replaced_inputs->at(node->inA()), replaced_inputs->at(node->inB()), node->init(), + node->axisMapping(), node->macro()); registerReplaceWithPredicate(node, replacement); } diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 4d5c41e3ec2..0a0f4bcbe9a 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -1359,9 +1359,53 @@ class GroupedWelfordOp : public Expr { class NVF_API MmaOp : public Expr { public: using AxesData = std::vector; + // AxisMapping denotes the pairing of two input dimensions to produce an + // output dimension. It holds two vectors of integers indicating the + // corresponding position of each output axis in either the A or B input. + // Positions refer to the noReductions logical domain of each input. + // NOTE: Axis positions are absolute, meaning you cannot specify them + // relative to the last dimension since -1 has special meaning. + // NOTE: -1 indicates that the axis does not exist, so Broadcast input + // domains should be listed with their actual position and not -1. + // + // Example 1: + // a [ K, 1, M ] + // b [ 1, N, K ] + // out [ M, N, rK ] + // axisMapping: + // a_axes = [ 2, 1, 0 ] + // b_axes = [ 0, 1, 2 ] + // This results in the following groups of mapped axes: + // { tv_a->axis(2), tv_b->axis(0), out->axis(0) } + // { tv_a->axis(1), tv_b->axis(1), out->axis(1) } + // { tv_a->axis(0), tv_b->axis(2), out->axis(2) } + // + // Example 1: + // a [ K, M ] + // b [ 1, N, K ] + // out [ M, N, rK ] + // axisMapping: + // a_axes = [ 1, -1, 0 ] + // b_axes = [ 0, 1, 2 ] + // This results in the following groups of mapped axes: + // { tv_a->axis(1), tv_b->axis(0), out->axis(0) } + // { tv_b->axis(1), out->axis(1) } + // { tv_a->axis(0), tv_b->axis(2), out->axis(2) } + struct AxisMapping { + AxesData a_axes; + AxesData b_axes; + + static AxisMapping trivialMapping(size_t dimension); + }; using Expr::Expr; - MmaOp(IrBuilderPasskey, Val* out, Val* in_a, Val* in_b, Val* init); + MmaOp( + IrBuilderPasskey, + Val* out, + Val* in_a, + Val* in_b, + Val* init, + const AxisMapping& axis_mapping); MmaOp( IrBuilderPasskey, @@ -1369,6 +1413,7 @@ class NVF_API MmaOp : public Expr { Val* in_a, Val* in_b, Val* init, + const AxisMapping& axis_mapping, const MmaMacro& options); NVFUSER_DECLARE_CLONE_AND_CREATE @@ -1426,32 +1471,17 @@ class NVF_API MmaOp : public Expr { void setMacro(MmaMacro options); - const auto& mAxes() const { - return attribute(ATTR_POS_M_AXES); - } - - const auto& nAxes() const { - return attribute(ATTR_POS_N_AXES); - } - - const auto& kAxes() const { - return attribute(ATTR_POS_K_AXES); - } - - const auto& batchAxes() const { - return attribute(ATTR_POS_BATCH_AXES); + const AxisMapping& axisMapping() const { + return attribute(ATTR_POS_AXIS_MAPPING); } private: - // Predefined idexes of attributes stored for this IR node, to avoid + // Predefined indices of attributes stored for this IR node, to avoid // magic numbers, based on order in which attributes are initialized // in constructor static constexpr size_t ATTR_POS_INIT = 0; static constexpr size_t ATTR_POS_MACRO = 1; - static constexpr size_t ATTR_POS_M_AXES = 2; - static constexpr size_t ATTR_POS_N_AXES = 3; - static constexpr size_t ATTR_POS_K_AXES = 4; - static constexpr size_t ATTR_POS_BATCH_AXES = 5; + static constexpr size_t ATTR_POS_AXIS_MAPPING = 2; }; //! The semantics are identical to torch.broadcast_to. diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 1aaac849433..af3a02c3e70 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -1980,12 +1980,24 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedWelfordOp) //============================================================================================================================== +MmaOp::AxisMapping MmaOp::AxisMapping::trivialMapping(size_t dimension) { + AxesData a_axes, b_axes; + a_axes.reserve(dimension); + b_axes.reserve(dimension); + for (size_t i : c10::irange(dimension)) { + a_axes.push_back((int64_t)i); + b_axes.push_back((int64_t)i); + } + return {a_axes, b_axes}; +} + MmaOp::MmaOp( IrBuilderPasskey passkey, Val* out, Val* in_a, Val* in_b, - Val* init) + Val* init, + const AxisMapping& axis_mapping) : Expr(passkey) { NVF_ERROR( out->getValType().value() == ValType::TensorView || @@ -2002,6 +2014,15 @@ MmaOp::MmaOp( in_b->getValType().value() == ValType::TensorIndex, in_b->getValType().value()); + NVF_ERROR( + axis_mapping.a_axes.size() == axis_mapping.b_axes.size(), + "Must have the same number of axis positions in axis mapping for each operand"); + + auto* out_tv = ir_utils::getTv(out); + NVF_ERROR( + axis_mapping.a_axes.size() == out_tv->getMaybeRootDomain().size(), + "Must have the same number of axis positions in axis mapping as output root dimensions"); + addOutput(out); addInput(in_a); addInput(in_b); @@ -2009,28 +2030,8 @@ MmaOp::MmaOp( addAttribute(init); // ATTR_POS_MACRO addDataAttribute(MmaMacro::NoMMA); - // ATTR_POS_M_AXES - addDataAttribute(AxesData{}); - // ATTR_POS_N_AXES - addDataAttribute(AxesData{}); - // ATTR_POS_K_AXES - addDataAttribute(AxesData{}); - // ATTR_POS_BATCH_AXES - addDataAttribute(AxesData{}); - - MmaOpUtils::MmaOpDetails mma_details; - // Detailed consistency checks for use case with TensorViews as - // inputs/output - if (in_a->isA() && in_b->isA() && - out->isA()) { - mma_details = MmaOpUtils::getMmaOpDetails( - out->as(), in_a->as(), in_b->as()); - } - - attribute(ATTR_POS_M_AXES) = std::move(mma_details.m_axes); - attribute(ATTR_POS_N_AXES) = std::move(mma_details.n_axes); - attribute(ATTR_POS_K_AXES) = std::move(mma_details.k_axes); - attribute(ATTR_POS_BATCH_AXES) = std::move(mma_details.batch_axes); + // ATTR_POS_AXIS_MAPPING + addDataAttribute(axis_mapping); } MmaOp::MmaOp( @@ -2039,8 +2040,9 @@ MmaOp::MmaOp( Val* in_a, Val* in_b, Val* init, + const AxisMapping& axis_mapping, const MmaMacro& macro) - : MmaOp(passkey, out, in_a, in_b, init) { + : MmaOp(passkey, out, in_a, in_b, init, axis_mapping) { attribute(ATTR_POS_MACRO) = macro; } diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 868ba36144d..2a353c7cfdc 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -1265,158 +1265,3 @@ int64_t getOperationCount(Val* val) { } } // namespace nvfuser::ir_utils - -namespace nvfuser::MmaOpUtils { - -// A helper for gathering details about TensorView object -TensorViewDetails getDetailsFor(const std::vector& dims) { - TensorViewDetails details; - for (auto pos : c10::irange((int64_t)dims.size())) { - const auto axis = dims.at(pos); - if (axis->isReduction()) { - details.rdomains.push_back(pos); - } else if (axis->isBroadcast()) { - details.bcasts.push_back(pos); - } else { - details.cdomains.push_back(pos); - } - } - return details; -} - -MmaOpDetails getMmaOpDetails( - TensorView* out, - TensorView* in_a, - TensorView* in_b) { - const auto in_a_details = - getDetailsFor(TensorDomain::noDevices(in_a->getLogicalDomain())); - const auto in_b_details = - getDetailsFor(TensorDomain::noDevices(in_b->getLogicalDomain())); - const auto out_details = - getDetailsFor(TensorDomain::noDevices(out->getMaybeRootDomain())); - - using AxesData = MmaOp::AxesData; - - const auto getMOrNaxes = [](const AxesData& cdomains, - const AxesData& bcasts, - const AxesData& rdomains) { - AxesData result; - // For all concrete domains - for (const auto& cdomain : cdomains) { - // That are in broadcast domains but are not in reduction domains - if ((std::find(bcasts.begin(), bcasts.end(), cdomain) != bcasts.end()) && - (std::find(rdomains.begin(), rdomains.end(), cdomain) == - rdomains.end())) { - result.push_back(cdomain); - } - } - return result; - }; - - const auto getKaxes = [](const AxesData& cdomains_a, - const AxesData& cdomains_b, - const AxesData& rdomains) { - AxesData result; - // For all concrete domains from in_a - for (const auto& cdomain_a : cdomains_a) { - // That are in concrete domains in in_b and are in reduction domains - if ((std::find(cdomains_b.begin(), cdomains_b.end(), cdomain_a) != - cdomains_b.end()) && - (std::find(rdomains.begin(), rdomains.end(), cdomain_a) != - rdomains.end())) { - result.push_back(cdomain_a); - } - } - return result; - }; - - const auto getBatchAxes = [](const TensorViewDetails& in_a_details, - const TensorViewDetails& in_b_details, - const TensorViewDetails& out_details) { - AxesData result; - // Batch candidates: - // concrete domains that are in all of inputs and output - for (const auto& domain : in_a_details.cdomains) { - if ((std::find( - in_b_details.cdomains.begin(), - in_b_details.cdomains.end(), - domain) != in_b_details.cdomains.end()) && - (std::find( - out_details.cdomains.begin(), - out_details.cdomains.end(), - domain) != out_details.cdomains.end())) { - result.push_back(domain); - } - } - // Batch candidates: - // broadcast domains that are in all of inputs and output - for (const auto& domain : in_a_details.bcasts) { - if ((std::find( - in_b_details.bcasts.begin(), - in_b_details.bcasts.end(), - domain) != in_b_details.bcasts.end()) && - (std::find( - out_details.bcasts.begin(), out_details.bcasts.end(), domain) != - out_details.bcasts.end())) { - result.push_back(domain); - } - } - std::sort(result.begin(), result.end()); - return result; - }; - - const auto validateInputDetails = [](const TensorViewDetails& details, - const std::string& desc) { - NVF_ERROR(!details.bcasts.empty(), desc, ": has no broadcast domains."); - NVF_ERROR(details.rdomains.empty(), desc, ": has reduction domains."); - NVF_ERROR( - details.cdomains.size() >= expected_gemm_cdomains, - desc, - ": has unsupported number of concrete domains, expected at least ", - expected_gemm_cdomains, - ", got ", - details.cdomains.size()); - }; - - const auto validateOutputDetails = [](const TensorViewDetails& details, - const std::string& desc) { - // TODO: revise rules when add support for batch gemms - NVF_ERROR(!details.rdomains.empty(), desc, ": has no reduction domains."); - NVF_ERROR( - (details.cdomains.size() >= expected_gemm_cdomains), - desc, - ": has unsupported number of concrete domains, expected at least ", - expected_gemm_cdomains, - ", got ", - details.cdomains.size()); - }; - - validateInputDetails(in_a_details, "MmaOp input A"); - validateInputDetails(in_b_details, "MmaOp input B"); - validateOutputDetails(out_details, "MmaOp output"); - - MmaOpDetails details; - - // For details, check MmaOpDetails - details.m_axes = getMOrNaxes( - in_a_details.cdomains, in_b_details.bcasts, out_details.rdomains); - details.n_axes = getMOrNaxes( - in_b_details.cdomains, in_a_details.bcasts, out_details.rdomains); - details.k_axes = getKaxes( - in_a_details.cdomains, in_b_details.cdomains, out_details.rdomains); - details.batch_axes = getBatchAxes(in_a_details, in_b_details, out_details); - - NVF_ERROR( - !details.m_axes.empty(), - "MmaOp inputs must define at least a single M dimension"); - NVF_ERROR( - !details.n_axes.empty(), - "MmaOp inputs must define at least a single N dimension"); - NVF_ERROR( - !details.k_axes.empty(), - "MmaOp inputs must define at least a single K dimension"); - - return details; -} - -} // namespace nvfuser::MmaOpUtils diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 60062b0e440..d0cf1db7bbd 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -23,39 +23,6 @@ namespace nvfuser::MmaOpUtils { // The expected number of concrete domains for gemm constexpr size_t expected_gemm_cdomains = 2; -// A helper structure used to gather all data created during analysis -struct MmaOpDetails { - using AxesData = MmaOp::AxesData; - // Concrete axes from A that are broadcast in B and are not - // reduction in output - AxesData m_axes; - // Concrete axes from B that are broadcast in A and are not - // reduction in output - AxesData n_axes; - // Concrete axes from A that are concrete in B and are - // reduction in output - AxesData k_axes; - // Concrete or broadcast axes that are present in all inputs - // and output - AxesData batch_axes; -}; - -// A helper structure with pieces of information about TensorView -struct TensorViewDetails { - using AxesData = MmaOp::AxesData; - // Broadcast domains - AxesData bcasts; - // Reduction domains - AxesData rdomains; - // Concrete domains - AxesData cdomains; -}; - -MmaOpDetails getMmaOpDetails( - TensorView* out, - TensorView* in_a, - TensorView* in_b); - void verifyMmaOpForEvaluation(MmaOp* mma_op, DataType expected_input_dtype); struct MatmulInputs { diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 5f0202b4da2..f49d0a9466f 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -185,6 +185,24 @@ std::unordered_map PairwiseLogicalDomainMap::map( } }; + if (auto* mma = dynamic_cast(consumer_tv_->definition())) { + // producer_tv_ is either A or B + const MmaOp::AxesData& operand_axes = producer_tv_ == mma->inA() + ? mma->axisMapping().a_axes + : mma->axisMapping().b_axes; + NVF_ERROR(operand_axes.size() == consumer_root.size()); + for (size_t idx : c10::irange(operand_axes.size())) { + int64_t operand_pos = operand_axes[idx]; + if (operand_pos == -1) { + continue; + } + IterDomain* operand_id = producer_logical.at((size_t)operand_pos); + IterDomain* out_id = consumer_root.at(idx); + updatePairwiseLogicalDomainMap(operand_id, out_id); + } + return dom_map; + } + // For MatmulOp, use the corresponding mapped input iterdomains. if (MatmulOp* op = dynamic_cast(consumer_tv_->definition())) { // Check if the producer is lhs/rhs input diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index df3cccc2109..7684e406fef 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -2087,101 +2088,111 @@ TensorView* viewAsScalar(TensorView* inp) { return out; } -namespace { - -//! Create new output for mma -static TensorView* newForMma( +TensorView* fusedMultiplySum( TensorView* tv_a, TensorView* tv_b, - const std::vector& axes, - DataType data_type = DataType::Float) { - auto orig_domain_a = TensorDomain::noReductions(tv_a->getLogicalDomain()); - auto orig_domain_b = TensorDomain::noReductions(tv_b->getLogicalDomain()); - - NVF_ERROR( - orig_domain_a.size() == orig_domain_b.size(), - "MMA op: need matching dim input"); - - std::set axes_set(axes.begin(), axes.end()); - std::vector new_domain; - - NVF_ERROR( - !axes_set.empty(), - "Asked for output of reduction, but no reduction axis provided."); - - NVF_ERROR( - (*(axes_set.rbegin())) < orig_domain_a.size(), - "Error setting up reduction, reduction axis (", - *(axes_set.rbegin()), - ") is outside nDims (", - orig_domain_a.size(), - "). Keep in mind reductions are relative to root domains, not modified views."); - - auto axis_iter = axes_set.begin(); - for (const auto dim : c10::irange(orig_domain_a.size())) { - bool is_reduction = false; - if (axis_iter != axes_set.end() && *axis_iter == dim) { - is_reduction = true; - axis_iter++; - } + const std::vector& axes, + Val* init, + const std::optional& axis_mapping_opt) { + const std::vector& a_logical = + TensorDomain::noReductions(tv_a->getLogicalDomain()); + const std::vector& b_logical = + TensorDomain::noReductions(tv_b->getLogicalDomain()); - const IterDomain* id = orig_domain_a[dim]->isBroadcast() - ? orig_domain_b[dim] - : orig_domain_a[dim]; + NVF_CHECK( + !a_logical.empty() && !b_logical.empty(), + "Tried to reduce a 0-dim tensor"); + std::unique_ptr axis_mapping_ptr; + if (!axis_mapping_opt.has_value()) { NVF_CHECK( - !(is_reduction && id->isBroadcast() && !id->isImplicitBroadcast()), - "Cannot reduce an axis that is marked as broadcasted as it has an undetermined size. Tried to reduce ID = ", - id, - " of tensor ", - tv_a, - "and", - tv_b); - - new_domain.push_back( - IterDomainBuilder(id->start(), id->extent()) - .stop_offset(id->stopOffset()) - .iter_type(is_reduction ? IterType::Reduction : id->getIterType()) - .build()); + a_logical.size() == b_logical.size(), + "If tv_a and tv_b have different dimensions, axis_mapping_opt must be provided"); + axis_mapping_ptr = std::make_unique( + MmaOp::AxisMapping::trivialMapping(a_logical.size())); } + const MmaOp::AxisMapping& axis_mapping = + axis_mapping_opt.has_value() ? *axis_mapping_opt : *axis_mapping_ptr; - TensorDomain* td = IrBuilder::create( - new_domain, TensorDomain::getContiguityFilledWith(new_domain, true)); - - return IrBuilder::create(td, data_type); -} - -} // namespace - -TensorView* fusedMultiplySum( - TensorView* tv_a, - TensorView* tv_b, - const std::vector& axes, - Val* init) { - // TODO: - // Validate axis relationships between a and b - NVF_CHECK(tv_a->nDims() > 0, "Tried to reduce a 0-dim tensor"); + NVF_CHECK( + axis_mapping.a_axes.size() == axis_mapping.b_axes.size(), + "Axis mapping should contain same number of output axes for each operand"); + const size_t out_dims = axis_mapping.a_axes.size(); + + std::unordered_set axes_set; + for (int64_t axis : axes) { + if (axis < 0) { + axis += (int64_t)out_dims; + } + NVF_ERROR(axis >= 0 && axis < (int64_t)out_dims); + axes_set.insert((size_t)axis); + } // TODO: // Add tf32 and other mma data types // Add fallback path for non-mma data types. NVF_CHECK( - tv_a->getDataType().value() == DataType::Half || - tv_a->getDataType().value() == DataType::BFloat16); - NVF_CHECK(tv_a->getDataType().value() == tv_b->getDataType().value()); + tv_a->dtype() == DataType::Half || tv_a->dtype() == DataType::BFloat16); + NVF_CHECK(tv_a->dtype() == tv_b->dtype()); + DataType out_dtype = DataType::Float; - NVF_CHECK(!axes.empty(), "No reduction axis specified"); + // Prepare output domain based on domain mapping and IterTypes of inputs + std::vector out_domain; + out_domain.reserve(axis_mapping.a_axes.size()); + for (size_t i : c10::irange(out_dims)) { + int64_t a_pos = axis_mapping.a_axes[i]; + int64_t b_pos = axis_mapping.b_axes[i]; + NVF_CHECK( + a_pos != -1 || b_pos != -1, + "Output axis ", + i, + " cannot be missing in both operands"); + NVF_CHECK( + a_pos == -1 || (a_pos >= 0 && a_pos < (int64_t)a_logical.size()), + "Position ", + i, + " in output of axis mapping for operand A is ", + a_pos, + " which is out of bounds for A which has dimension ", + a_logical.size()); + NVF_CHECK( + b_pos == -1 || (b_pos >= 0 && b_pos < (int64_t)b_logical.size()), + "Position ", + i, + " in output of axis mapping for operand B is ", + b_pos, + " which is out of bounds for B which has dimension ", + b_logical.size()); + IterDomain* a_id = a_pos == -1 ? nullptr : a_logical[(size_t)a_pos]; + IterDomain* b_id = b_pos == -1 ? nullptr : b_logical[(size_t)b_pos]; + + bool a_concrete = a_id == nullptr ? false : !a_id->isBroadcast(); + bool b_concrete = b_id == nullptr ? false : !b_id->isBroadcast(); + // NOTE: we can have !a_concrete && !b_concrete if there are broadcast batch + // dims + + // Check for K dimensions + bool is_reduction = false; + if (axes_set.count(i)) { + NVF_CHECK( + a_concrete && b_concrete, + "Reduction dimensions must be concrete in both operands"); + is_reduction = true; + } - // TODO: - // will lift this in a follow up when we have a - // more generic axes matching. - NVF_CHECK( - axes.size() == 1, "Single axis reduction only for mma op instantiation.") + IterDomain* orig_id = a_concrete ? a_id : b_id; + out_domain.push_back( + IterDomainBuilder(orig_id->start(), orig_id->extent()) + .stop_offset(orig_id->stopOffset()) + .iter_type( + is_reduction ? IterType::Reduction : orig_id->getIterType()) + .build()); + } - std::vector uint_axes = ops::canonicalizeAxes( - axes, (int64_t)tv_a->domain()->noReductions().size()); + TensorDomain* td = IrBuilder::create( + out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)); - TensorView* out = newForMma(tv_a, tv_b, uint_axes); + TensorView* out = IrBuilder::create(td, out_dtype); if (init == nullptr) { init = IrBuilder::create(0.0, out->dtype()); @@ -2193,11 +2204,8 @@ TensorView* fusedMultiplySum( NVF_CHECK( init->isConstScalar(), "Cannot create a reduction operation where the initial value is not a const scalar."); - NVF_CHECK( - init->dtype() == out->dtype(), - "Init value dtype for fusedMultiplySum must match output."); - IrBuilder::create(out, tv_a, tv_b, init); + IrBuilder::create(out, tv_a, tv_b, init, axis_mapping); return out; } diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 4070151e0de..d8ea10038ad 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -684,8 +684,9 @@ TensorView* viewAsScalar(TensorView* inp); //! //! \param tv_a first multiply operand //! \param tv_b second multiply operand -//! \param axes axes to sum over +//! \param axes axes to sum over, relative to output loop domain //! \param init sum initial value +//! \param axis_mapping_opt mapping from output axes to operand axes //! //! Note & TODO: //! currently only support lowering to a mma op @@ -696,7 +697,8 @@ NVF_API TensorView* fusedMultiplySum( TensorView* tv_a, TensorView* tv_b, const std::vector& axes, - Val* init = nullptr); + Val* init = nullptr, + const std::optional& axis_mapping_opt = std::nullopt); // Create a tensor view from the given value. The given value can be a single // scalar, an array of scalars, or a nested array of scalars. diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index 2fd62cda21f..06474538cf7 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1722,7 +1722,12 @@ MmaOp* MatmulPattern::translateToMmaOp() { } else if (output->definition()->isA()) { Val* init = IrBuilder::create(0.0, output->dtype()); // This replaces the mul and sum by overwriting output->definition() - return IrBuilder::create(output, A, B, init); + return IrBuilder::create( + output, + A, + B, + init, + MmaOp::AxisMapping::trivialMapping(output->nDims())); } // This will hold the translated output from MatmulOp or LinearOp diff --git a/csrc/tensor_view.cpp b/csrc/tensor_view.cpp index 29d80556e7b..b85c4947b1c 100644 --- a/csrc/tensor_view.cpp +++ b/csrc/tensor_view.cpp @@ -852,6 +852,7 @@ TensorView* TensorView::rFactor(const std::vector& axes) { this_mma->inA(), this_mma->inB(), this_mma->init(), + this_mma->axisMapping(), this_mma->macro()); // Remaining reduction that can be scheduled cross diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index e284ccb4e08..58aeffa3025 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3409,4 +3409,74 @@ TEST_F(MatmulSchedulerTest, HSH_NN) { EXPECT_FALSE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +// Test a kernel with MmaOp that has inputs like [B, M, K] and [B, K, N] +// and whose output is [B, M, N, K] +TEST_F(MatmulSchedulerTest, MmaOpAxisMapping) { + Fusion fusion; + FusionGuard fg(&fusion); + + int64_t B = 4, M = 16, N = 32, K = 64; + DataType dtype = DataType::Half; + + auto tv0 = makeConcreteTensor({B, M, K}, dtype); + auto tv1 = makeConcreteTensor({B, K, N}, dtype); + fusion.addInput(tv0); + fusion.addInput(tv1); + + // Just doing a gmem->smem copy + tv0 = set(tv0); + tv0->setMemoryType(MemoryType::Shared); + tv1 = set(tv1); + tv1->setMemoryType(MemoryType::Shared); + + MmaOp::AxisMapping axis_mapping{ + /*a_axes=*/{0, 1, -1, 2}, + /*b_axes=*/{0, -1, 2, 1}}; + + auto tv2 = + fusedMultiplySum(tv0, tv1, /*axes=*/{-1}, /*init=*/nullptr, axis_mapping); + + fusion.addOutput(tv2); + + std::vector patterns = + mma_utils::findMatmulPatterns(&fusion); + + ASSERT_EQ(patterns.size(), 1); + const mma_utils::MatmulPattern& pattern = patterns.front(); + + IdModel id_model(&fusion); + const ValGraph& permissive_graph = + id_model.idGraph(IdMappingMode::PERMISSIVE); + + mma_utils::DimRolesMap dim_roles = pattern.getDimRoles(id_model); + EXPECT_FALSE(dim_roles.empty()); + + auto checkAxisRoles = [&permissive_graph, &dim_roles]( + TensorView* tv, + const std::vector& roles) { + ASSERT_EQ(tv->nDims(), (int64_t)roles.size()); + for (size_t i : c10::irange(roles.size())) { + IterDomain* id = tv->axis(i); + MatmulDimRole role = roles[i]; + ValGroup vg = permissive_graph.toGroup(id); + auto it = dim_roles.find(vg); + ASSERT_FALSE(it == dim_roles.end()) + << "Could not find role for " << id->toString() << " in " + << tv->toString(); + EXPECT_TRUE(it->second == role) + << "Role mismatch for " << id->toString() << " in " << tv->toString(); + } + }; + checkAxisRoles( + tv0, {MatmulDimRole::Batch, MatmulDimRole::M, MatmulDimRole::K}); + checkAxisRoles( + tv1, {MatmulDimRole::Batch, MatmulDimRole::K, MatmulDimRole::N}); + checkAxisRoles( + tv2, + {MatmulDimRole::Batch, + MatmulDimRole::M, + MatmulDimRole::N, + MatmulDimRole::K}); +} + } // namespace nvfuser