Skip to content

Commit

Permalink
minor cleanup of comments
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Aug 20, 2024
1 parent 7e8ceee commit bd4dc71
Showing 1 changed file with 31 additions and 40 deletions.
71 changes: 31 additions & 40 deletions csrc/quantization/machete/machete_mainloop.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@

#include "cutlass_extensions/cute_utils.cuh"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace machete {

using namespace cute;
Expand All @@ -37,9 +35,6 @@ using namespace cutlass::gemm;
using namespace cutlass::gemm::collective;
using namespace cutlass::gemm::collective::detail;

/////////////////////////////////////////////////////////////////////////////////////////////////

// WarpSpecialized Mainloop that source A operand from registers
template <class ElementATuple_, class GmemLayoutA, int AlignmentA,
class ElementB_, class GmemLayoutB, int AlignmentB,
class ElementAccumulator_, class TileShape_MNK,
Expand All @@ -63,7 +58,7 @@ struct MacheteCollectiveMma {

// Prepacked block shape (N is M in the transposed problem)
using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK;
// Prepacked blocks per dim in each dimension
// Prepacked blocks per dim for a single MMA tile
using PPBlocksPerTile_MK = decltype(make_shape(
size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}),
size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{})));
Expand Down Expand Up @@ -97,7 +92,6 @@ struct MacheteCollectiveMma {
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;

// Required by kernel
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
TileShape_MNK, GMMA::Major::K, GmmaMajorB>(),
Expand All @@ -119,7 +113,6 @@ struct MacheteCollectiveMma {
sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale,
ElementZero, TileShape_MNK>(StageCountType{});

// Required by kernel
struct DispatchPolicy {
constexpr static int Stages = PipelineStages;
using ClusterShape = ClusterShape_MNK;
Expand Down Expand Up @@ -312,7 +305,7 @@ struct MacheteCollectiveMma {
static_assert(size<1>(SmemLayoutAtomScale{}) == 1,
"size<1>(SmemLayoutAtomScale) must be 1.");

public: // TODO: make private
private:
static constexpr ConversionMode get_conversion_mode() {
if constexpr (cute::is_void_v<ElementScale>) {
return ConversionMode::DirectConvert;
Expand Down Expand Up @@ -649,8 +642,8 @@ struct MacheteCollectiveMma {
static constexpr uint32_t TmaTransactionBytes =
TmaTransactionBytesMK + TmaTransactionBytesNK;

/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
/// performance
// Issue Tma Descriptor Prefetch -- ideally from a single thread for best
// performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(
Expand All @@ -676,13 +669,15 @@ struct MacheteCollectiveMma {
}
}

/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the
/// contract Returned tuple must contain at least two elements, with the first
/// two elements being: gA_mkl - The tma tensor, A after a local tile so it
/// has shape (BLK_M,BLK_K,m,k,l) gB_nkl - The tma tensor, B after a local
/// tile so it has shape (BLK_N,BLK_K,n,k,l) The rest of the tensors can be
/// specified as needed by this collective.
// Set up the data needed by this collective for load and mma.
// Returns a tuple of tensors. The collective and the kernel layer have the
// contract Returned tuple must contain at least two elements, with the first
// two elements being: gA_mkl - The tma tensor, A after a local tile so it
// has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local
// tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be
// specified as needed by this collective.
// NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the
// values within a prepacked block.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL,
Params const& mainloop_params) const {
Expand Down Expand Up @@ -739,9 +734,9 @@ struct MacheteCollectiveMma {
}
}

/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
/// This overload gets triggered when we have scales.
// Perform a collective-scoped matrix multiply-accumulate
// Producer Perspective
// This overload gets triggered when we have scales.
template <class... Ts, class KTileIterator, class BlockCoord>
CUTLASS_DEVICE void load(Params const& mainloop_params,
MainloopPipeline pipeline,
Expand Down Expand Up @@ -899,7 +894,7 @@ struct MacheteCollectiveMma {
}
}

/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
Expand All @@ -916,8 +911,8 @@ struct MacheteCollectiveMma {
}
}

/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
// Perform a collective-scoped matrix multiply-accumulate
// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline,
PipelineState smem_pipe_read, FrgTensorC& accum,
Expand Down Expand Up @@ -1151,7 +1146,7 @@ struct MacheteCollectiveMma {
warpgroup_fence_operand(accum);
}

/// Perform a Consumer Epilogue to release all buffers
// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_release,
int k_tile_count) {
Expand All @@ -1173,7 +1168,7 @@ struct MacheteCollectiveMma {
}

private:
/// Utilities for any additional inputs inside of the TMA load
// Utilities for any additional inputs inside of the TMA load
template <class... Ts>
CUTLASS_DEVICE auto partition_extra_tma_inputs(
Params const& mainloop_params, cute::tuple<Ts...> const& load_inputs,
Expand Down Expand Up @@ -1215,8 +1210,8 @@ struct MacheteCollectiveMma {
}
}

/// Utilities for partitioning extra inputs for loading from smem in the
/// mainloop.
// Utilities for partitioning extra inputs for loading from smem in the
// mainloop.
template <class ThreadMma>
CUTLASS_DEVICE auto partition_extra_mma_info(ThreadMma const& thread_mma,
TensorStorage& shared_tensors) {
Expand Down Expand Up @@ -1250,7 +1245,7 @@ struct MacheteCollectiveMma {
}
}

/// Returns the tiled copy and copy views for the extra inputs.
// Returns the tiled copy and copy views for the extra inputs.
template <class TiledMma, class... Ts>
CUTLASS_DEVICE auto retile_extra_mma_info(
TiledMma const& tiled_mma, cute::tuple<Ts...>& partitioned_extra_info,
Expand Down Expand Up @@ -1284,7 +1279,7 @@ struct MacheteCollectiveMma {
}
}

/// Utilities to copy A and extra inputs from smem to RF
// Utilities to copy A and extra inputs from smem to RF
template <class... Ts, class... Us>
CUTLASS_DEVICE void load_extra_info_to_registers(
cute::tuple<Ts...> const& partitioned_mma_extra_info,
Expand Down Expand Up @@ -1319,7 +1314,7 @@ struct MacheteCollectiveMma {
}
}

/// Utilities to transform A.
// Utilities to transform A.
template <class TCrA_load, int VectorWidthA, class TCrA_mma, class... Ts>
CUTLASS_DEVICE void transform_A_kblock(
TCrA_load const& tCrA_load, cute::Int<VectorWidthA> vec_A,
Expand Down Expand Up @@ -1378,20 +1373,20 @@ struct MacheteCollectiveMma {
}
}

/// Utilities for transforming the A operand prior to issuing tensorcore math.
// Utilities for transforming the A operand prior to issuing tensorcore math.
template <typename IlvdBlkLayout, class EngineIn, class EngineOut,
class TensorLayout,
int ConversionVectorWidth = cosize_v<TensorLayout>>
CUTLASS_DEVICE void convert_tensor(
Tensor<EngineIn, TensorLayout> const& in,
Tensor<EngineOut, TensorLayout>& out,
cute::Int<ConversionVectorWidth> width = {}) {
/// This is an element-wise conversion where we expect both tensors to have
/// the same layout. As a result, we can cast as a cutlass array to use the
/// fast numeric converters without worrying about indexing into the layout.
// This is an element-wise conversion where we expect both tensors to have
// the same layout. As a result, we can cast as a cutlass array to use the
// fast numeric converters without worrying about indexing into the layout.
constexpr int N = cosize_v<TensorLayout>;

/// The inputs must be backed by registers & be statically sized.
// The inputs must be backed by registers & be statically sized.
static_assert(is_rmem<EngineIn>::value,
"Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value,
Expand Down Expand Up @@ -1428,8 +1423,4 @@ struct MacheteCollectiveMma {
}
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace machete

/////////////////////////////////////////////////////////////////////////////////////////////////

0 comments on commit bd4dc71

Please sign in to comment.