Skip to content

Commit

Permalink
Refactor MultiMatmulSchedulers (#3277)
Browse files Browse the repository at this point in the history
The PR fixes #3266. 

* Future consolidation is possible but I kept some duplicate functions
in `HopperMultiMatrixScheduler` and `AmpereMultiMatrixScheduler` for
flexibility.

### Changes:
- Move `isPowOf2` to `csrc/utils.h`
- Move `representativeId` to `scheduler/tools/abstract_tensor.h`
- Move `checkConcreteStaticDim to mma_utils.cpp`
- Add TODO to remove `swizzleSharedMemory` from
`HopperMultiMatrixScheduler`
- Create base class `MultiMatrixScheduler` to hold common functions like
`findPatterns`

### Details:
- Create base class `MultiMatrixScheduler`
- `HopperMultiMatrixScheduler` and `AmpereMultiMatrixScheduler` inherit
from `MultiMatrixScheduler` and overwrite the `run` function.
- `MultiMatrixScheduler` implements `findPatterns`, `translatePatterns`,
`findRoles`, `countDims`, and `updateIdModel`. It also holds the
necessary data members for those functions.
  • Loading branch information
rdspring1 authored Oct 29, 2024
1 parent 5db18de commit 7a3b1a4
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 450 deletions.
189 changes: 5 additions & 184 deletions csrc/scheduler/ampere_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <id_model/schedule.h>
#include <instrumentation.h>
#include <ir/utils.h>
#include <multidevice/utils.h>
#include <scheduler/ampere_multi_matmul.h>
#include <scheduler/debug_utils.h>
#include <scheduler/matmul.h>
Expand All @@ -19,6 +18,7 @@
#include <scheduler/tools/abstract_tensor.h>
#include <scheduler/tools/inlining.h>
#include <scheduler/utils.h>
#include <utils.h>
#include <val_graph.h>
#include <val_graph_visitor.h>

Expand All @@ -31,32 +31,6 @@ namespace nvfuser {

namespace {

// Returns true if given number is power of 2
constexpr bool isPowOf2(int64_t x) {
return x > 1 && (x & (x - 1)) == 0;
}

inline IterDomain* representativeId(const AbstractId& abs_id) {
if (abs_id.is<IterDomain*>()) {
return abs_id.as<IterDomain*>();
}
NVF_ERROR(abs_id.is<ValGroupAndItsGraph>());
return representativeId(abs_id.as<ValGroupAndItsGraph>().group);
}

// Utility to check concrete static size
inline void checkConcreteStaticDim(const AbstractId& abs_id) {
IterDomain* id = representativeId(abs_id);
NVF_ERROR(
!id->isBroadcast() && !id->isReduction(),
"no support for reduction or broadcast domains, but got ",
id->toString());
NVF_ERROR(
id->extent()->isConstInt(),
"swizzled dimension's extend must be known during scheduling, got ",
id->toString());
}

//! Automatically generates the shared memory swizzled data layout
//! for matmul mainloop and epilogue.
//! The shared mem data layout is always 2D currently, and this utility
Expand All @@ -76,8 +50,8 @@ AbstractTensor swizzleSharedMemory(TensorView* shared_mem_tv) {
(int64_t)swizzle_domain.size() >= 2,
"At least 2D input (excluding consecutive reduction domains starting from the innermost dim) needed for swizzling, but get ",
shared_mem_tv->toString());
checkConcreteStaticDim(swizzle_domain[-2]);
checkConcreteStaticDim(swizzle_domain[-1]);
mma_utils::checkConcreteStaticDim(swizzle_domain[-2]);
mma_utils::checkConcreteStaticDim(swizzle_domain[-1]);

// Extract the constant sizes of the swizzled tile
const int64_t tile_size_x =
Expand Down Expand Up @@ -522,98 +496,6 @@ void AmpereMultipleMatmulScheduler::cacheInputsAndOutputs() {
scheduler_utils::cacheAndForkOutputs(fusion_, /*unroll=*/true);
}

void AmpereMultipleMatmulScheduler::findPatterns() {
patterns_ = mma_utils::findMatmulPatterns(fusion_);
NVF_ERROR(!patterns_.empty(), "No matmul patterns were found");
}

void AmpereMultipleMatmulScheduler::countDims() {
NVF_ERROR(!patterns_.empty());
TensorView* mma_result = patterns_.front().output;
num_device_dims_ = numDeviceDims(mma_result);
for (const auto& it : id_roles_) {
if (it.second == MatmulDimRole::Batch &&
// Skip device dims
!std::any_of(it.first->begin(), it.first->end(), [](Val* v) {
return v->as<IterDomain>()->isDeviceDim();
})) {
// All batch dims will be merged into one, if any exist
num_local_batch_dims_ = 1;
}
}
num_splitk_dims_ = params_->splitk_factor > 1 ? 1 : 0;
// Subtract 6 for the [Mo, No, Ko, Mi, Ni, Ki]
num_device_and_batch_dims_ = num_device_dims_ + num_local_batch_dims_;
}

void AmpereMultipleMatmulScheduler::translatePatterns() {
mma_results_.reserve(patterns_.size());
for (mma_utils::MatmulPattern& pattern : patterns_) {
MmaOp* mma = pattern.translateToMmaOp();
mma_results_.push_back(mma->out()->as<TensorView>());
}

// Build IdModel graphs now since translateToMmaOp creates new TVs. Before
// this point the graphs are not yet built.
updateIdModel();
}

// Get tensor roles and id roles
// When there are multiple matmul patterns, we can have conflicting roles.
// For now we throw an error if this is the case.
// TODO: This should be checked in canScheduleCompileTime
void AmpereMultipleMatmulScheduler::findRoles() {
const auto roles_opt = mma_utils::allPatternRoles(id_model_, patterns_);
NVF_ERROR(
roles_opt.has_value(),
"Incompatible roles found between matmul patterns");
std::tie(id_roles_, tensor_roles_) = roles_opt.value();

mma_utils::MatmulOperandInnerDimsOpt inner_dims_opt =
mma_utils::getOperandInnerDims(id_model_, id_roles_, tensor_roles_);
NVF_ERROR(inner_dims_opt.isValid(), inner_dims_opt.getErrorMsg());
inner_dims_ = inner_dims_opt.getData();

as_ = tensor_roles_.at(MatmulTensorRole::OPERAND_A);
bs_ = tensor_roles_.at(MatmulTensorRole::OPERAND_B);

countDims();
}

// Including current tensor naming convention for reference,
// this is very temporary and will change over time and
// in fact the whole body of this function will
// eventually be a set of utility functions for different
// sections of matmul(fusion) kernels, with
// each having its own build out to do.
//
// Current naming convention is based on the following formula:
//
// d = alpha * (a x b) + beta * c
//
// and is defined in the following way:
//
// operands assumed in global memory : a, b, c
//
// registers staging global load : ar, br (short for a/b read)
//
// shared mem cache of operands : acw_smem, bcw_smem (short for a/b
// cache_write smem)
//
// registers at shared memory load output : acr, bcr (short for a/b cache
// read)
//
// register tensor input to the actual mma op: ab, bb (short for a/b
// broadcasted)
//
// accumulator register: mma_result
// - mma_result is MmaOp output if there is epilogue
// - mma_result is dc (short for d cache) if there is no epilogue
//
// result in global memory: d

// Currently the support is for a, b, c and d as fusion inputs/outputs
// aka. no prolog fusion yet.
void AmpereMultipleMatmulScheduler::defineOperandCaches() {
cacheOperandsToSmem(as_, acw_smems_, params_->supported_vec_size.a);
addSetsForCacheReads(acw_smems_, acrs_);
Expand Down Expand Up @@ -669,12 +551,6 @@ void AmpereMultipleMatmulScheduler::cacheOperandsToSmem(
}
}

// We add two LoadStore operators to the inputs of our fusions. The first
// one is for a read from global memory and the second one (below) is for a
// cache read. As an optimizaton, we avoid adding an operator if there's an
// existing LoadStoreOp present. Please note that for the second LoadStore
// we don't propagate the allocation domain, since the scheduler sets the
// allocation domain in the registers.
void AmpereMultipleMatmulScheduler::addSetsForCacheReads(
const std::vector<TensorView*>& tv_smems,
std::vector<TensorView*>& tv_rs) {
Expand Down Expand Up @@ -702,38 +578,6 @@ void AmpereMultipleMatmulScheduler::addSetsForCacheReads(
}
}

//! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer
//! to the new IdModel. This is necessary whenever we perform an operation
//! that creates a new TensorView, such as caching or rFactor
void AmpereMultipleMatmulScheduler::updateIdModel() {
// Build new IdModel
IdModel new_id_model(fusion_, /*build_graphs=*/false);
new_id_model.buildPermissiveGraph();

// Get new permissive graph
ValGraph& new_graph = new_id_model.idGraph(IdMappingMode::PERMISSIVE);

if (!id_roles_.empty()) {
// Update id_roles_ to have keys corresponding to ValGroups in the new
// IdModel
std::unordered_map<ValGroup, MatmulDimRole> new_id_roles;
for (auto& [k, v] : id_roles_) {
const ValGroup& new_group = new_graph.toGroup(k->front());
new_id_roles.emplace(new_group, v);
}
id_roles_ = new_id_roles;
}

graph_ = &new_id_model.idGraph(IdMappingMode::PERMISSIVE);

// Set id_model_ after we are done using the old one
id_model_ = std::move(new_id_model);
}

//! Swizzle the M and N outer dimensions after makeTile has been called.
//! This updates outer_dim_roles if we introduce a new dimension, which can
//! happen if tv is missing a merged axis, in which case we skip merging after
//! the split. This is analogous to forwarding during transform propagation.
void AmpereMultipleMatmulScheduler::swizzleBlockTiles(
TensorView* tv,
std::vector<MatmulDimRole>& outer_dim_roles) {
Expand Down Expand Up @@ -798,8 +642,6 @@ void AmpereMultipleMatmulScheduler::swizzleBlockTiles(
}
}

//! This calls orig->cacheAfter() and also updates the permissive graph to
//! reflect the new IterDomain mappings
TensorView* AmpereMultipleMatmulScheduler::cacheAfter(
TensorView* orig,
LoadStoreOpType op_type,
Expand Down Expand Up @@ -835,16 +677,6 @@ TensorView* AmpereMultipleMatmulScheduler::cacheAfter(
return c;
}

//! Do block tiling for a collection of TensorViews. The tensors should be
//! unscheduled before this method is called.
//! 1) Axes will be ordered according to canonicalDimOrdering, and then axes
//! with the same role will be merged.
//! 2) After that, we perform splits according to
//! params_->tile_sizes.cta_tile, e.g. [M, K] -> [Mo, Ko, Mi, Ki].
//! 3) Depending on the value of params_->grid_swizzle_factor, if the TV has
//! both M and N dimensions, we perform a 2D swizzle of the outer dimensions
//! Mo and No.
//! 4) Finally, we do a split-K split if the splitk_factor is not 1
std::vector<std::vector<MatmulDimRole>> AmpereMultipleMatmulScheduler::
blockTileTensors(const std::vector<TensorView*>& tvs) {
if (canonical_dim_ordering_.empty()) {
Expand Down Expand Up @@ -920,10 +752,6 @@ std::vector<std::vector<MatmulDimRole>> AmpereMultipleMatmulScheduler::
return all_merged_roles;
}

//! Schedule the loads of all operands from global memory to shared memory.
//! Starting from the basic tiled schedule, we swizzle the operand memory.
//! Note that the cache op and LoadStoreOpType are already set during
//! defineOperandCaches().
void AmpereMultipleMatmulScheduler::scheduleOperandSmemStores() {
auto scheduleBranch = [&](const std::vector<TensorView*>& gmem_operands,
const std::vector<TensorView*>& smem_operands,
Expand Down Expand Up @@ -989,8 +817,6 @@ void AmpereMultipleMatmulScheduler::scheduleMmaOperands(
}
}

// MmaOperand contains only A and B. If tvs are outputs (i.e. not operands),
// then operand_type should be std::nullopt.
void AmpereMultipleMatmulScheduler::scheduleMmaResults() {
auto all_merged_roles = blockTileTensors(mma_results_);
for (size_t i : c10::irange(mma_results_.size())) {
Expand Down Expand Up @@ -1238,8 +1064,8 @@ void AmpereMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) {
const MatMulTileOptions& gemm_tile = params_->tile_sizes;
const int64_t vectorization_factor = params_->supported_vec_size.epilogue;
// input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n]
checkConcreteStaticDim(c->axis(-2));
checkConcreteStaticDim(c->axis(-1));
mma_utils::checkConcreteStaticDim(c->axis(-2));
mma_utils::checkConcreteStaticDim(c->axis(-1));
const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as<int64_t>();
const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as<int64_t>();
NVF_ERROR(
Expand Down Expand Up @@ -1360,9 +1186,6 @@ void AmpereMultipleMatmulScheduler::scheduleEpilogue() {
scheduleFusionInputsForEpilogue();
}

//! Propagates transformations from fusion output to fusion tv inputs that are
//! producers in the epilogue. Transformations' propagation aims at input tvs
//! which are not assigned to core roles, that is, are not MMA inputs.
void AmpereMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() {
std::vector<TensorView*> cached_tvs;

Expand Down Expand Up @@ -1463,8 +1286,6 @@ void AmpereMultipleMatmulScheduler::setUpInlining() {
}
}

// NOTE: this should be called after acw_smem, acr, ..., ab, and mma_result
// transforms have been applied and inlining
void AmpereMultipleMatmulScheduler::setUpCircularBuffering() {
// Propagate mma output swizzle and parallelization down the DAG
if (params_->circular_buffer_options.circular_buffer_smem_write) {
Expand Down
49 changes: 7 additions & 42 deletions csrc/scheduler/ampere_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <scheduler/mma_utils.h>
#include <val_graph.h>
#include <val_graph_visitor.h>
#include <scheduler/multi_matmul.h>

namespace nvfuser {

Expand Down Expand Up @@ -66,37 +64,23 @@ namespace nvfuser {
// Each of the named tensors above is scheduled differently. We schedule them
// by building AbstractTensors for each tensor category; these are held in
// AmpereMultipleMatmulScheduler::schedules_.
// TODO: Inheret from SchedulerEntry
class AmpereMultipleMatmulScheduler {
// TODO: Inherit from SchedulerEntry
class AmpereMultipleMatmulScheduler : public MultipleMatmulScheduler {
public:
AmpereMultipleMatmulScheduler(Fusion* fusion, const MatmulParams* params)
: fusion_(fusion),
params_(params),
id_model_(fusion, /*build_graphs=*/false) {
: MultipleMatmulScheduler(fusion, params) {
const auto device_prop = at::cuda::getCurrentDeviceProperties();
const int cc = device_prop->major * 10 + device_prop->minor;
NVF_ERROR(
cc >= 75 && cc < 90,
"This matmul scheduler is restricted to Ampere and Turing.");
}

void run();
void run() final;

private:
void cacheInputsAndOutputs();

void findPatterns();

void countDims();

void translatePatterns();

// Get tensor roles and id roles
// When there are multiple matmul patterns, we can have conflicting roles.
// For now we throw an error if this is the case.
// TODO: This should be checked in canScheduleCompileTime
void findRoles();

// Including current tensor naming convention for reference,
// this is very temporary and will change over time and
// in fact the whole body of this function will
Expand Down Expand Up @@ -148,11 +132,6 @@ class AmpereMultipleMatmulScheduler {
const std::vector<TensorView*>& tv_smems,
std::vector<TensorView*>& tv_rs);

//! Rebuilds IdModel, then updates all ValGroups in abstract tensors to refer
//! to the new IdModel. This is necessary whenever we perform an operation
//! that creates a new TensorView, such as caching or rFactor
void updateIdModel();

//! Swizzle the M and N outer dimensions after makeTile has been called.
//! This updates outer_dim_roles if we introduce a new dimension, which can
//! happen if tv is missing a merged axis, in which case we skip merging after
Expand Down Expand Up @@ -216,26 +195,12 @@ class AmpereMultipleMatmulScheduler {
void setUpCircularBuffering();

private:
Fusion* fusion_;
const MatmulParams* params_;
IdModel id_model_;
// Permissive graph of id_model_, which we modify at times using e.g.
// AbstractTensor.split or by mapping vals in cacheAfter and rFactor
ValGraph* graph_ = nullptr;
std::vector<mma_utils::MatmulPattern> patterns_;
mma_utils::DimRolesMap id_roles_;
mma_utils::TensorRolesMap tensor_roles_;
mma_utils::MatmulOperandInnerDims inner_dims_;

int64_t num_splitk_dims_ = 0, num_device_dims_ = 0, num_local_batch_dims_ = 0,
num_device_and_batch_dims_ = 0;

std::vector<std::pair<TensorView*, TensorView*>> cached_outputs_;

std::vector<ValGroup> canonical_dim_ordering_;

std::vector<TensorView*> as_, bs_, acw_smems_, bcw_smems_, acrs_, bcrs_, abs_,
bbs_, mma_results_, splitk_sums_, smem_epilogues_;
std::vector<TensorView*> acw_smems_, bcw_smems_, acrs_, bcrs_, abs_, bbs_,
splitk_sums_, smem_epilogues_;
};

} // namespace nvfuser
Loading

0 comments on commit 7a3b1a4

Please sign in to comment.