From 07be8b9b3a2279d86c8287b1411fbc46262831aa Mon Sep 17 00:00:00 2001 From: "Gao, Xiang" <qasdfgtyuiop@gmail.com> Date: Mon, 9 Dec 2024 14:24:41 -0800 Subject: [PATCH 1/3] Update all reduction primitives to take a block_dim parameter (#3536) Stacked on https://github.com/NVIDIA/Fuser/pull/3541 This PR is mostly mechanical. It just update all the reduction primitives with a new parameter `block_dim`, which is just `blockDim` for a kernel without warp specialization. For a kernel with warp specialization, we use a value for `block_dim`, which is the block dimension of the compute warps. For better performance, it is beneficial if we can know whether we are using the `blockDim` or a custom value at compile time. To achieve this, I added a new helper class: ```C++ struct DefaultBlockDim { const uint32_t x, y, z; __device__ DefaultBlockDim() : x(blockDim.x), y(blockDim.y), z(blockDim.z) {} __device__ operator dim3() const { return blockDim; } }; ``` When a function takes `block_dim`, its type will be a template type `BlockDimT`, which is either `DefaultBlockDim` or `dim3`, where `DefaultBlockDim` means we are using `blockDim`, and `dim3` means we are using the provided custom value. With this PR, I would consider all our reductions and broadcasts compatible with warp specialization. I expect them to just work with warp specialization. In the future, we may want to use TMA + warp specialization for reduction and persistent kernels. --- csrc/codegen.cpp | 34 ++++ csrc/parallel_dimension_map.h | 5 + runtime/basic_type_traits.cu | 2 + runtime/block_reduction.cu | 123 +++++++------ runtime/block_sync_default.cu | 22 ++- runtime/block_welford_outer.cu | 16 +- runtime/broadcast.cu | 20 ++- runtime/fused_reduction.cu | 244 +++++++++++++++++++++----- runtime/fused_welford_impl.cu | 137 +++++++++++---- runtime/fused_welford_impl_outer.cu | 33 +++- runtime/grid_broadcast.cu | 17 +- runtime/grid_reduction.cu | 159 ++++++++++++----- runtime/grid_sync.cu | 44 +++-- runtime/warp.cu | 43 +++-- runtime/welford.cu | 79 ++++++--- tests/cpp/test_circular_buffering.cpp | 13 -- tests/cpp/test_gpu2.cpp | 9 +- 17 files changed, 733 insertions(+), 267 deletions(-) diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 030de26c84d..0bfb72f1a7a 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1214,6 +1214,20 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } } + std::string genComputeBlockDim() { + std::stringstream ss; + const auto& pdim_map = kernel_->summary().parallel_dimension_map; + if (!pdim_map.hasWarpSpecialization()) { + ss << "DefaultBlockDim()"; + } else { + ss << "dim3(" + << genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDx)) << ", " + << genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDy)) << ", " + << genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDz)) << ")"; + } + return ss.str(); + } + std::string genReductionOp(BinaryOpType op_type, DataType data_type) { std::stringstream lambda; lambda << "[](" << data_type << " &a, " << data_type << " b) " @@ -1252,6 +1266,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genStaticCast(genPtrType(data_type), "shared_mem")); NVF_ERROR(stmt->predicate() != nullptr && stmt->predicate()->hasValue()); func_args.arg(genInline(stmt->predicate())); + func_args.arg(genComputeBlockDim()); indent() << genCall("broadcast::blockBroadcast", template_args, func_args) << ";\n"; @@ -1284,6 +1299,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { NVF_ERROR(read_pred != nullptr && read_pred->hasValue()); func_args.arg(genInline(read_pred)); func_args.arg(genStaticCast(output->dtype(), genInline(init))); + func_args.arg(genComputeBlockDim()); ArgumentBuilder template_args; if (reduction_dims.first->getParallelType() == ParallelType::TIDx && @@ -1349,6 +1365,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genInline(write_pred)); } func_args.arg(genCall(data_type, genInline(init))); + func_args.arg(genComputeBlockDim()); indent() << genCall("blockReduce", template_args, func_args) << ";\n"; } @@ -1578,6 +1595,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genInline(wop->writePredicate())); } func_args.arg(genStaticCast(data_type, 0)); + func_args.arg(genComputeBlockDim()); indent() << genCall("blockWelford", template_args, func_args) << ";\n"; } @@ -1781,6 +1799,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genCall(data_type, genInline(grop->init()))); func_args.arg(genInline(grop->entrance_index())); func_args.arg(genInline(grop->entrances())); + func_args.arg(genComputeBlockDim()); addProfileArguments(func_args, grop); @@ -1915,6 +1934,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(read_pred).arg(write_pred); // init_val func_args.arg(genCall("LocalTuple", data_type, genInline(grop->init()))); + // block_dim + func_args.arg(genComputeBlockDim()); // reduction_op func_args.arg(genReductionOp(op_type, out->dtype())); @@ -1971,6 +1992,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { } // Init val func_args.arg(genCall(data_type, genInline(grop->initVal(0)))); + // block_dim + func_args.arg(genComputeBlockDim()); addProfileArguments(func_args, grop); @@ -2059,6 +2082,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genInline(grouped_grop->entrance_index())); func_args.arg(genInline(grouped_grop->entrances())); + func_args.arg(genComputeBlockDim()); addProfileArguments(func_args, grouped_grop); @@ -2271,6 +2295,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genCall("ConstRefTuple", types, inputs)); func_args.arg(genCall("VolatilePtrTuple", types, work_bufs)); func_args.arg(genCall("LocalTuple", types, init_vals)); + func_args.arg(genComputeBlockDim()); // global_sync_buffer const auto sync_buffer = @@ -2407,6 +2432,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genCall("LocalTuple", data_types, init_args[0])); func_args.arg(genCall("LocalTuple", data_types, init_args[1])); func_args.arg(genCall("LocalTuple", index_types, init_args[2])); + // block_dim + func_args.arg(genComputeBlockDim()); // work buffer func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0])); func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1])); @@ -2498,6 +2525,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genVariableNameConvertAlignedArray(input.get(1))); func_args.arg(genVariableNameConvertAlignedArray(input.get(2))) .append("[0]"); + // block_dim + func_args.arg(genComputeBlockDim()); // global buf for (const auto i : c10::irange(3)) { @@ -2652,6 +2681,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genStaticCast(data_type, 0)); func_args.arg(genInline(gwop->entrance_index())); func_args.arg(genInline(gwop->entrances())); + func_args.arg(genComputeBlockDim()); indent() << genCall("welford::gridWelford", template_args, func_args) << ";\n"; @@ -2751,6 +2781,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(read_pred).arg(write_pred); // init_val func_args.arg(genCall("LocalTuple", data_type_args, init_args)); + // block_dim + func_args.arg(genComputeBlockDim()); // reduction_op func_args.arg(genTemplate( "welfordCombine", ArgumentBuilder().arg(data_type).arg(index_type))); @@ -2877,6 +2909,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { func_args.arg(genInline(write_pred)); } func_args.arg(genCall(data_type, genInline(init))); + func_args.arg(genComputeBlockDim()); indent() << genCall("blockIterGroupedYdimReduce", template_args, func_args) << ";\n"; @@ -3315,6 +3348,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { .append(sync_idx) .append("]"); sync_call_args.arg(sync_segment_size); + sync_call_args.arg(genComputeBlockDim()); auto sync_call = genCall("grid_sync::sync", sync_call_template_parms, sync_call_args); diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index f618de16021..e2e9de423c1 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -61,6 +61,11 @@ class ParallelDimensionMap { //! buffer tensors. Val* getNumComputeThreadsEachBlock() const; + //! Get if the kernel uses warp specialization + bool hasWarpSpecialization() const { + return !warp_specialized_types_.empty(); + } + bool has(ParallelType pt) const { return dim_map_.count(pt) > 0; } diff --git a/runtime/basic_type_traits.cu b/runtime/basic_type_traits.cu index 98eb2695ce1..b2f299cbd73 100644 --- a/runtime/basic_type_traits.cu +++ b/runtime/basic_type_traits.cu @@ -47,6 +47,8 @@ template <class _Tp, class _Up> struct is_same : public false_type {}; template <class _Tp> struct is_same<_Tp, _Tp> : public true_type {}; +template <class T, class U> +constexpr bool is_same_v = is_same<T, U>::value; // is_integral, for some types. template <class _Tp> diff --git a/runtime/block_reduction.cu b/runtime/block_reduction.cu index 817093eb63b..648b0b18714 100644 --- a/runtime/block_reduction.cu +++ b/runtime/block_reduction.cu @@ -21,7 +21,8 @@ template < bool Z_REDUCE, bool Aligned, typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void blockReduce( T& out, const T& inp_val, @@ -29,28 +30,32 @@ __device__ void blockReduce( T* shared_mem, bool read_pred, bool write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // If this thread will output a final result bool should_write = index_utils::maskedIsZero<X_REDUCE, Y_REDUCE, Z_REDUCE>(threadIdx); // Size of the reduction segments unsigned int reduction_size = - index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(blockDim); + index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(block_dim); // Index into the reduction segment unsigned int reduction_tid = index_utils::maskedOffset<X_REDUCE, Y_REDUCE, Z_REDUCE>( - threadIdx, blockDim); + threadIdx, block_dim); // Index of the reduction segment unsigned int reduction_idx = index_utils::maskedOffset<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>( - threadIdx, blockDim); + threadIdx, block_dim); // number of reductions per block unsigned int reduction_num = - index_utils::maskedSize<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>(blockDim); + index_utils::maskedSize<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>(block_dim); // smem_offset is the offset into shared memory for the current thread. // To ensure coalesced access to shared memory, we need to ensure @@ -65,15 +70,15 @@ __device__ void blockReduce( // To avoid this, we should always use the offset based on the indexing of // threads within a block. // Offset into smem for the current thread - unsigned int smem_offset = threadIdx.x + threadIdx.y * blockDim.x + - threadIdx.z * blockDim.x * blockDim.y; + unsigned int smem_offset = threadIdx.x + threadIdx.y * block_dim.x + + threadIdx.z * block_dim.x * block_dim.y; // The peer stride represents the distance between the current element and its // nearest reduction peer. It depends on the reduction dimension. A reduction // peer refers to elements that belong to the same reduction segment. For // example, if the reduction is across TIDy, all the elements in the same // column (with the same TIDx) are considered peers of each other. The - // distance between an element and its nearest peer is blockDim.x. + // distance between an element and its nearest peer is block_dim.x. constexpr int num_redu_dims = (int)X_REDUCE + (int)Y_REDUCE + (int)Z_REDUCE; constexpr bool xz_reduce = (num_redu_dims == 2 && !Y_REDUCE); // reduction in 3 dimensions, XYZ, stride is 1 @@ -82,27 +87,27 @@ __device__ void blockReduce( // Reduction only in 1 dimension, X or Y or Z // e.g. inner or outer reduction // If X_REDUCE, reducing in neighbor cols in smem, peer_stride is 1 - // If Y_REDUCE, reducing in neighbor rows in smem, peer_stride is blockDim.x - // If Z_REDUCE, reducing in neighbor planes in smem, peer_stride is - // blockDim.x * blockDim.y + // If Y_REDUCE, reducing in neighbor rows in smem, peer_stride is + // block_dim.x If Z_REDUCE, reducing in neighbor planes in smem, peer_stride + // is block_dim.x * block_dim.y peer_stride = X_REDUCE ? 1 - : Y_REDUCE ? blockDim.x - : blockDim.x * blockDim.y; + : Y_REDUCE ? block_dim.x + : block_dim.x * block_dim.y; } else if (num_redu_dims == 2) { // Reduction in 2 dimensions, only one dimension is not reduced, !X, !Y, !Z // If !Z_REDUCE, merge XY, reducing neighbor cols, peer_stride is 1 - // If !X_REDUCE, merge ZY, reducing neighbor rows, peer_stride is blockDim.x - // If !Y_REDUCE, if blockDim.y == 1, merge XZ, peer_stride is 1. - // otherwise, needs carefully calculate offset to the reduction peer: + // If !X_REDUCE, merge ZY, reducing neighbor rows, peer_stride is + // block_dim.x If !Y_REDUCE, if block_dim.y == 1, merge XZ, peer_stride + // is 1. otherwise, needs carefully calculate offset to the reduction peer: // (1) redu_offset = reduction_tid + tree_fold_factor - // (2) idz = redu_offset / blockDim.x - // (3) idx = redu_offset % blockDim.x - // (4) smem_offset = idx + threadIdx.y * blockDim.x + idz * blockDim.x * - // blockDim.y + // (2) idz = redu_offset / block_dim.x + // (3) idx = redu_offset % block_dim.x + // (4) smem_offset = idx + threadIdx.y * block_dim.x + idz * block_dim.x * + // block_dim.y if (!Y_REDUCE) { peer_stride = 1; } else { - peer_stride = !Z_REDUCE ? 1 : blockDim.x; + peer_stride = !Z_REDUCE ? 1 : block_dim.x; } } @@ -112,41 +117,41 @@ __device__ void blockReduce( } else { shared_mem[smem_offset] = init_val; } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Reduce down to nearest power of 2 for the tree reduction: int np2 = 1 << (31 - __clz(reduction_size)); if (reduction_tid < np2 && reduction_tid + np2 < reduction_size) { int peer_offset = smem_offset + np2 * peer_stride; if constexpr (xz_reduce) { - if (blockDim.y > 1) { + if (block_dim.y > 1) { int redu_offset = reduction_tid + np2; - int idz = redu_offset / blockDim.x; - int idx = redu_offset % blockDim.x; + int idz = redu_offset / block_dim.x; + int idx = redu_offset % block_dim.x; peer_offset = - idx + threadIdx.y * blockDim.x + idz * blockDim.x * blockDim.y; + idx + threadIdx.y * block_dim.x + idz * block_dim.x * block_dim.y; } } reduction_op(shared_mem[smem_offset], shared_mem[peer_offset]); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // loop peel the final iteration to save one syncthread for the end for (int factor = np2 / 2; factor > 1; factor >>= 1) { if (reduction_tid < factor) { int peer_offset = smem_offset + factor * peer_stride; if constexpr (xz_reduce) { - if (blockDim.y > 1) { + if (block_dim.y > 1) { int redu_offset = reduction_tid + factor; - int idz = redu_offset / blockDim.x; - int idx = redu_offset % blockDim.x; + int idz = redu_offset / block_dim.x; + int idx = redu_offset % block_dim.x; peer_offset = - idx + threadIdx.y * blockDim.x + idz * blockDim.x * blockDim.y; + idx + threadIdx.y * block_dim.x + idz * block_dim.x * block_dim.y; } } reduction_op(shared_mem[smem_offset], shared_mem[peer_offset]); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } if (should_write && write_pred) { @@ -157,7 +162,7 @@ __device__ void blockReduce( } out = result; } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } // Use the same pred for both reads and writes @@ -167,14 +172,19 @@ template < bool Z_REDUCE, bool Aligned, typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void blockReduce( T& out, const T& inp_val, Func reduction_op, T* shared_mem, bool read_write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { blockReduce<X_REDUCE, Y_REDUCE, Z_REDUCE, Aligned, T, Func>( out, inp_val, @@ -182,7 +192,8 @@ __device__ void blockReduce( shared_mem, read_write_pred, read_write_pred, - init_val); + init_val, + block_dim); } // Each thread in the iteration dimension processes N elements @@ -195,7 +206,8 @@ template < bool Aligned, int N, // Number of elements per input array typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void blockIterGroupedYdimReduce( T out[N], const T inp_val[N], @@ -203,14 +215,18 @@ __device__ void blockIterGroupedYdimReduce( T* shared_mem, bool read_pred, bool write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // N should be a valid vectorization factor static_assert( N == 2 || N == 4 || N == 8 || N == 16, "N should be a valid vectorization factor, one of (2, 4, 8, 16)!"); bool should_write = threadIdx.y == 0; - unsigned int reduction_size = blockDim.y; + unsigned int reduction_size = block_dim.y; unsigned int reduction_tid = threadIdx.y; // In shared memory, each row has 128 bytes, if sizeof(T) * N = 32 bytes, each @@ -228,11 +244,12 @@ __device__ void blockIterGroupedYdimReduce( // assume TIDy is the reduction dimension, TIDx is the iteration dimension // TIDz is not used - unsigned int peer_stride = elements_per_load * blockDim.x; + unsigned int peer_stride = elements_per_load * block_dim.x; - unsigned int smem_offset_inter = blockDim.x * blockDim.y * elements_per_load; + unsigned int smem_offset_inter = + block_dim.x * block_dim.y * elements_per_load; unsigned int smem_offset_intra = - (threadIdx.y * blockDim.x + threadIdx.x) * elements_per_load; + (threadIdx.y * block_dim.x + threadIdx.x) * elements_per_load; // load to [total_loads] sections of shared memory #pragma unroll @@ -241,7 +258,7 @@ __device__ void blockIterGroupedYdimReduce( shared_mem + smem_offset_inter * i + smem_offset_intra, const_cast<T*>(inp_val) + i * elements_per_load); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Reduce down to nearest power of 2 for the tree reduction: // Perform parallel reduction for each element in the array @@ -272,7 +289,7 @@ __device__ void blockIterGroupedYdimReduce( shared_mem + self_offset, self + i * elements_per_load); } } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Tree reduction for (int factor = np2 / 2; factor > 1; factor >>= 1) { @@ -302,7 +319,7 @@ __device__ void blockIterGroupedYdimReduce( shared_mem + self_offset, self + i * elements_per_load); } } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } // last reduction @@ -347,7 +364,7 @@ __device__ void blockIterGroupedYdimReduce( out[i] = result[i]; } } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } // Use the same pred for both reads and writes @@ -355,14 +372,19 @@ template < bool Aligned, int N, // Number of elements per input array typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void blockIterGroupedYdimReduce( T out[N], const T inp_val[N], Func reduction_op, T* shared_mem, bool read_write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { blockIterGroupedYdimReduce<Aligned, N, T, Func>( out, inp_val, @@ -370,5 +392,6 @@ __device__ void blockIterGroupedYdimReduce( shared_mem, read_write_pred, read_write_pred, - init_val); + init_val, + block_dim); } diff --git a/runtime/block_sync_default.cu b/runtime/block_sync_default.cu index 2a3048c1a8b..1c865320fa6 100644 --- a/runtime/block_sync_default.cu +++ b/runtime/block_sync_default.cu @@ -6,18 +6,34 @@ */ // clang-format on +// Basically just blockDim, but wrapped as a struct so that we have a mechanism +// to know at compile time that whether we are just using blockDim or some +// custom value. For a kernel without warp specialization, we just use blockDim, +// but for a kernel with warp specialization, we use a custom block_dim whose +// dimension are the dimensions of the compute warps. +struct DefaultBlockDim { + const uint32_t x, y, z; + __device__ DefaultBlockDim() : x(blockDim.x), y(blockDim.y), z(blockDim.z) {} + __device__ operator dim3() const { + return blockDim; + } +}; + // Default block synchronization. Just use __barrier_sync namespace block_sync { __forceinline__ __device__ void init() {} // Thread-block synchronization -template <bool aligned> -__forceinline__ __device__ void sync() { +template <bool aligned, typename BlockDimT> +__forceinline__ __device__ void sync(BlockDimT block_dim) { if constexpr (aligned) { __syncthreads(); - } else { + } else if constexpr (std::is_same_v<BlockDimT, DefaultBlockDim>) { __barrier_sync(0); + } else { + uint32_t num_threads = block_dim.x * block_dim.y * block_dim.z; + asm volatile("bar.sync 0, %0;" : : "r"(num_threads) : "memory"); } } diff --git a/runtime/block_welford_outer.cu b/runtime/block_welford_outer.cu index 3b758fe61d3..ded37e75098 100644 --- a/runtime/block_welford_outer.cu +++ b/runtime/block_welford_outer.cu @@ -56,11 +56,21 @@ namespace impl { // registers than just returing the output. Results would vary // depending on compiler versions, but it seems safer to return outputs // as a new value. -template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY> +template < + bool Aligned, + int NumVals, + typename DataType, + int BDIMX, + int BDIMY, + typename BlockDimT> __inline__ __device__ WelfordTriplet<DataType> blockWelfordOuter( DataType* inp_avg, DataType* inp_var, nvfuser_index_t inp_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, DataType* smem) { constexpr int num_warps = BDIMX * BDIMY / 32; static_assert(num_warps >= 1, "There must be at least a single warp"); @@ -188,7 +198,7 @@ __inline__ __device__ WelfordTriplet<DataType> blockWelfordOuter( } } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // The next step is to let each thread of a warp independently // accumulate the partial results on the shared memory @@ -245,7 +255,7 @@ __inline__ __device__ WelfordTriplet<DataType> blockWelfordOuter( } } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Nothing to do for warps whose wid is larger than NunVals if (wid >= NumVals) { diff --git a/runtime/broadcast.cu b/runtime/broadcast.cu index 127fc6ea762..2bc82b9bd97 100644 --- a/runtime/broadcast.cu +++ b/runtime/broadcast.cu @@ -16,30 +16,40 @@ namespace broadcast { // inp_val: Per-thread source value. Only valid when the thread is a source. // out: Per-thread output location // -template <bool X_THREAD, bool Y_THREAD, bool Z_THREAD, bool Aligned, typename T> +template < + bool X_THREAD, + bool Y_THREAD, + bool Z_THREAD, + bool Aligned, + typename T, + typename BlockDimT> __device__ void blockBroadcast( T& out, const T& inp_val, T* shared_mem, - bool read_write_pred) { + bool read_write_pred, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { const bool has_valid_data = (!X_THREAD || threadIdx.x == 0) && (!Y_THREAD || threadIdx.y == 0) && (!Z_THREAD || threadIdx.z == 0); const auto shared_offset = index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); if (has_valid_data && read_write_pred) { shared_mem[shared_offset] = inp_val; } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (read_write_pred) { out = shared_mem[shared_offset]; } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } } // namespace broadcast diff --git a/runtime/fused_reduction.cu b/runtime/fused_reduction.cu index 731403067d9..9967078fc24 100644 --- a/runtime/fused_reduction.cu +++ b/runtime/fused_reduction.cu @@ -204,6 +204,7 @@ template < bool FORWARD_PROTECT_SMEM, bool Aligned, typename LocalTupleT, + typename BlockDimT, typename... Funcs> struct BlockReduceEach { __inline__ __device__ static void reduce( @@ -215,9 +216,20 @@ struct BlockReduceEach { int num_threads_per_reduction, int num_elements_per_reduction, int reduction_idx, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, Funcs... funcs) { // Finish the reduction of each tuple value with a smaller offset - BlockReduceEach<idx - 1, BROADCAST, true, Aligned, LocalTupleT, Funcs...>:: + BlockReduceEach< + idx - 1, + BROADCAST, + true, + Aligned, + LocalTupleT, + BlockDimT, + Funcs...>:: reduce( block_result, partial_result, @@ -227,6 +239,7 @@ struct BlockReduceEach { num_threads_per_reduction, num_elements_per_reduction, reduction_idx, + block_dim, funcs...); if (num_elements_per_reduction == 1) { @@ -252,7 +265,7 @@ struct BlockReduceEach { copyTuple(shared_buf, smem_offset, block_result_i); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (tid_in_reduction < np2 && tid_in_reduction + np2 < num_elements_per_reduction) { @@ -265,7 +278,7 @@ struct BlockReduceEach { } // Always sync when communicating across smem - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Reduce down to 2 values, last thread will do the final reduction and // can save a syncthreads this way @@ -278,7 +291,7 @@ struct BlockReduceEach { smem_offset + factor, funcs...); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } copyTuple(block_result_i, shared_buf, smem_offset); @@ -300,7 +313,7 @@ struct BlockReduceEach { } // Sync threads to make sure result is in smem - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); copyTuple( block_result_i, @@ -311,7 +324,7 @@ struct BlockReduceEach { block_result.val<idx>(0) = block_result_i.val<0>(0); if (FORWARD_PROTECT_SMEM) { - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } } }; @@ -322,6 +335,7 @@ template < bool FORWARD_PROTECT_SMEM, bool Aligned, typename LocalTupleT, + typename BlockDimT, typename... Funcs> struct BlockReduceEach< -1, @@ -329,6 +343,7 @@ struct BlockReduceEach< FORWARD_PROTECT_SMEM, Aligned, LocalTupleT, + BlockDimT, Funcs...> { __inline__ __device__ static void reduce( LocalTupleT& block_result, @@ -339,6 +354,10 @@ struct BlockReduceEach< int num_threads_per_reduction, int num_elements_per_reduction, int reduction_idx, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, Funcs... funcs) {} }; @@ -360,6 +379,7 @@ template < bool FORWARD_PROTECT_SMEM, bool Aligned, typename LocalTupleT, + typename BlockDimT, typename... Funcs> __inline__ __device__ void blockReduceEach( LocalTupleT& block_result, @@ -370,6 +390,10 @@ __inline__ __device__ void blockReduceEach( int num_threads_per_reduction, int num_elements_per_reduction, int reduction_idx, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, Funcs... reduction_ops) { BlockReduceEach< LocalTupleT::num_vals - 1, @@ -377,6 +401,7 @@ __inline__ __device__ void blockReduceEach( FORWARD_PROTECT_SMEM, Aligned, LocalTupleT, + BlockDimT, Funcs...>:: reduce( block_result, @@ -387,6 +412,7 @@ __inline__ __device__ void blockReduceEach( num_threads_per_reduction, num_elements_per_reduction, reduction_idx, + block_dim, reduction_ops...); } @@ -466,7 +492,7 @@ class ParallelReduce { // reduceGroup does not support Welford-style reductions that reduce // all values of a tuple together, so this is the only entry point // for Welford for now. - template <bool Aligned, typename Func, typename... Types> + template <bool Aligned, typename Func, typename BlockDimT, typename... Types> __device__ __inline__ void reduce( RefTuple<Types...> out, const ConstRefTuple<Types...>& inp, @@ -477,10 +503,14 @@ class ParallelReduce { bool read_pred, // Prevent reading from out of bounds memory bool write_pred, // Prevent from writing out of bounds const LocalTuple<Types...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, Func reduction_op); //! Profiled version - template <bool Aligned, typename Func, typename... Types> + template <bool Aligned, typename Func, typename BlockDimT, typename... Types> __device__ __inline__ void reduce( RefTuple<Types...> out, const ConstRefTuple<Types...>& inp, @@ -491,6 +521,10 @@ class ParallelReduce { bool read_pred, // Prevent reading from out of bounds memory bool write_pred, // Prevent from writing out of bounds const LocalTuple<Types...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, Func reduction_op, int64_t& cycles, int64_t& count); @@ -505,6 +539,7 @@ class ParallelReduce { //! no need to accumulate into the out parameter. template < bool Aligned, + typename BlockDimT, typename... DataTypes, typename... Funcs, typename... BoolTypes> @@ -513,6 +548,10 @@ class ParallelReduce { const ConstRefTuple<DataTypes...>& inp, VolatilePtrTuple<DataTypes...> global_work_buffer, const LocalTuple<DataTypes...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, int64_t* global_sync_buffer, void* shared_mem, const LocalTuple<BoolTypes...>& read_preds, @@ -522,6 +561,7 @@ class ParallelReduce { //! Profiled version template < bool Aligned, + typename BlockDimT, typename... DataTypes, typename... Funcs, typename... BoolTypes> @@ -530,6 +570,10 @@ class ParallelReduce { const ConstRefTuple<DataTypes...>& inp, VolatilePtrTuple<DataTypes...> global_work_buffer, const LocalTuple<DataTypes...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, int64_t* global_sync_buffer, void* shared_mem, const LocalTuple<BoolTypes...>& read_preds, @@ -551,7 +595,12 @@ class ParallelReduce { // simplicity. In practice, it should be really uncommon to group // welford ops with different data types, so this restriction // shouldn't be an issue. - template <bool Aligned, int NumArgs, typename DataType, typename IndexType> + template < + bool Aligned, + int NumArgs, + typename DataType, + typename IndexType, + typename BlockDimT> __device__ __inline__ void welfordGroup( typename MakeRefTuple<NumArgs, DataType>::type out_avg, typename MakeRefTuple<NumArgs, DataType>::type out_var, @@ -562,6 +611,10 @@ class ParallelReduce { const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg, const typename MakeLocalTuple<NumArgs, DataType>::type& init_var, const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, typename MakeVolatilePtrTuple<NumArgs, DataType>::type global_work_buffer_avg, typename MakeVolatilePtrTuple<NumArgs, DataType>::type @@ -574,7 +627,12 @@ class ParallelReduce { const typename MakeLocalTuple<NumArgs, bool>::type& write_preds); //! Profiled version - template <bool Aligned, int NumArgs, typename DataType, typename IndexType> + template < + bool Aligned, + int NumArgs, + typename DataType, + typename IndexType, + typename BlockDimT> __device__ __inline__ void welfordGroup( typename MakeRefTuple<NumArgs, DataType>::type out_avg, typename MakeRefTuple<NumArgs, DataType>::type out_var, @@ -585,6 +643,10 @@ class ParallelReduce { const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg, const typename MakeLocalTuple<NumArgs, DataType>::type& init_var, const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, typename MakeVolatilePtrTuple<NumArgs, DataType>::type global_work_buffer_avg, typename MakeVolatilePtrTuple<NumArgs, DataType>::type @@ -601,7 +663,13 @@ class ParallelReduce { // This is highly specific to the outer-reduction pattern. All the // assumptions should be asserted with static_assert at the begging of // the fuction. - template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY> + template < + bool Aligned, + int NumVals, + typename DataType, + int BDIMX, + int BDIMY, + typename BlockDimT> __device__ __inline__ void welfordGroupOuter( DataType out_avg[NumVals], DataType out_var[NumVals], @@ -609,6 +677,10 @@ class ParallelReduce { const DataType in_avg[NumVals], const DataType in_var[NumVals], nvfuser_index_t in_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, DataType* global_buf_avg, DataType* global_buf_var, nvfuser_index_t* global_buf_N, @@ -616,7 +688,13 @@ class ParallelReduce { int64_t* global_sync_buffer); // Profiled version - template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY> + template < + bool Aligned, + int NumVals, + typename DataType, + int BDIMX, + int BDIMY, + typename BlockDimT> __device__ __inline__ void welfordGroupOuter( DataType out_avg[NumVals], DataType out_var[NumVals], @@ -624,6 +702,10 @@ class ParallelReduce { const DataType in_avg[NumVals], const DataType in_var[NumVals], nvfuser_index_t in_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, DataType* global_buf_avg, DataType* global_buf_var, nvfuser_index_t* global_buf_N, @@ -651,12 +733,17 @@ class ParallelReduce { template < bool BLOCK_BROADCAST, bool Aligned, + typename BlockDimT, typename... DataTypes, typename... Funcs, typename... BoolTypes> __device__ __inline__ static LocalTuple<DataTypes...> reduceGroupBlock( const ConstRefTuple<DataTypes...>& inp, const LocalTuple<DataTypes...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, void* shared_mem, const LocalTuple<BoolTypes...>& read_preds, bool block_reduce_participate, @@ -668,6 +755,7 @@ class ParallelReduce { //! but it isn't synchronized when returning from this function. template < bool Aligned, + typename BlockDimT, typename... DataTypes, typename... Funcs, typename... BoolTypes> @@ -675,6 +763,10 @@ class ParallelReduce { RefTuple<DataTypes...>& out, const VolatilePtrTuple<DataTypes...>& global_work_buffer, const LocalTuple<DataTypes...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, void* shared_mem, nvfuser_index_t block_red_idx_offset, nvfuser_index_t num_thread_iters, @@ -692,21 +784,35 @@ class ParallelReduce { bool Aligned, int NumVals, typename DataType, - typename IndexType> + typename IndexType, + typename BlockDimT> __device__ __inline__ static void welfordGroupBlock( LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result, const ConstRefWelfordTripletTuple<NumVals, DataType, IndexType>& inp, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, PtrTuple<DataType, DataType, IndexType> shared_buf, const typename MakeLocalTuple<NumVals, bool>::type& read_preds, bool block_reduce_participate); //! Welford version of reduceGrouplLastBlock - template <bool Aligned, int NumVals, typename DataType, typename IndexType> + template < + bool Aligned, + int NumVals, + typename DataType, + typename IndexType, + typename BlockDimT> __device__ __inline__ static void welfordGroupLastBlock( RefWelfordTripletTuple<NumVals, DataType, IndexType>& out, const VolatilePtrWelfordTripletTuple<NumVals, DataType, IndexType>& global_work_buffer, const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, PtrTuple<DataType, DataType, IndexType> shared_buf, nvfuser_index_t block_red_idx_offset, nvfuser_index_t num_thread_iters, @@ -729,7 +835,7 @@ template < int Z_THREAD, bool PERSISTENT_REDUCTION, bool BROADCAST> -template <bool Aligned, typename Func, typename... Types> +template <bool Aligned, typename Func, typename BlockDimT, typename... Types> __device__ __inline__ void ParallelReduce< X_BLOCK, Y_BLOCK, @@ -749,6 +855,10 @@ __device__ __inline__ void ParallelReduce< bool read_pred, // Prevent reading from out of bounds memory bool write_pred, // Prevent from writing out of bounds const LocalTuple<Types...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, Func reduction_op) { // If no reduction needed, just return input if (!BLOCK_REDUCE && !GRID_REDUCE) { @@ -785,14 +895,14 @@ __device__ __inline__ void ParallelReduce< // to number of threads int block_reduction_size = index_utils:: maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( - blockDim); + block_dim); // Index in the reduction segment, can be an int since it's limited to // number of threads int tid_in_block_reduction = index_utils::maskedOffset< isReduce(X_THREAD), isReduce(Y_THREAD), - isReduce(Z_THREAD)>(threadIdx, blockDim); + isReduce(Z_THREAD)>(threadIdx, block_dim); // ID of the block reduction this thread is participating in // @@ -802,7 +912,7 @@ __device__ __inline__ void ParallelReduce< // dimension int block_reduction_idx = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // Shared memory buffer is 2D // [iter dimension, reduction dimension] @@ -817,7 +927,7 @@ __device__ __inline__ void ParallelReduce< } // Sync to make sure smem is completely initialized - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Round reduction size down to nearest power of 2 int np2 = 1 << (31 - __clz(block_reduction_size)); @@ -834,7 +944,7 @@ __device__ __inline__ void ParallelReduce< } // Always need to sync while operating on shared memory - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Reduce down until 2 values, leaving 2 values allows us to manually // perform the last reduction and avoid a syncthreads @@ -847,7 +957,7 @@ __device__ __inline__ void ParallelReduce< block_reduce_smem_offset + factor, reduction_op); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } // Accumulate that last valid result @@ -883,7 +993,7 @@ __device__ __inline__ void ParallelReduce< } // Sync threads to make sure result is in smem - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // If the thread is participating, and is not attempting to write out // of bounds, return the broadcasted value. if (block_reduce_participate && write_pred) { @@ -898,7 +1008,7 @@ __device__ __inline__ void ParallelReduce< // // This could be avoided in some cases if we added thread syncs from // block reductions in the syncthread insertion pass. - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); return; } } @@ -929,7 +1039,7 @@ __device__ __inline__ void ParallelReduce< gridDim) * index_utils:: maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - blockDim) * + block_dim) * grid_red_size; global_work_buffer += global_buffer_size; } @@ -948,13 +1058,13 @@ __device__ __inline__ void ParallelReduce< // How many grid reductions have to be performed, in the block dimension const auto num_thread_iters = index_utils:: maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - blockDim); + block_dim); // Which grid reduction does this thread participate in, in the block // dimension const auto thread_red_idx_offset = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // 3D buffer of reductions: // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] @@ -989,7 +1099,10 @@ __device__ __inline__ void ParallelReduce< isReduce(Z_BLOCK), PERSISTENT_REDUCTION, Aligned>( - global_sync_buffer[block_red_idx_offset], grid_red_size, last_block); + global_sync_buffer[block_red_idx_offset], + grid_red_size, + last_block, + block_dim); } // -- START BLOCK CLEANUP -- // @@ -1011,12 +1124,12 @@ __device__ __inline__ void ParallelReduce< int tid_in_block_reduction_2 = index_utils::maskedOffset< activeNotIter(X_THREAD), activeNotIter(Y_THREAD), - activeNotIter(Z_THREAD)>(threadIdx, blockDim); + activeNotIter(Z_THREAD)>(threadIdx, block_dim); int block_reduction_size_2 = index_utils::maskedSize< activeNotIter(X_THREAD), activeNotIter(Y_THREAD), - activeNotIter(Z_THREAD)>(blockDim); + activeNotIter(Z_THREAD)>(block_dim); // 3D buffer of reductions: // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] @@ -1049,7 +1162,7 @@ __device__ __inline__ void ParallelReduce< // Which block reduction this thread is participating in int block_reduction_idx = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // Offset in smem for this thread's result auto smem_offset = @@ -1064,7 +1177,7 @@ __device__ __inline__ void ParallelReduce< copyTuple(shared_buf, smem_offset, last_block_result); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (tid_in_block_reduction_2 < np2 && tid_in_block_reduction_2 + np2 < @@ -1078,7 +1191,7 @@ __device__ __inline__ void ParallelReduce< } // Always sync when communicating across smem - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Reduce down to 2 values, last thread will do the final reduction and // can save a syncthreads this way @@ -1091,7 +1204,7 @@ __device__ __inline__ void ParallelReduce< smem_offset + factor, reduction_op); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } // If this thread in each block has the final result before broadcasting @@ -1113,7 +1226,7 @@ __device__ __inline__ void ParallelReduce< if (grid_reduce_participate && PERSISTENT_REDUCTION) { // If persistent reduction, always broadcast reduced values copyTuple(shared_buf, smem_offset, last_block_result); - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (write_pred && block_reduce_participate) { copyTuple( out, shared_buf, block_reduction_idx * block_reduction_size_2); @@ -1133,7 +1246,7 @@ __device__ __inline__ void ParallelReduce< } } // Forward protect the smem used in this reduction - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } } @@ -1147,7 +1260,7 @@ template < int Z_THREAD, bool PERSISTENT_REDUCTION, bool BROADCAST> -template <bool Aligned, typename Func, typename... Types> +template <bool Aligned, typename Func, typename BlockDimT, typename... Types> __device__ __inline__ void ParallelReduce< X_BLOCK, Y_BLOCK, @@ -1167,6 +1280,10 @@ __device__ __inline__ void ParallelReduce< bool read_pred, // Prevent reading from out of bounds memory bool write_pred, // Prevent from writing out of bounds const LocalTuple<Types...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, Func reduction_op, int64_t& cycles, int64_t& count) { @@ -1186,6 +1303,7 @@ __device__ __inline__ void ParallelReduce< read_pred, write_pred, init_val, + block_dim, reduction_op); if (isLastBlockInGrid() && @@ -1206,6 +1324,7 @@ template < bool BROADCAST> template < bool Aligned, + typename BlockDimT, typename... DataTypes, typename... Funcs, typename... BoolTypes> @@ -1223,6 +1342,10 @@ __device__ __inline__ void ParallelReduce< const ConstRefTuple<DataTypes...>& inp, VolatilePtrTuple<DataTypes...> global_work_buffer, const LocalTuple<DataTypes...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, int64_t* global_sync_buffer, void* shared_mem, const LocalTuple<BoolTypes...>& read_preds, @@ -1257,6 +1380,7 @@ __device__ __inline__ void ParallelReduce< const auto block_result = reduceGroupBlock < !GRID_REDUCE && BROADCAST, Aligned > (inp, init_val, + block_dim, shared_mem, read_preds, block_reduce_participate, @@ -1273,7 +1397,7 @@ __device__ __inline__ void ParallelReduce< // forward-protect the smem buffer. This block sync is not // necessary when a grid reduction follows since a block sync is // done just before the grid sync. - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); return; } @@ -1309,13 +1433,13 @@ __device__ __inline__ void ParallelReduce< // How many grid reductions have to be performed, in the block dimension const auto num_thread_iters = index_utils:: maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - blockDim); + block_dim); // Which grid reduction does this thread participate in, in the block // dimension const auto thread_red_idx_offset = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // 3D buffer of reductions: // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] @@ -1336,7 +1460,7 @@ __device__ __inline__ void ParallelReduce< gridDim) * index_utils:: maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - blockDim) * + block_dim) * grid_red_size; global_work_buffer += global_buffer_size; } @@ -1362,7 +1486,10 @@ __device__ __inline__ void ParallelReduce< isReduce(Z_BLOCK), PERSISTENT_REDUCTION, Aligned>( - global_sync_buffer[block_red_idx_offset], grid_red_size, last_block); + global_sync_buffer[block_red_idx_offset], + grid_red_size, + last_block, + block_dim); } // -- START BLOCK CLEANUP -- // @@ -1370,6 +1497,7 @@ __device__ __inline__ void ParallelReduce< out, global_work_buffer, init_val, + block_dim, shared_mem, block_red_idx_offset, num_thread_iters, @@ -1382,7 +1510,7 @@ __device__ __inline__ void ParallelReduce< funcs...); // Forward protect the smem buffer - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } template < @@ -1396,6 +1524,7 @@ template < bool BROADCAST> template < bool Aligned, + typename BlockDimT, typename... DataTypes, typename... Funcs, typename... BoolTypes> @@ -1413,6 +1542,10 @@ __device__ __inline__ void ParallelReduce< const ConstRefTuple<DataTypes...>& inp, VolatilePtrTuple<DataTypes...> global_work_buffer, const LocalTuple<DataTypes...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, int64_t* global_sync_buffer, void* shared_mem, const LocalTuple<BoolTypes...>& read_preds, @@ -1432,6 +1565,7 @@ __device__ __inline__ void ParallelReduce< inp, global_work_buffer, init_val, + block_dim, global_sync_buffer, shared_mem, read_preds, @@ -1457,6 +1591,7 @@ template < template < bool BLOCK_BROADCAST, bool Aligned, + typename BlockDimT, typename... DataTypes, typename... Funcs, typename... BoolTypes> @@ -1472,6 +1607,10 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce< reduceGroupBlock( const ConstRefTuple<DataTypes...>& inp, const LocalTuple<DataTypes...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, void* shared_mem, const LocalTuple<BoolTypes...>& read_preds, bool block_reduce_participate, @@ -1489,13 +1628,13 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce< // to number of threads const int block_reduction_size = index_utils:: maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( - blockDim); + block_dim); // Index in the reduction segment, can be an int since it's limited to // number of threads const int tid_in_block_reduction = index_utils:: maskedOffset<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // ID of the block reduction this thread is participating in // @@ -1505,7 +1644,7 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce< // dimension const int block_reduction_idx = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // Do not protect the smem buffer as it's not always necessary. impl::blockReduceEach< @@ -1513,6 +1652,7 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce< false, Aligned, LocalTuple<DataTypes...>, + BlockDimT, Funcs...>( block_result, block_result, @@ -1522,6 +1662,7 @@ __device__ __inline__ LocalTuple<DataTypes...> ParallelReduce< block_reduction_size, block_reduction_size, block_reduction_idx, + block_dim, funcs...); return block_result; @@ -1538,6 +1679,7 @@ template < bool BROADCAST> template < bool Aligned, + typename BlockDimT, typename... DataTypes, typename... Funcs, typename... BoolTypes> @@ -1554,6 +1696,10 @@ __device__ __inline__ void ParallelReduce< RefTuple<DataTypes...>& out, const VolatilePtrTuple<DataTypes...>& global_work_buffer, const LocalTuple<DataTypes...>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, void* shared_mem, nvfuser_index_t block_red_idx_offset, nvfuser_index_t num_thread_iters, @@ -1583,12 +1729,12 @@ __device__ __inline__ void ParallelReduce< int tid_in_block_reduction = index_utils::maskedOffset< activeNotIter(X_THREAD), activeNotIter(Y_THREAD), - activeNotIter(Z_THREAD)>(threadIdx, blockDim); + activeNotIter(Z_THREAD)>(threadIdx, block_dim); int block_reduction_size = index_utils::maskedSize< activeNotIter(X_THREAD), activeNotIter(Y_THREAD), - activeNotIter(Z_THREAD)>(blockDim); + activeNotIter(Z_THREAD)>(block_dim); bool has_block_result = index_utils::maskedIsZero< activeNotIter(X_THREAD), @@ -1620,13 +1766,14 @@ __device__ __inline__ void ParallelReduce< // Which block reduction this thread is participating in int block_reduction_idx = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); impl::blockReduceEach< BROADCAST, false, Aligned, LocalTuple<DataTypes...>, + BlockDimT, Funcs...>( last_block_result, last_block_result, @@ -1636,6 +1783,7 @@ __device__ __inline__ void ParallelReduce< block_reduction_size, min(grid_red_size, block_reduction_size), block_reduction_idx, + block_dim, reduction_ops...); copyTupleIf( diff --git a/runtime/fused_welford_impl.cu b/runtime/fused_welford_impl.cu index 314a8b405b5..eeeebc36bdc 100644 --- a/runtime/fused_welford_impl.cu +++ b/runtime/fused_welford_impl.cu @@ -72,7 +72,8 @@ template < bool Aligned, int NumVals, typename DataType, - typename IndexType> + typename IndexType, + typename BlockDimT> struct BlockWelfordEach { __inline__ __device__ static void reduce( LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result, @@ -83,7 +84,11 @@ struct BlockWelfordEach { int tid_in_reduction, int num_threads_per_reduction, int num_elements_per_reduction, - int reduction_idx) { + int reduction_idx, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // Finish the reduction of each tuple value with a smaller offset BlockWelfordEach< idx - 1, @@ -92,7 +97,8 @@ struct BlockWelfordEach { Aligned, NumVals, DataType, - IndexType>:: + IndexType, + BlockDimT>:: reduce( block_result, partial_result, @@ -101,7 +107,8 @@ struct BlockWelfordEach { tid_in_reduction, num_threads_per_reduction, num_elements_per_reduction, - reduction_idx); + reduction_idx, + block_dim); if (num_elements_per_reduction == 1) { if (has_block_result) { @@ -125,7 +132,7 @@ struct BlockWelfordEach { copyTuple(shared_buf, smem_offset, block_result_i); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (tid_in_reduction < np2 && tid_in_reduction + np2 < num_elements_per_reduction) { impl::reduceTuple( @@ -141,7 +148,7 @@ struct BlockWelfordEach { } // Always sync when communicating across smem - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Reduce down to 2 values, last thread will do the final reduction and // can save a syncthreads this way @@ -154,7 +161,7 @@ struct BlockWelfordEach { smem_offset + factor, welfordCombine<DataType, IndexType>); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } copyTuple(block_result_i, shared_buf, smem_offset); @@ -180,7 +187,7 @@ struct BlockWelfordEach { } // Sync threads to make sure result is in smem - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); copyTuple( block_result_i, @@ -193,7 +200,7 @@ struct BlockWelfordEach { block_result.N.val<idx>(0) = block_result_i.val<2>(0); if (FORWARD_PROTECT_SMEM) { - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } } }; @@ -205,7 +212,8 @@ template < bool Aligned, int NumVals, typename DataType, - typename IndexType> + typename IndexType, + typename BlockDimT> struct BlockWelfordEach< -1, BROADCAST, @@ -213,7 +221,8 @@ struct BlockWelfordEach< Aligned, NumVals, DataType, - IndexType> { + IndexType, + BlockDimT> { __inline__ __device__ static void reduce( LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result, const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& @@ -223,7 +232,11 @@ struct BlockWelfordEach< int tid_in_reduction, int num_threads_per_reduction, int num_elements_per_reduction, - int reduction_idx) {} + int reduction_idx, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) {} }; //! Welford version of blockReduceEach. Perform block-parallel Welford @@ -234,7 +247,8 @@ template < bool Aligned, int NumVals, typename DataType, - typename IndexType> + typename IndexType, + typename BlockDimT> __inline__ __device__ void blockWelfordEach( LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result, const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& @@ -244,7 +258,11 @@ __inline__ __device__ void blockWelfordEach( int tid_in_reduction, int num_threads_per_reduction, int num_elements_per_reduction, - int reduction_idx) { + int reduction_idx, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { BlockWelfordEach< NumVals - 1, BROADCAST, @@ -252,7 +270,8 @@ __inline__ __device__ void blockWelfordEach( Aligned, NumVals, DataType, - IndexType>:: + IndexType, + BlockDimT>:: reduce( block_result, partial_result, @@ -261,7 +280,8 @@ __inline__ __device__ void blockWelfordEach( tid_in_reduction, num_threads_per_reduction, num_elements_per_reduction, - reduction_idx); + reduction_idx, + block_dim); } } // namespace impl @@ -275,7 +295,12 @@ template < int Z_THREAD, bool PERSISTENT_REDUCTION, bool BROADCAST> -template <bool Aligned, int NumArgs, typename DataType, typename IndexType> +template < + bool Aligned, + int NumArgs, + typename DataType, + typename IndexType, + typename BlockDimT> __device__ __inline__ void ParallelReduce< X_BLOCK, Y_BLOCK, @@ -295,6 +320,10 @@ __device__ __inline__ void ParallelReduce< const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg, const typename MakeLocalTuple<NumArgs, DataType>::type& init_var, const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, typename MakeVolatilePtrTuple<NumArgs, DataType>::type global_work_buffer_avg, typename MakeVolatilePtrTuple<NumArgs, DataType>::type @@ -338,7 +367,12 @@ __device__ __inline__ void ParallelReduce< NumArgs, DataType, IndexType>( - block_result, inp, shared_buf, read_preds, block_reduce_participate); + block_result, + inp, + block_dim, + shared_buf, + read_preds, + block_reduce_participate); // If block reduction only, save to out and exit if (!GRID_REDUCE) { @@ -352,7 +386,7 @@ __device__ __inline__ void ParallelReduce< // forward-protect the smem buffer. This block sync is not // necessary when a grid reduction follows since a block sync is // done just before the grid sync. - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); return; } @@ -388,13 +422,13 @@ __device__ __inline__ void ParallelReduce< // How many grid reductions have to be performed, in the block dimension const auto num_thread_iters = index_utils:: maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - blockDim); + block_dim); // Which grid reduction does this thread participate in, in the block // dimension const auto thread_red_idx_offset = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // 3D buffer of reductions: // [reduction_offset(grid), iter_offset(grid), iter_offset(block)] @@ -419,7 +453,7 @@ __device__ __inline__ void ParallelReduce< gridDim) * index_utils:: maskedSize<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - blockDim) * + block_dim) * grid_red_size; global_work_buffer += global_buffer_size; } @@ -445,7 +479,10 @@ __device__ __inline__ void ParallelReduce< isReduce(Z_BLOCK), PERSISTENT_REDUCTION, Aligned>( - global_sync_buffer[block_red_idx_offset], grid_red_size, last_block); + global_sync_buffer[block_red_idx_offset], + grid_red_size, + last_block, + block_dim); } // -- START BLOCK CLEANUP -- // @@ -454,6 +491,7 @@ __device__ __inline__ void ParallelReduce< global_work_buffer, LocalWelfordTripletTuple<NumArgs, DataType, IndexType>( init_avg, init_var, init_N), + block_dim, shared_buf, block_red_idx_offset, num_thread_iters, @@ -465,7 +503,7 @@ __device__ __inline__ void ParallelReduce< grid_reduce_participate); // Forward protect the smem buffer - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } template < @@ -477,7 +515,12 @@ template < int Z_THREAD, bool PERSISTENT_REDUCTION, bool BROADCAST> -template <bool Aligned, int NumArgs, typename DataType, typename IndexType> +template < + bool Aligned, + int NumArgs, + typename DataType, + typename IndexType, + typename BlockDimT> __device__ __inline__ void ParallelReduce< X_BLOCK, Y_BLOCK, @@ -497,6 +540,10 @@ __device__ __inline__ void ParallelReduce< const typename MakeLocalTuple<NumArgs, DataType>::type& init_avg, const typename MakeLocalTuple<NumArgs, DataType>::type& init_var, const typename MakeLocalTuple<NumArgs, IndexType>::type& init_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, typename MakeVolatilePtrTuple<NumArgs, DataType>::type global_work_buffer_avg, typename MakeVolatilePtrTuple<NumArgs, DataType>::type @@ -526,6 +573,7 @@ __device__ __inline__ void ParallelReduce< init_avg, init_var, init_N, + block_dim, global_work_buffer_avg, global_work_buffer_var, global_work_buffer_N, @@ -555,7 +603,8 @@ template < bool Aligned, int NumVals, typename DataType, - typename IndexType> + typename IndexType, + typename BlockDimT> __device__ __inline__ void ParallelReduce< X_BLOCK, Y_BLOCK, @@ -568,6 +617,10 @@ __device__ __inline__ void ParallelReduce< welfordGroupBlock( LocalWelfordTripletTuple<NumVals, DataType, IndexType>& block_result, const ConstRefWelfordTripletTuple<NumVals, DataType, IndexType>& inp, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, PtrTuple<DataType, DataType, IndexType> shared_buf, const typename MakeLocalTuple<NumVals, bool>::type& read_preds, bool block_reduce_participate) { @@ -582,13 +635,13 @@ __device__ __inline__ void ParallelReduce< // to number of threads const int block_reduction_size = index_utils:: maskedSize<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( - blockDim); + block_dim); // Index in the reduction segment, can be an int since it's limited to // number of threads const int tid_in_block_reduction = index_utils:: maskedOffset<isReduce(X_THREAD), isReduce(Y_THREAD), isReduce(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // ID of the block reduction this thread is participating in // @@ -598,7 +651,7 @@ __device__ __inline__ void ParallelReduce< // dimension const int block_reduction_idx = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); // Do not protect the smem buffer as it's not always necessary. impl::blockWelfordEach< @@ -607,7 +660,8 @@ __device__ __inline__ void ParallelReduce< Aligned, NumVals, DataType, - IndexType>( + IndexType, + BlockDimT>( block_result, block_result, shared_buf, @@ -615,7 +669,8 @@ __device__ __inline__ void ParallelReduce< tid_in_block_reduction, block_reduction_size, block_reduction_size, - block_reduction_idx); + block_reduction_idx, + block_dim); } template < @@ -627,7 +682,12 @@ template < int Z_THREAD, bool PERSISTENT_REDUCTION, bool BROADCAST> -template <bool Aligned, int NumVals, typename DataType, typename IndexType> +template < + bool Aligned, + int NumVals, + typename DataType, + typename IndexType, + typename BlockDimT> __device__ __inline__ void ParallelReduce< X_BLOCK, Y_BLOCK, @@ -642,6 +702,10 @@ __device__ __inline__ void ParallelReduce< const VolatilePtrWelfordTripletTuple<NumVals, DataType, IndexType>& global_work_buffer, const LocalWelfordTripletTuple<NumVals, DataType, IndexType>& init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, PtrTuple<DataType, DataType, IndexType> shared_buf, nvfuser_index_t block_red_idx_offset, nvfuser_index_t num_thread_iters, @@ -670,12 +734,12 @@ __device__ __inline__ void ParallelReduce< int tid_in_block_reduction = index_utils::maskedOffset< activeNotIter(X_THREAD), activeNotIter(Y_THREAD), - activeNotIter(Z_THREAD)>(threadIdx, blockDim); + activeNotIter(Z_THREAD)>(threadIdx, block_dim); int block_reduction_size = index_utils::maskedSize< activeNotIter(X_THREAD), activeNotIter(Y_THREAD), - activeNotIter(Z_THREAD)>(blockDim); + activeNotIter(Z_THREAD)>(block_dim); bool has_block_result = index_utils::maskedIsZero< activeNotIter(X_THREAD), @@ -700,7 +764,7 @@ __device__ __inline__ void ParallelReduce< // Which block reduction this thread is participating in int block_reduction_idx = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); impl::blockWelfordEach< BROADCAST, @@ -716,7 +780,8 @@ __device__ __inline__ void ParallelReduce< tid_in_block_reduction, block_reduction_size, min(grid_red_size, block_reduction_size), - block_reduction_idx); + block_reduction_idx, + block_dim); copyWelfordTripletTupleIf( out, diff --git a/runtime/fused_welford_impl_outer.cu b/runtime/fused_welford_impl_outer.cu index bd705bbec71..4d314c3bbac 100644 --- a/runtime/fused_welford_impl_outer.cu +++ b/runtime/fused_welford_impl_outer.cu @@ -212,7 +212,13 @@ template < int Z_THREAD, bool PERSISTENT_REDUCTION, bool BROADCAST> -template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY> +template < + bool Aligned, + int NumVals, + typename DataType, + int BDIMX, + int BDIMY, + typename BlockDimT> __device__ __inline__ void ParallelReduce< X_BLOCK, Y_BLOCK, @@ -229,6 +235,10 @@ __device__ __inline__ void ParallelReduce< const DataType in_avg[NumVals], const DataType in_var[NumVals], nvfuser_index_t in_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, DataType* global_buf_avg, DataType* global_buf_var, nvfuser_index_t* global_buf_N, @@ -258,11 +268,11 @@ __device__ __inline__ void ParallelReduce< auto iter_tid = index_utils:: maskedOffset<isIter(X_THREAD), isIter(Y_THREAD), isIter(Z_THREAD)>( - threadIdx, blockDim); + threadIdx, block_dim); auto per_block_result = impl::blockWelfordOuter<Aligned, NumVals, DataType, BDIMX, BDIMY>( - out_avg, out_var, in_N, shared_buf); + out_avg, out_var, in_N, block_dim, shared_buf); // At this point, threads with tid_in_group == 0 has valid partial // results. Store them to global buffer. @@ -310,7 +320,8 @@ __device__ __inline__ void ParallelReduce< isReduce(Y_BLOCK), isReduce(Z_BLOCK), PERSISTENT_REDUCTION, - Aligned>(global_sync_buffer[blockIdx.x], gridDim.y, last_block); + Aligned>( + global_sync_buffer[blockIdx.x], gridDim.y, last_block, block_dim); auto partial_results = welfordGroupAccumulateGlobalBuffer<NumVals, DataType, BDIMX, BDIMY>( @@ -321,6 +332,7 @@ __device__ __inline__ void ParallelReduce< partial_results.avg_.array, partial_results.var_.array, partial_results.N_, + block_dim, shared_buf); // At this point, each thread of the groups with tid_in_group=0 @@ -361,7 +373,13 @@ template < int Z_THREAD, bool PERSISTENT_REDUCTION, bool BROADCAST> -template <bool Aligned, int NumVals, typename DataType, int BDIMX, int BDIMY> +template < + bool Aligned, + int NumVals, + typename DataType, + int BDIMX, + int BDIMY, + typename BlockDimT> __device__ __inline__ void ParallelReduce< X_BLOCK, Y_BLOCK, @@ -378,6 +396,10 @@ __device__ __inline__ void ParallelReduce< const DataType in_avg[NumVals], const DataType in_var[NumVals], nvfuser_index_t in_N, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, DataType* global_buf_avg, DataType* global_buf_var, nvfuser_index_t* global_buf_N, @@ -399,6 +421,7 @@ __device__ __inline__ void ParallelReduce< in_avg, in_var, in_N, + block_dim, global_buf_avg, global_buf_var, global_buf_N, diff --git a/runtime/grid_broadcast.cu b/runtime/grid_broadcast.cu index 5f3db3ffd4a..8beb4055f7c 100644 --- a/runtime/grid_broadcast.cu +++ b/runtime/grid_broadcast.cu @@ -28,13 +28,18 @@ template < bool Y_THREAD, bool Z_THREAD, bool Aligned, - typename T> + typename T, + typename BlockDimT> __device__ void broadcast( T& out, const T& inp_val, volatile T* work_buf, Tensor<int64_t, 1> sync_flags, - bool read_write_pred) { + bool read_write_pred, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // Number of values broadcasted in the grid dimensions const auto grid_seg_size = index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim); @@ -47,13 +52,13 @@ __device__ void broadcast( // Number of threads not participating in a broadcast dimension, this is the // number of thread entries to expect in the work buffer, therefore a striding const auto block_stride = - index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); + index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim); // Which broadcast in the block this is to line up the entry with the work // buffer const auto thread_offset = index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); const bool has_valid_data = (!X_BLOCK || blockIdx.x == gridDim.x - 1) && (!Y_BLOCK || blockIdx.y == gridDim.y - 1) && @@ -67,7 +72,7 @@ __device__ void broadcast( } grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, true, Aligned>( - sync_flags[grid_seg_idx], grid_seg_size); + sync_flags[grid_seg_idx], grid_seg_size, block_dim); if (read_write_pred) { out = work_buf[grid_seg_idx * block_stride + thread_offset]; @@ -76,6 +81,6 @@ __device__ void broadcast( // Make sure everyone has read from the buffer before continuing the kernel // and potentially overwriting grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, true, Aligned>( - sync_flags[grid_seg_idx], grid_seg_size); + sync_flags[grid_seg_idx], grid_seg_size, block_dim); } } // namespace grid_broadcast diff --git a/runtime/grid_reduction.cu b/runtime/grid_reduction.cu index c4a8638910f..3a81432e27a 100644 --- a/runtime/grid_reduction.cu +++ b/runtime/grid_reduction.cu @@ -75,7 +75,8 @@ template < bool Z_THREAD, bool Aligned, typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void gridReduceLastBlock( T& out, const volatile T* in, @@ -87,7 +88,11 @@ __device__ void gridReduceLastBlock( Func reduction_op, T* shared_buf, bool write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // We have to do num_reductions across reduction_size. The reductions are // contiguous, but offset by reduction_size. There is an entry in "in" for // every block, and every thread marked as true. Threads in dimensions marked @@ -96,18 +101,18 @@ __device__ void gridReduceLastBlock( // Find the reduction id of the participating threads const auto block_reduction_segment_idx = index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); // Find an id associated within a reduction segment for all // "non-participating" threads, which will parallelize the reductions for the // "participating" threads const auto id_in_block_segment = index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); // Stride by the "non-participating" threads const auto input_stride_for_thread_in_segment = - index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); + index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim); T inp = init_val; @@ -123,7 +128,7 @@ __device__ void gridReduceLastBlock( // Block reduce the per thread values into per "participating" thread values T inp_tmp = init_val; blockReduce<!X_THREAD, !Y_THREAD, !Z_THREAD, Aligned>( - inp_tmp, inp, reduction_op, shared_buf, true, init_val); + inp_tmp, inp, reduction_op, shared_buf, true, init_val, block_dim); const bool should_write = (X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0); if (should_write && write_pred) { @@ -191,7 +196,8 @@ template < bool PERSISTENT_REDUCTION, bool Aligned, typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void gridReduce( T& out, const T& inp_val, @@ -203,7 +209,11 @@ __device__ void gridReduce( bool write_pred, T init_val, const nvfuser_index_t entrance_ind, - const nvfuser_index_t n_entrances) { + const nvfuser_index_t n_entrances, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { T block_reduction_val = init_val; // Do block reduction when required @@ -215,7 +225,8 @@ __device__ void gridReduce( shared_buf, read_pred, true, - init_val); + init_val, + block_dim); } else if (read_pred) { block_reduction_val = inp_val; } @@ -233,7 +244,7 @@ __device__ void gridReduce( // Number of threads we can use in final reduction, Seems to assume all // threads in the block participate const auto block_reduction_segment_size = - index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); + index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim); // Number of reductions in the grid const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION @@ -251,20 +262,23 @@ __device__ void gridReduce( index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim); auto thread_offset = index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); auto work_buf_offset = block_offset * block_reduction_segment_size + thread_offset; work_buf[work_buf_offset] = block_reduction_val; } if (PERSISTENT_REDUCTION) { grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } else { // Use a different sync flag for each call grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment], - grid_reduction_segment_size); + grid_reduction_segment_size, + block_dim); } bool last_block = @@ -280,14 +294,17 @@ __device__ void gridReduce( reduction_op, shared_buf, write_pred, - init_val); + init_val, + block_dim); } if (PERSISTENT_REDUCTION) { // Make sure we're done with global memory before we allow the kernel to // continue grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } } @@ -306,7 +323,8 @@ template < bool PERSISTENT_REDUCTION, bool Aligned, typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void gridReduce( T& out, const T& inp_val, @@ -319,6 +337,10 @@ __device__ void gridReduce( T init_val, const nvfuser_index_t entrance_ind, const nvfuser_index_t n_entrances, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, int64_t& cycles, int64_t& count) { int64_t start_counter = 0; @@ -349,7 +371,8 @@ __device__ void gridReduce( write_pred, init_val, entrance_ind, - n_entrances); + n_entrances, + block_dim); if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) && index_utils::maskedIsZero<true, true, true>(threadIdx)) { @@ -368,11 +391,16 @@ template < bool Z_THREAD, bool Aligned, typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void gridReduce2PartialReduction( const T& inp_val, T init_val, Func reduction_op, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, volatile T* work_buf, T* shared_buf, bool read_pred, @@ -390,7 +418,8 @@ __device__ void gridReduce2PartialReduction( shared_buf, read_pred, true, - init_val); + init_val, + block_dim); } else if (read_pred) { block_reduction_val = inp_val; } @@ -401,7 +430,7 @@ __device__ void gridReduce2PartialReduction( index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim); auto thread_offset = index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); auto work_buf_offset = block_offset * block_reduction_segment_size + thread_offset; work_buf[work_buf_offset] = block_reduction_val; @@ -421,7 +450,8 @@ template < typename T1, typename Func1, typename T2, - typename Func2> + typename Func2, + typename BlockDimT> __device__ void gridReduceGroup( T1& out1, const T1& inp_val1, @@ -438,7 +468,11 @@ __device__ void gridReduceGroup( bool read_pred, bool write_pred, const nvfuser_index_t entrance_ind, - const nvfuser_index_t n_entrances) { + const nvfuser_index_t n_entrances, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // Number of values to reduce in the reduction segment const auto grid_reduction_segment_size = index_utils::maskedSize<X_BLOCK, Y_BLOCK, Z_BLOCK>(gridDim); @@ -452,7 +486,7 @@ __device__ void gridReduceGroup( // Number of threads we can use in final reduction, Seems to assume all // threads in the block participate const auto block_reduction_segment_size = - index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); + index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim); // Number of reductions in the grid const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION @@ -478,6 +512,7 @@ __device__ void gridReduceGroup( inp_val1, init_val1, reduction_op1, + block_dim, work_buf1, (T1*)shared_buf, read_pred, @@ -496,6 +531,7 @@ __device__ void gridReduceGroup( inp_val2, init_val2, reduction_op2, + block_dim, work_buf2, (T2*)shared_buf, read_pred, @@ -505,11 +541,14 @@ __device__ void gridReduceGroup( if (PERSISTENT_REDUCTION) { grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } else { grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( sync_flags[entrance_ind * grid_segment_size + idx_in_grid_segment], - grid_reduction_segment_size); + grid_reduction_segment_size, + block_dim); } bool last_block = @@ -525,7 +564,8 @@ __device__ void gridReduceGroup( reduction_op1, (T1*)shared_buf, write_pred, - init_val1); + init_val1, + block_dim); gridReduceLastBlock<!X_THREAD, !Y_THREAD, !Z_THREAD, Aligned>( out2, work_buf2, @@ -534,14 +574,17 @@ __device__ void gridReduceGroup( reduction_op2, (T2*)shared_buf, write_pred, - init_val2); + init_val2, + block_dim); } if (PERSISTENT_REDUCTION) { // Make sure we're done with global memory before we allow the kernel to // continue grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } } @@ -558,7 +601,8 @@ template < typename T1, typename Func1, typename T2, - typename Func2> + typename Func2, + typename BlockDimT> __device__ void gridReduceGroup( T1& out1, const T1& inp_val1, @@ -576,6 +620,10 @@ __device__ void gridReduceGroup( bool write_pred, const nvfuser_index_t entrance_ind, const nvfuser_index_t n_entrances, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, int64_t& cycles, int64_t& count) { int64_t start_counter = 0; @@ -613,7 +661,8 @@ __device__ void gridReduceGroup( read_pred, write_pred, entrance_ind, - n_entrances); + n_entrances, + block_dim); if (index_utils::maskedIsLast<true, true, true>(blockIdx, gridDim) && index_utils::maskedIsZero<true, true, true>(threadIdx)) { @@ -705,7 +754,8 @@ template < bool Aligned, int vec_size, typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void iterGroupedGridReduceLastBlock( T* out, const volatile T* in, @@ -719,7 +769,11 @@ __device__ void iterGroupedGridReduceLastBlock( bool write_pred, T init_val, const nvfuser_index_t grid_segment_size, - const nvfuser_index_t idx_in_grid_segment) { + const nvfuser_index_t idx_in_grid_segment, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // We have to do num_reductions across reduction_size. The reductions are // contiguous, but offset by reduction_size. There is an entry in "in" for // every block, and every thread marked as true. Threads in dimensions marked @@ -728,14 +782,14 @@ __device__ void iterGroupedGridReduceLastBlock( // Find the reduction id of the participating threads const auto block_reduction_segment_idx = index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); // Find an id associated within a reduction segment for all // "non-participating" threads, which will parallelize the reductions for the // "participating" threads const auto id_in_block_segment = index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); // index into iteration dim. // Its calculation is same to that in [iterGroupedGridReduce]. Becuase when @@ -743,11 +797,11 @@ __device__ void iterGroupedGridReduceLastBlock( // X_THREAD, Y_THREAD, Z_THREAD are flipped. const auto thread_offset = index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); // Stride by the "non-participating" threads const auto input_stride_for_thread_in_segment = - index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); + index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim); constexpr unsigned int max_align_bytes = 16; constexpr unsigned int vec_bytes = sizeof(T) * vec_size; @@ -814,7 +868,7 @@ __device__ void iterGroupedGridReduceLastBlock( inp_tmp[i] = init_val; } blockIterGroupedYdimReduce<Aligned, vec_size>( - inp_tmp, inp, reduction_op, shared_buf, true, init_val); + inp_tmp, inp, reduction_op, shared_buf, true, init_val, block_dim); const bool should_write = (X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0); if (should_write && write_pred) { @@ -846,7 +900,8 @@ template < bool Aligned, int vec_size, typename T, - typename Func> + typename Func, + typename BlockDimT> __device__ void iterGroupedGridReduce( T* out, const T* inp_val, @@ -856,7 +911,11 @@ __device__ void iterGroupedGridReduce( T* shared_buf, bool read_pred, bool write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // inp or block reduction results T block_reduction_val[vec_size]; @@ -873,7 +932,8 @@ __device__ void iterGroupedGridReduce( shared_buf, read_pred, true, - init_val); + init_val, + block_dim); } else if (read_pred) { #pragma unroll for (int i = 0; i < vec_size; i++) { @@ -893,7 +953,7 @@ __device__ void iterGroupedGridReduce( // Number of reductions in each block const auto block_segment_size = - index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); + index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim); // Number of reductions in the grid const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION @@ -908,7 +968,7 @@ __device__ void iterGroupedGridReduce( index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim); auto thread_offset = index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); // Max vectorized load/store size is 16 bytes, if each thread has more than // 16 bytes, split into multiple sections to ensure each thread occupies @@ -959,12 +1019,16 @@ __device__ void iterGroupedGridReduce( if (PERSISTENT_REDUCTION) { grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } else { // there is only one vectorized call grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } bool last_block = @@ -987,14 +1051,17 @@ __device__ void iterGroupedGridReduce( write_pred, init_val, grid_segment_size, - idx_in_grid_segment); + idx_in_grid_segment, + block_dim); } if (PERSISTENT_REDUCTION) { // Make sure we're done with global memory before we allow the kernel to // continue grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } } } // namespace reduction diff --git a/runtime/grid_sync.cu b/runtime/grid_sync.cu index 4e5289323b1..ba7c115a4d0 100644 --- a/runtime/grid_sync.cu +++ b/runtime/grid_sync.cu @@ -29,16 +29,21 @@ template < bool Y_BLOCK, bool Z_BLOCK, bool PERSISTENT, - bool Aligned> + bool Aligned, + typename BlockDimT> __device__ void sync( int64_t& semaphore, const uint64_t& segment_size, - const bool last_block) { + const bool last_block, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // Finish all global memory transactions before synchronizing __threadfence(); // Synchronize all threads in a block before synchronizing blocks - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Only allow linear_tid == 0 to participate in the synchronization if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { @@ -78,7 +83,7 @@ __device__ void sync( } // Sync block to make sure all other threads are waiting on the sync - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } template < @@ -86,12 +91,20 @@ template < bool Y_BLOCK, bool Z_BLOCK, bool PERSISTENT, - bool Aligned> -__device__ void sync(int64_t& semaphore, const uint64_t& segment_size) { + bool Aligned, + typename BlockDimT> +__device__ void sync( + int64_t& semaphore, + const uint64_t& segment_size, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT, Aligned>( semaphore, segment_size, - index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim)); + index_utils::maskedIsLast<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim), + block_dim); } // Grid sync that can be called multiple times in the same kernel without all @@ -105,16 +118,25 @@ __device__ void sync(int64_t& semaphore, const uint64_t& segment_size) { // // Note that this is not currently used by grid and welford reduction // as they use a separate sync flag for each each grid sync call. -template <bool X_BLOCK, bool Y_BLOCK, bool Z_BLOCK, bool Aligned> +template < + bool X_BLOCK, + bool Y_BLOCK, + bool Z_BLOCK, + bool Aligned, + typename BlockDimT> __device__ void sync( int64_t& semaphore, const uint64_t& segment_size, - const nvfuser_index_t n_entrances) { + const nvfuser_index_t n_entrances, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // Finish all global memory transactions before synchronizing __threadfence(); // Synchronize all threads in a block before synchronizing blocks - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Only allow linear_tid == 0 to participate in the synchronization if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { @@ -147,7 +169,7 @@ __device__ void sync( } // Sync block to make sure all other threads are waiting on the sync - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } // Non-blocking function to read the semaphore value in each calling thread diff --git a/runtime/warp.cu b/runtime/warp.cu index 03d78e74798..5678e4d05b0 100644 --- a/runtime/warp.cu +++ b/runtime/warp.cu @@ -21,14 +21,23 @@ __device__ __forceinline__ std::complex<T> shfl_xor( return std::complex<T>(real, imag); } -template <bool SINGLE_WARP, bool Aligned, typename T, typename Func> +template < + bool SINGLE_WARP, + bool Aligned, + typename T, + typename Func, + typename BlockDimT> __device__ void warpReduceTIDX( T& out, const T& inp_val, Func reduction_op, T* shared_mem, bool read_write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { constexpr int WARP_SIZE = 32; // Assume input padded to multiples of a warp @@ -49,19 +58,19 @@ __device__ void warpReduceTIDX( if (!SINGLE_WARP) { unsigned int warp_idx = threadIdx.x / WARP_SIZE; unsigned int lane_idx = threadIdx.x % WARP_SIZE; - unsigned int reduce_group_id = threadIdx.z * blockDim.y + threadIdx.y; + unsigned int reduce_group_id = threadIdx.z * block_dim.y + threadIdx.y; bool is_warp_head = lane_idx == 0; - unsigned int reduction_size = blockDim.x; + unsigned int reduction_size = block_dim.x; unsigned int num_of_warps = reduction_size / WARP_SIZE; unsigned int smem_offset = reduce_group_id * num_of_warps; - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (is_warp_head) { shared_mem[smem_offset + warp_idx] = reduce_val; } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (warp_idx == 0) { // This assumes num_of_warps will be < 32, meaning < 1024 threads. @@ -82,20 +91,30 @@ __device__ void warpReduceTIDX( } // needs sync, otherwise other warps may access shared memory before this // reduction is done. - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } else { reduction_op(out, reduce_val); } } -template <int BDIMX, int BDIMY, bool Aligned, typename T, typename Func> +template < + int BDIMX, + int BDIMY, + bool Aligned, + typename T, + typename Func, + typename BlockDimT> __device__ void warpReduceTIDXY( T& out, const T& inp_val, Func reduction_op, T* shared_mem, bool read_write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { constexpr int WARP_SIZE = 32; constexpr int num_of_warps = BDIMX * BDIMY / WARP_SIZE; @@ -118,11 +137,11 @@ __device__ void warpReduceTIDXY( unsigned int idx = threadIdx.x + threadIdx.y * BDIMX; unsigned int warp_idx = idx / WARP_SIZE; unsigned int lane_idx = idx % WARP_SIZE; - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (lane_idx == 0) { shared_mem[warp_idx] = reduce_val; } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); if (warp_idx == 0) { reduce_val = lane_idx < num_of_warps ? shared_mem[lane_idx] : init_val; @@ -137,7 +156,7 @@ __device__ void warpReduceTIDXY( } // needs sync, otherwise other warps may access shared memory before this // reduction is done. - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } else { reduction_op(out, reduce_val); } diff --git a/runtime/welford.cu b/runtime/welford.cu index aba379bffcb..35c5f69109b 100644 --- a/runtime/welford.cu +++ b/runtime/welford.cu @@ -114,7 +114,8 @@ template < bool Z_REDUCE, bool Aligned, typename T, - typename TN> + typename TN, + typename BlockDimT> __inline__ __device__ void blockWelford( T& out_avg, T& out_M2, @@ -127,24 +128,28 @@ __inline__ __device__ void blockWelford( TN* shared_mem_N, bool read_pred, bool write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // If this thread will output a final result bool should_write = index_utils::maskedIsZero<X_REDUCE, Y_REDUCE, Z_REDUCE>(threadIdx); // Size of the reduction segments unsigned int reduction_size = - index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(blockDim); + index_utils::maskedSize<X_REDUCE, Y_REDUCE, Z_REDUCE>(block_dim); // Index into the reduction segment unsigned int reduction_tid = index_utils::maskedOffset<X_REDUCE, Y_REDUCE, Z_REDUCE>( - threadIdx, blockDim); + threadIdx, block_dim); // Index of the reduction segment unsigned int reduction_idx = index_utils::maskedOffset<!X_REDUCE, !Y_REDUCE, !Z_REDUCE>( - threadIdx, blockDim); + threadIdx, block_dim); // Offset into smem for the current thread unsigned int smem_offset = reduction_idx * reduction_size + reduction_tid; @@ -159,7 +164,7 @@ __inline__ __device__ void blockWelford( shared_mem_N[smem_offset] = 0; } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // Reduce down to nearest power of 2: int np2 = 1 << (31 - __clz(reduction_size)); @@ -172,7 +177,7 @@ __inline__ __device__ void blockWelford( shared_mem_M2[smem_offset + np2], shared_mem_N[smem_offset + np2]); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); // loop peel the final iteration to save one syncthread for the end for (int factor = np2 / 2; factor > 1; factor >>= 1) { @@ -185,7 +190,7 @@ __inline__ __device__ void blockWelford( shared_mem_M2[smem_offset + factor], shared_mem_N[smem_offset + factor]); } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } if (should_write && write_pred) { @@ -212,7 +217,7 @@ __inline__ __device__ void blockWelford( out_M2 = res_M2; out_N = res_N; } - block_sync::sync<Aligned>(); + block_sync::sync<Aligned>(block_dim); } // Use the same pred for both reads and writes @@ -222,7 +227,8 @@ template < bool Z_REDUCE, bool Aligned, typename T, - typename TN> + typename TN, + typename BlockDimT> __inline__ __device__ void blockWelford( T& out_avg, T& out_M2, @@ -234,7 +240,11 @@ __inline__ __device__ void blockWelford( T* shared_mem_M2, TN* shared_mem_N, bool read_write_pred, - T init_val) { + T init_val, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { blockWelford<X_REDUCE, Y_REDUCE, Z_REDUCE, Aligned, T, TN>( out_avg, out_M2, @@ -247,7 +257,8 @@ __inline__ __device__ void blockWelford( shared_mem_N, read_write_pred, read_write_pred, - init_val); + init_val, + block_dim); } // ----------------------------------------------------------------------------------------------- // Grid Welford Prototype @@ -260,7 +271,8 @@ template < bool Z_THREAD, bool Aligned, typename T, - typename TN> + typename TN, + typename BlockDimT> __device__ void gridWelfordLastBlock( T& out_avg, T& out_M2, @@ -271,8 +283,11 @@ __device__ void gridWelfordLastBlock( const nvfuser_index_t grid_reduction_segment_size, // Number of reductions across // grid reduce dimensions - const nvfuser_index_t - block_reduction_segment_size, // Number of reductions across the block + const nvfuser_index_t block_reduction_segment_size, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim, T* shared_buf_avg, T* shared_buf_M2, TN* shared_buf_N, @@ -286,18 +301,18 @@ __device__ void gridWelfordLastBlock( // Find the reduction id of the participating threads const auto block_reduction_segment_idx = index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); // Find an id associated within a reduction segment for all // "non-participating" threads, which will parallelize the reductions for the // "participating" threads const auto id_in_block_segment = index_utils::maskedOffset<!X_THREAD, !Y_THREAD, !Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); // Stride by the "non-participating" threads const auto input_stride_for_thread_in_segment = - index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(blockDim); + index_utils::maskedSize<!X_THREAD, !Y_THREAD, !Z_THREAD>(block_dim); T inp_avg = init_val; T inp_M2 = init_val; @@ -333,7 +348,8 @@ __device__ void gridWelfordLastBlock( shared_buf_M2, shared_buf_N, true, - init_val); + init_val, + block_dim); const bool should_write = (X_THREAD || threadIdx.x == 0) && (Y_THREAD || threadIdx.y == 0) && (Z_THREAD || threadIdx.z == 0); if (should_write && write_pred) { @@ -352,7 +368,8 @@ template < bool PERSISTENT_REDUCTION, bool Aligned, typename T, - typename TN> + typename TN, + typename BlockDimT> __device__ void gridWelford( T& out_avg, T& out_M2, @@ -371,7 +388,11 @@ __device__ void gridWelford( bool write_pred, T init_val, const nvfuser_index_t entrance_ind, - const nvfuser_index_t n_entrances) { + const nvfuser_index_t n_entrances, + // block_dim is basically just blockDim (wrapped as DefaultBlockDim) if + // there is no warp specialization in the kernel. If there is warp + // specialization, block_dim is the the dimension of the compute warps. + BlockDimT block_dim) { // entrance index only matters for non-persistent re-entrant grid reductions. const nvfuser_index_t entrance_ind_ = PERSISTENT_REDUCTION ? 0 : entrance_ind; const nvfuser_index_t n_entrances_ = PERSISTENT_REDUCTION ? 1 : n_entrances; @@ -389,7 +410,7 @@ __device__ void gridWelford( // Number of threads we can use in final reduction, Seems to assume all // threads in the block participate const auto block_reduction_segment_size = - index_utils::maskedSize<X_THREAD, Y_THREAD, Z_THREAD>(blockDim); + index_utils::maskedSize<X_THREAD, Y_THREAD, Z_THREAD>(block_dim); // Number of reductions in the grid const nvfuser_index_t grid_segment_size = PERSISTENT_REDUCTION @@ -411,7 +432,7 @@ __device__ void gridWelford( index_utils::maskedOffset<X_BLOCK, Y_BLOCK, Z_BLOCK>(blockIdx, gridDim); auto thread_offset = index_utils::maskedOffset<X_THREAD, Y_THREAD, Z_THREAD>( - threadIdx, blockDim); + threadIdx, block_dim); auto work_buf_offset = block_offset * block_reduction_segment_size + thread_offset; if (read_pred) { @@ -427,12 +448,15 @@ __device__ void gridWelford( if (PERSISTENT_REDUCTION) { grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } else { // Use a different sync flag for each call grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( sync_flags[entrance_ind_ * grid_segment_size + idx_in_grid_segment], - grid_reduction_segment_size); + grid_reduction_segment_size, + block_dim); } bool last_block = @@ -449,6 +473,7 @@ __device__ void gridWelford( work_buf_N, grid_reduction_segment_size, block_reduction_segment_size, + block_dim, shared_buf_avg, shared_buf_M2, shared_buf_N, @@ -460,7 +485,9 @@ __device__ void gridWelford( // Make sure we're done with global memory before we allow the kernel to // continue grid_sync::sync<X_BLOCK, Y_BLOCK, Z_BLOCK, PERSISTENT_REDUCTION, Aligned>( - sync_flags[idx_in_grid_segment], grid_reduction_segment_size); + sync_flags[idx_in_grid_segment], + grid_reduction_segment_size, + block_dim); } } diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 725f2c128b9..a5eb595af9c 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1366,12 +1366,6 @@ TEST_P(TmaCircularBufferingTest, PointwiseCpAsync) { TEST_P(TmaCircularBufferingTest, InnerReduction) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); - if (std::holds_alternative<WarpSpecialized>(circular_buffer_type)) { - GTEST_SKIP() - << "This test uses block reduce, which uses hard-coded blockDim, " - << "which can cause deadlock when combined with warp specialization."; - } - std::unique_ptr<Fusion> fusion = std::make_unique<Fusion>(); FusionGuard fg(fusion.get()); @@ -1486,13 +1480,6 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { TEST_P(TmaCircularBufferingTest, Persistent) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); - if (std::holds_alternative<WarpSpecialized>(circular_buffer_type)) { - GTEST_SKIP() - << "This test uses block reduce and block broadcast, " - << "which has hard-coded blockDim, " - << "which can cause deadlock when combined with warp specialization."; - } - constexpr at::ScalarType dtype = at::ScalarType::Float; constexpr int64_t correction = 0; constexpr int64_t reduction_axis = 1; diff --git a/tests/cpp/test_gpu2.cpp b/tests/cpp/test_gpu2.cpp index 08f7f631cca..c6df7767aca 100644 --- a/tests/cpp/test_gpu2.cpp +++ b/tests/cpp/test_gpu2.cpp @@ -2086,7 +2086,8 @@ __global__ void kernel1( (float*)mem_M2, (long*)mem_N, (bool)(threadIdx.x<inp.logical_size[0]), - 0.f); + 0.f, + blockDim); __syncthreads(); if(threadIdx.x<out_var.logical_size[0] && threadIdx.y==0){ welfordCombine( @@ -2175,7 +2176,8 @@ __global__ void kernel1( (float*)mem_M2, (long*)mem_N, (bool)(threadIdx.x<inp.logical_size[0]), - 0.f); + 0.f, + blockDim); __syncthreads(); if(threadIdx.x<out_var.logical_size[0] && threadIdx.y==0 && threadIdx.z==0){ out_avg[threadIdx.x*out_var.alloc_stride[0]]=tmp_avg; @@ -2251,7 +2253,8 @@ __global__ void kernel1( threadIdx.x<out_var.logical_size[0], 0.f, 0, - 1); + 1, + blockDim); if(blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1){ out_avg[threadIdx.x*out_avg.alloc_stride[0]]=tmp_avg; out_var[threadIdx.x*out_var.alloc_stride[0]]=tmp_M2/tmp_N; From 5a2184ca920bc36a808bebc4455ab7161c1ed5a0 Mon Sep 17 00:00:00 2001 From: Ryan Spring <rspring@nvidia.com> Date: Mon, 9 Dec 2024 14:26:09 -0800 Subject: [PATCH 2/3] Add support for 32B and 64B swizzles to hopper matmul scheduler (#3544) This PR adds support for 32B and 64B swizzles to StMatrix indexing and to the hopper matmul scheduler. ### Key Index Change The number of distinct swizzle rows is number of bytes for swizzle divided by size of megabank (16B). The number of times a swizzle pattern is repeated to fill core (8, 8) matrix is number of swizzle rows (8) divided by number of distinct rows. ```cpp MmaInputSmemSwizzle swizzle = getSwizzle(out_tv); int64_t swizzle_bytes = getBytesFromSwizzle(swizzle); constexpr int64_t megabank_size_bytes = 16; const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes; int row = ...; int col = ...; constexpr int64_t swizzle_row_size = 8; const int64_t swizzle_row_repetitions = swizzle_row_size / distinct_swizzle_row_size; int64_t row_in_swizzle_pattern = (row % swizzle_row_size) / swizzle_row_repetitions; int64_t swizzle_col = col ^ row_in_swizzle_pattern; ``` ### Testing Changes * Added `mma_macro` as testing value. * Created separate test suite called `Swizzle/HopperMatmulSchedulerTest` to test `32B`, `64B`, `128B` swizzles. --- csrc/device_lower/pass/index.cpp | 66 ++++++++++++++++++-------- csrc/scheduler/hopper_multi_matmul.cpp | 18 +++---- csrc/scheduler/hopper_multi_matmul.h | 7 ++- tests/cpp/test_matmul_scheduler.cpp | 61 ++++++++++++++++++++---- tests/cpp/utils.h | 34 +++++++++++++ 5 files changed, 144 insertions(+), 42 deletions(-) diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 7f1dc742917..74632876c8c 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1559,15 +1559,15 @@ void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) { } static DataType getMmaInputAType(MmaMacro macro) { - int warp_group_size = isHopper(macro) ? 128 : 32; - int size = getM(macro) * getK(macro) / warp_group_size / - 2 /* halves per 32bit register */; + int64_t warp_group_size = isHopper(macro) ? 128L : 32L; + int64_t size = getM(macro) * getK(macro) / warp_group_size / + 2L /* halves per 32bit register */; return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size}; } static DataType getMmaInputBType(MmaMacro macro) { - int size = getN(macro) * getK(macro) / 32 /* threads per warp */ / - 2 /* halves per 32bit register */; + int64_t size = getN(macro) * getK(macro) / 32L /* threads per warp */ / + 2L /* halves per 32bit register */; return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size}; } @@ -1842,8 +1842,8 @@ Val* hardCodedIndexGenerationForStMatrix( // To account for the threadIdx.y, we have to add it to the offset: // offset_from_tdy = threadIdx.y * tma_m * tma_n * 2 (half) // -// Now, lets apply stmatrix tile to the TMA Box. -// [NO(2), MO(4), MI(16), NIO(4), NII(16)]. +// Now, lets apply stmatrix tile (16, 16) to the TMA Box [NO(2), M(64), NI(64)]. +// [NO(2), MO(4), MI(16), NIO(4), NII(16)]. // // A warp group of 128 threads contains four warps. StMatrix is a warp-level // operation, so four StMatrix operations can be issued simultaneously by the @@ -1865,6 +1865,7 @@ Val* hardCodedIndexGenerationForStMatrix( // domain is scheduled as [NO(2), M(64), NI(64)]. Therefore, we must store the // data in shared memory in [M(64), NI(64)] contiguous tiles. // +// NOTE: This offset is skipped if for-loop is trivial // To account for the outer_index, we have to add it to the offset: // offset_from_outer_index = outer_index * tma_m * NI(64) * 2 (half) // @@ -1928,8 +1929,13 @@ Val* hardCodedIndexGenerationForStMatrix( // with the 8 rows of the matrix to avoid bank conflicts. This swizzle pattern // is repeated along the rows of the TMA box. // +// The number of distinct swizzle rows is number of bytes for swizzle divided by +// size of megabank (16B). The number of times a swizzle pattern is repeated to +// fill core (8, 8) matrix is number of swizzle rows (8) divided by number of +// distinct rows. +// // Swizzle column -// row_in_swizzle_pattern = row % swizzle_row_size(8) +// row_in_swizzle_pattern = (row % swizzle_row_size(8)) / swizzle_repetitions // swizzle_col = column XOR row_in_swizzle_pattern // // Calculate Tile Offset @@ -1939,7 +1945,7 @@ Val* hardCodedIndexGenerationForStMatrix( // // Get shared memory offset // smem_offset = offset_from_tdy + offset_from_outer_index + tile_offset -Val* hardCodedIndexGenerationForStMatrix128BSwizzle( +Val* hardCodedIndexGenerationForStMatrixSwizzle( const LoadStoreOp* ldst, ForLoop* loop, const int64_t stsm_m_tile, @@ -1958,16 +1964,19 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle( NVF_ERROR(ldst->out()->isA<TensorView>()); TensorView* out_tv = ldst->out()->as<TensorView>(); - NVF_ERROR(getSwizzle(out_tv) == MmaInputSmemSwizzle::B128); + MmaInputSmemSwizzle swizzle = getSwizzle(out_tv); + int64_t swizzle_bytes = getBytesFromSwizzle(swizzle); // Constants constexpr int64_t dtype_size = 2; constexpr int64_t warp_size = 32; constexpr int64_t swizzle_row_size = 8; constexpr int64_t stsm_column_size = 8; - constexpr int64_t swizzle_n_tile = 64; + constexpr int64_t megabank_size_bytes = 16; // Derived constants + const int64_t swizzle_n_tile = swizzle_bytes / dtype_size; + const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes; constexpr int64_t stsm_column_stride = stsm_column_size * dtype_size; const int64_t swizzle_n_iter = swizzle_n_tile / stsm_n_tile; const int64_t swizzle_n_tile_stride = swizzle_n_tile * dtype_size; @@ -2000,8 +2009,6 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle( Val* warp_id = SimplifyingIrBuilder::divExpr(TDX, warp_size_val); Val* lane_id = SimplifyingIrBuilder::modExpr(TDX, warp_size_val); - Val* outer_index = - SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val); Val* inner_index = SimplifyingIrBuilder::modExpr(loop->index(), swizzle_n_iter_val); @@ -2021,6 +2028,17 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle( // Swizzle Column Val* row_in_swizzle_pattern = SimplifyingIrBuilder::modExpr(row, swizzle_row_size_val); + + // The swizzle pattern is repeated to fill (8, 8) matrix for 64B and 32B + // swizzles. swizzle_row_iter is the number of repetitions to fill 8 rows + // with distict swizzle rows. + const int64_t swizzle_row_iter = swizzle_row_size / distinct_swizzle_row_size; + if (swizzle_row_iter > 1) { + Val* swizzle_row_iter_val = + IrBuilder::create<Val>(swizzle_row_iter, DataType::Index); + row_in_swizzle_pattern = SimplifyingIrBuilder::divExpr( + row_in_swizzle_pattern, swizzle_row_iter_val); + } Val* swizzle_col = bitwise_xor(col, row_in_swizzle_pattern); // Calculate Tile Offset @@ -2031,16 +2049,22 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle( Val* offset = SimplifyingIrBuilder::addExpr(row_offset, col_offset); // Calculate Tile offset - Val* tile_offset = IrBuilder::mulExpr(outer_index, tile_stride_val); + // Skip tile offset if loop is trivial. + if (!loop->stop()->isOneInt()) { + Val* outer_index = + SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val); + Val* tile_offset = + SimplifyingIrBuilder::mulExpr(outer_index, tile_stride_val); + offset = SimplifyingIrBuilder::addExpr(tile_offset, offset); + } // Calculate TDY offset - Val* tdy_offset = IrBuilder::mulExpr(TDY, tdy_stride_val); + Val* tdy_offset = SimplifyingIrBuilder::mulExpr(TDY, tdy_stride_val); + offset = SimplifyingIrBuilder::addExpr(tdy_offset, offset); // Create shared memory TensorIndex Val* out_index = SimplifyingIrBuilder::addExpr( - IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)), - SimplifyingIrBuilder::addExpr( - tdy_offset, SimplifyingIrBuilder::addExpr(tile_offset, offset))); + IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)), offset); Val* out = IrBuilder::create<kir::TensorIndex>( dynamic_cast<TensorView*>(ldst->out()), out_index); return out; @@ -2092,11 +2116,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { ldst, for_loops_[0], m_tile, n_tile, m, n); break; case MmaInputSmemSwizzle::B128: - out = hardCodedIndexGenerationForStMatrix128BSwizzle( + case MmaInputSmemSwizzle::B64: + case MmaInputSmemSwizzle::B32: + out = hardCodedIndexGenerationForStMatrixSwizzle( ldst, for_loops_[0], m_tile, n_tile, m, n); break; - case MmaInputSmemSwizzle::B32: - case MmaInputSmemSwizzle::B64: default: NVF_ERROR("Unsupported Swizzle Type for StMatrix"); } diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index dad548d3fd0..4320fe80953 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -1027,13 +1027,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { const int64_t tma_m = getM(params_->mma_macro); const int64_t tma_n = getN(params_->mma_macro); - NVF_ERROR( - tma_n >= 64, - "Scheduler only supports 128B swizzle that requires N dimension of MMA ", - "macro to be >= 64, but received ", - tma_n, - "."); - fusion_->manage("st_matrix_m_tile", stmatrix_tile_m); fusion_->manage("st_matrix_n_tile", stmatrix_tile_n); fusion_->manage("st_matrix_m", tma_m); @@ -1084,12 +1077,14 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { dc->setAllocationDomain(s.as<IterDomain*>(), true); } + MmaInputSmemSwizzle swizzle = tmaSwizzleSharedMemory(d_smem); + // Schedule shared memory cache; Output from StMatrix scheduleStMatrixForMmaOutput( - d_smem, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n); + d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n); // Schedule global memory output; Output from TMA Store - scheduleTMAStoreForMmaOutput(d, tma_m, tma_n); + scheduleTMAStoreForMmaOutput(d, swizzle, tma_m, tma_n); } } } @@ -1247,6 +1242,7 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput( TensorView* tv, + MmaInputSmemSwizzle swizzle, int64_t tile_m, int64_t tile_n, int64_t tma_m, @@ -1263,7 +1259,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput( mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain()); // Create tma store allocation domain with swizzle - scheduleTMAStoreForMmaOutput(tv, tma_m, tma_n); + scheduleTMAStoreForMmaOutput(tv, swizzle, tma_m, tma_n); tv->setLoopDomain(s.as<IterDomain*>()); @@ -1290,6 +1286,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput( void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput( TensorView* tv, + MmaInputSmemSwizzle swizzle, int64_t m, int64_t n) { // [M(m), N(n)] -> [MO(1), MI(m), NO(1), NI(n)] @@ -1301,7 +1298,6 @@ void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput( // [BDX, BDY, TDY, MO(1), NO(1), MI, NI] // skip the first 5 iterDomains int64_t num_ids_to_skip = 5; - MmaInputSmemSwizzle swizzle = MmaInputSmemSwizzle::B128; NVF_ERROR(num_ids_to_skip >= 0); if (swizzle == MmaInputSmemSwizzle::None) { diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index e44a6d8830b..8a16225f017 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -182,6 +182,7 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { //! registers to shared memory. void scheduleStMatrixForMmaOutput( TensorView* tv, + MmaInputSmemSwizzle swizzle, int64_t tile_m, int64_t tile_n, int64_t tma_m, @@ -189,7 +190,11 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { //! Schedules the copy operation of output of a Mma op which resided in the //! shared memory to global memory. - void scheduleTMAStoreForMmaOutput(TensorView* tv, int64_t m, int64_t n); + void scheduleTMAStoreForMmaOutput( + TensorView* tv, + MmaInputSmemSwizzle swizzle, + int64_t m, + int64_t n); // Map TensorView's iterDomain to its ValGroup. // Then, find the MatmulDimRole for the ValGroup. diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index 1a25a1ebfe4..e74844d22d3 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3119,8 +3119,8 @@ using HopperMatmulSchedulerTestParams = std::tuple< bool, // b_k_inner int64_t, // M int64_t, // N - int64_t // K - >; + int64_t, // K + MmaMacro>; std::string hopperTestName( const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) { @@ -3128,23 +3128,42 @@ std::string hopperTestName( bool use_smem_epilogue; bool a_k_inner, b_k_inner; int64_t M, N, K; - std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K) = info.param; + MmaMacro mma_macro; + std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) = + info.param; os << (a_k_inner ? "K" : "M"); os << (b_k_inner ? "K" : "N"); os << "_" << M << "_" << N << "_" << K; + os << "_MmaMacro_" << mma_macro_to_str_map.at(mma_macro); if (use_smem_epilogue) { os << "_tma_store"; } return os.str(); } +std::string hopperTestNameSwizzle( + const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) { + std::unordered_map<MmaMacro, std::string> mma_macro_to_swizzle_str_map = { + {MmaMacro::Hopper_64_256_16, "128BSwizzle"}, + {MmaMacro::Hopper_64_128_16, "128BSwizzle"}, + {MmaMacro::Hopper_64_64_16, "128BSwizzle"}, + {MmaMacro::Hopper_64_32_16, "64BSwizzle"}, + {MmaMacro::Hopper_64_16_16, "32BSwizzle"}}; + MmaMacro mma_macro = std::get<6>(info.param); + std::ostringstream os; + os << hopperTestName(info); + os << "_" << mma_macro_to_swizzle_str_map.at(mma_macro); + return os.str(); +} + class HopperMatmulSchedulerTest : public NVFuserFixtureParamTest<HopperMatmulSchedulerTestParams> { protected: void SetUp() { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0); - std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K) = GetParam(); + std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) = + GetParam(); if (a_k_inner) { layout = b_k_inner ? MmaLayout::TN : MmaLayout::TT; @@ -3159,14 +3178,17 @@ class HopperMatmulSchedulerTest // Create custom Matmul Params MatMulTileOptions gemm_tile; // TODO cta tile is a multiple of mma macro for hopper. - gemm_tile.cta_tile = GemmTile(128, 256, 16); + // Default cta_tile configuration is 2-CTA. + gemm_tile.cta_tile = + GemmTile(2 * getM(mma_macro), getN(mma_macro), getK(mma_macro)); // TODO warp tile is (macroM, macroN, macroK) for hopper. - gemm_tile.warp_tile = GemmTile(64, 128, 16); + gemm_tile.warp_tile = + GemmTile(getM(mma_macro), getN(mma_macro), getK(mma_macro)); mparams.supported_vec_size = {8, 8, 4}; - mparams.mma_macro = MmaMacro::Hopper_64_128_16; + mparams.mma_macro = mma_macro; mparams.use_smem_epilogue = use_smem_epilogue; @@ -3203,6 +3225,7 @@ class HopperMatmulSchedulerTest bool use_smem_epilogue; bool a_k_inner, b_k_inner; int64_t M, N, K; + MmaMacro mma_macro; std::unique_ptr<Fusion> fusion_up; Fusion* fusion; std::unique_ptr<FusionGuard> fusion_guard; @@ -3275,7 +3298,7 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) { } INSTANTIATE_TEST_SUITE_P( - , + General, HopperMatmulSchedulerTest, testing::Combine( testing::Bool(), // use_smem_epilogue @@ -3283,8 +3306,28 @@ INSTANTIATE_TEST_SUITE_P( testing::Bool(), // b_k_inner testing::Values(512), // M testing::Values(256), // N - testing::Values(64) // K + testing::Values(64), // K + testing::Values(MmaMacro::Hopper_64_128_16) // mma_macros ), hopperTestName); +INSTANTIATE_TEST_SUITE_P( + Swizzle, + HopperMatmulSchedulerTest, + testing::Combine( + testing::Values(true), // use_smem_epilogue + testing::Bool(), // a_k_inner + testing::Bool(), // b_k_inner + testing::Values(512), // M + testing::Values(256), // N + testing::Values(64), // K + testing::Values( + MmaMacro::Hopper_64_256_16, + MmaMacro::Hopper_64_128_16, + MmaMacro::Hopper_64_64_16, + MmaMacro::Hopper_64_32_16, + MmaMacro::Hopper_64_16_16) // mma_macros + ), + hopperTestNameSwizzle); + } // namespace nvfuser diff --git a/tests/cpp/utils.h b/tests/cpp/utils.h index 766a5e369c7..f835766e576 100644 --- a/tests/cpp/utils.h +++ b/tests/cpp/utils.h @@ -703,6 +703,40 @@ static auto kAllHopperMacros = testing::Values( MmaMacro::Hopper_64_248_16, MmaMacro::Hopper_64_256_16); +static std::unordered_map<MmaMacro, std::string> mma_macro_to_str_map = { + {MmaMacro::Hopper_64_8_16, "m64_n8_k16"}, + {MmaMacro::Hopper_64_16_16, "m64_n16_k16"}, + {MmaMacro::Hopper_64_24_16, "m64_n24_k16"}, + {MmaMacro::Hopper_64_32_16, "m64_n32_k16"}, + {MmaMacro::Hopper_64_40_16, "m64_n40_k16"}, + {MmaMacro::Hopper_64_48_16, "m64_n48_k16"}, + {MmaMacro::Hopper_64_56_16, "m64_n56_k16"}, + {MmaMacro::Hopper_64_64_16, "m64_n64_k16"}, + {MmaMacro::Hopper_64_72_16, "m64_n72_k16"}, + {MmaMacro::Hopper_64_80_16, "m64_n80_k16"}, + {MmaMacro::Hopper_64_88_16, "m64_n88_k16"}, + {MmaMacro::Hopper_64_96_16, "m64_n96_k16"}, + {MmaMacro::Hopper_64_104_16, "m64_n104_k16"}, + {MmaMacro::Hopper_64_112_16, "m64_n112_k16"}, + {MmaMacro::Hopper_64_120_16, "m64_n120_k16"}, + {MmaMacro::Hopper_64_128_16, "m64_n128_k16"}, + {MmaMacro::Hopper_64_136_16, "m64_n136_k16"}, + {MmaMacro::Hopper_64_144_16, "m64_n144_k16"}, + {MmaMacro::Hopper_64_152_16, "m64_n152_k16"}, + {MmaMacro::Hopper_64_160_16, "m64_n160_k16"}, + {MmaMacro::Hopper_64_168_16, "m64_n168_k16"}, + {MmaMacro::Hopper_64_176_16, "m64_n176_k16"}, + {MmaMacro::Hopper_64_184_16, "m64_n184_k16"}, + {MmaMacro::Hopper_64_192_16, "m64_n192_k16"}, + {MmaMacro::Hopper_64_200_16, "m64_n200_k16"}, + {MmaMacro::Hopper_64_208_16, "m64_n208_k16"}, + {MmaMacro::Hopper_64_216_16, "m64_n216_k16"}, + {MmaMacro::Hopper_64_224_16, "m64_n224_k16"}, + {MmaMacro::Hopper_64_232_16, "m64_n232_k16"}, + {MmaMacro::Hopper_64_240_16, "m64_n240_k16"}, + {MmaMacro::Hopper_64_248_16, "m64_n248_k16"}, + {MmaMacro::Hopper_64_256_16, "m64_n256_k16"}}; + // Utility to generate matmul input tensors based on given layout at::Tensor atMatmul(at::Tensor a, at::Tensor b, MmaLayout layout); From 4a897a417c8d14839453baa4925ae2048016b411 Mon Sep 17 00:00:00 2001 From: Jingyue Wu <wujingyue@gmail.com> Date: Mon, 9 Dec 2024 15:30:30 -0800 Subject: [PATCH 3/3] Allgather with DID loop split (#3284) Another baby step towards #2563 --- csrc/multidevice/communication.cpp | 4 +- csrc/multidevice/communication.h | 5 + csrc/multidevice/lower_communication.cpp | 2 +- csrc/multidevice/utils.cpp | 188 +++++++++++------- csrc/multidevice/utils.h | 13 +- tests/cpp/multidevice.cpp | 16 +- .../test_multidevice_lower_communication.cpp | 67 +++++++ tests/cpp/test_multidevice_sharding.cpp | 48 +++++ tests/cpp/test_multidevice_transformer.cpp | 124 ++++++------ 9 files changed, 326 insertions(+), 141 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 522b755b96d..af122ee6e3d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -328,8 +328,8 @@ c10::intrusive_ptr<c10d::Work> postAllgather( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { - auto splits = at::split(output_tensor, /*split_size=*/1, /*dim=*/0); - assertBufferCount(splits, communication->team().size()); + auto splits = + at::tensor_split(output_tensor, communication->team_size(), /*dim=*/0); assertBuffersHaveSameSize({input_tensor}, splits); // allgather primitive in c10d induces extra buffering time to copy out the diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 45c104b36d3..8631a1a04e5 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -90,6 +90,11 @@ class Communication : public Expr { return attribute<Team>(1); } + // A convenience helper so the user doesn't need to convert size_t to int64_t. + int64_t team_size() const { + return static_cast<int64_t>(team().size()); + } + DeviceIdxType root() const { return attribute<DeviceIdxType>(2); } diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index c8068b5a113..4b878ac7376 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -196,7 +196,7 @@ void lowerToReduceScatter( std::vector<Communication*>& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); auto reduction_axis = output_tv->getReductionAxis().value(); - auto scattered_axis = getShardedAxis(output_tv); + auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx); // The output tensor is sharded on scattered_axis and needs to be mapped // back onto the input. The input has an reduced axis, so the scattered axis // is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 24b7e582104..54f1303bc16 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -121,48 +121,133 @@ bool isSharded(const TensorView* tv) { return is_sharded; } -std::vector<int64_t> unshardedSizes( - const TensorView* tv, - c10::IntArrayRef sizes) { - std::vector<int64_t> unsharded_sizes = sizes.vec(); - - for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { - const ParallelType parallel_type = alloc_id->getParallelType(); +namespace { +// Collect device-parallel IterDomains in `domain` and return them as a +// ParallelType-to-IterDomain map. +std::unordered_map<ParallelType, IterDomain*> mapDeviceParallelTypeToId( + const std::vector<IterDomain*>& domain) { + std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id; + parallel_type_to_id.reserve(kParallelTypeDIDs.size()); + for (IterDomain* id : domain) { + const ParallelType parallel_type = id->getParallelType(); if (!isParallelTypeDeviceDim(parallel_type)) { continue; } - const auto inputs = IterVisitor::getInputsTo( - {alloc_id}, - {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); NVF_ERROR( - !inputs.empty(), - "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); - NVF_ERROR( - inputs.size() == 1, - "Failed to find the single logical input to ", - alloc_id, - ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", - toDelimitedString(inputs)); - - const auto iter = std::find( - tv->getLogicalDomain().begin(), - tv->getLogicalDomain().end(), - inputs[0]); + parallel_type_to_id.try_emplace(parallel_type, id).second, + "Found multiple loop IterDomains with the same parallel type (", + parallel_type, + "): ", + toDelimitedString(domain)); + } + return parallel_type_to_id; +} + +std::unordered_map<IterDomain*, int64_t> mapIterDomainToTensorAxis( + const std::vector<IterDomain*>& domain) { + std::unordered_map<IterDomain*, int64_t> id_to_axis; + int64_t axis = 0; + for (auto* id : domain) { + // Reduction IterDomains are not materialized as an at::Tensor axis. + if (id->isReduction()) { + continue; + } + id_to_axis[id] = axis; + axis++; + } + return id_to_axis; +} + +} // namespace + +int64_t getShardedLogicalAxis( + const TensorView* tv, + const ParallelType parallel_type) { + std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id = + mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); + IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); + if (alloc_id == nullptr) { + return -1; + } + + std::unordered_map<IterDomain*, int64_t> logical_id_to_axis = + mapIterDomainToTensorAxis(tv->getLogicalDomain()); + IterDomain* id = alloc_id; + while (logical_id_to_axis.count(id) == 0) { + Expr* def = id->definition(); NVF_ERROR( - iter != tv->getLogicalDomain().end(), - "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", - inputs[0]); - - // Count the number of non-reduction IterDomains before `iter`. Reduction - // IterDomains are not materialized in the at::Tensor's shape. - const auto index = std::count_if( - tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { - return !id->isReduction(); - }); - unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); + def != nullptr, + "Failed to find a non-reduction logical IterDomain that produces ", + alloc_id); + if (auto* split = dynamic_cast<Split*>(def)) { + // Returning just which tensor axis is sharded isn't sufficient to let + // shardTensor, a user of this function, know how to shard the tensor. + // For example, + // + // t = makeContigConcreteTensor({6}); + // t->split(0, 2, /*inner_split=*/true); + // t->axis(-1)->parallelize(DIDx); + // // [i{3}, iDIDx{2}] + // + // and the unsharded tensor is [0, 1, 2, 3, 4, 5], regardless of the + // stride. The sharded tensor ought to be [0, 2, 4] for GPU 0 and [1, 3, + // 5] for GPU 1. However, shardTensor as is will return [0, 1, 2] and [3, + // 4, 5], assuming the axis is sharded outermost. + // + // One potential way to solve the general problem is to replay and rewind + // the splits on the at::Tensor. For example, + // + // t = makeContigConcreteTensor({30}); + // t->split(0, 5); + // t->split(0, 3); + // t->axis(0)->parallelize(Host); + // t->axis(1)->parallelize(DIDx); + // // [iHost{2}, iDIDx{3}, i{5}] + // + // Given an unsharded at::Tensor of shape [30], we'll first replay the + // splits using `torch.view` to get a tensor of shape [2,3,5]. Then, we + // `torch.slice` axis 1 for DIDx to get a tensor of shape [2,1,5]. Then, + // we rewind the splits (and therefore apply merging) using + // `torch.reshape` to get a sharded tensor of shape [10]. + NVF_ERROR( + split->outer() == id, + "Currently, we don't support DID on inner splits: ", + split); + id = split->in(); + } else if (auto* merge = dynamic_cast<Merge*>(def)) { + // For example, + // + // t = makeContigTensor(2); + // t->merge(0, 1); + // t->axis(0)->parallelize(DIDx); + // + // When `unshardedSizes` is given a local tensor of shape [1, 1], it's + // unclear the global shape is [1, D] or [D, 1] or even [2, D/2], etc. + NVF_THROW( + "Failed to attribute the sharding to a single tensor axis and therefore bailed out: ", + merge); + } else { + NVF_THROW( + "Unexpected transforms from logical to a DID-parallel allocation IterDomain: ", + def); + } } + return logical_id_to_axis.at(id); +} + +std::vector<int64_t> unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes) { + std::vector<int64_t> unsharded_sizes = sizes.vec(); + for (ParallelType parallel_type : kParallelTypeDIDs) { + const int64_t sharded_axis = getShardedLogicalAxis(tv, parallel_type); + if (sharded_axis == -1) { + continue; + } + unsharded_sizes.at(sharded_axis) *= tv->getDeviceMesh().size(parallel_type); + } return unsharded_sizes; } @@ -174,27 +259,6 @@ int64_t numDeviceDims(const TensorView* tv) { } namespace { -// Collect device-parallel IterDomains in `loop_domain` and return them as a -// ParallelType-to-IterDomain map. -std::unordered_map<ParallelType, IterDomain*> mapParallelTypeToId( - const std::vector<IterDomain*>& loop_domain) { - std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id; - parallel_type_to_id.reserve(kParallelTypeDIDs.size()); - for (IterDomain* loop_id : loop_domain) { - const ParallelType parallel_type = loop_id->getParallelType(); - if (!isParallelTypeDeviceDim(parallel_type)) { - continue; - } - - NVF_ERROR( - parallel_type_to_id.try_emplace(parallel_type, loop_id).second, - "Found multiple loop IterDomains with the same parallel type (", - parallel_type, - "): ", - toDelimitedString(loop_domain)); - } - return parallel_type_to_id; -} std::vector<IterDomain*> getInputsInTargetDomain( IterDomain* loop_id, @@ -294,9 +358,9 @@ bool haveDifferentShardings( // 3. Check if the two loop IterDomains are almost-exactly mapped in the // IdModel. std::unordered_map<ParallelType, IterDomain*> p_parallel_type_to_id = - mapParallelTypeToId(producer->getLoopDomain()); + mapDeviceParallelTypeToId(producer->getLoopDomain()); std::unordered_map<ParallelType, IterDomain*> c_parallel_type_to_id = - mapParallelTypeToId(consumer->getLoopDomain()); + mapDeviceParallelTypeToId(consumer->getLoopDomain()); for (const auto parallel_type : kParallelTypeDIDs) { IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type); @@ -502,16 +566,6 @@ std::set<DeviceIdxType> involvedDevices(Expr* expr) { return ret; } -int64_t getShardedAxis(TensorView* tv) { - auto ids = TensorDomain::noReductions(tv->getLogicalDomain()); - for (size_t i = 0; i < ids.size(); ++i) { - if (ids[i]->getParallelType() == ParallelType::DIDx) { - return static_cast<int64_t>(i); - } - } - return -1; -} - void reorderDIDToFront(TensorView* tv) { // new position to old position std::unordered_map<int64_t, int64_t> order_map; diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 5be2e11bd15..ef88fbdcf80 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -123,9 +123,16 @@ int64_t requestedNumberOfDevices(Fusion*); void unshard(Fusion*); void unshard(TensorView*); -// Returns the index of the a sharded axis if none return -1. -// TODO: Assumes no merges/splits on sharded axis. -int64_t getShardedAxis(TensorView*); +// Returns the index of the sharded logical axis that produces the allocation +// IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel +// type, returns -1. +// +// This is used to correlate `tv` and its corresponding at::Tensor, e.g., by +// `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and +// `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in +// `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's +// extent if that IterDomain is sharded. +int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); // Reorders a TensorView so that the DID parallelized axis are in front. void reorderDIDToFront(TensorView*); diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index bab5cdccc5e..22897dc5311 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -128,7 +128,10 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) { return tensor; } NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv); - return shardTensor(tensor, getShardedAxis(tv), tv->getDeviceMesh()); + return shardTensor( + tensor, + getShardedLogicalAxis(tv, ParallelType::DIDx), + tv->getDeviceMesh()); } at::Tensor MultiDeviceTest::shardTensor( @@ -144,13 +147,10 @@ at::Tensor MultiDeviceTest::shardTensor( auto stride = extent / nslices; // TODO: returning slice 0 temporarily when device is not in the mesh. i = (i < 0) ? 0 : i; - auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); - // Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx - // axis in front representing the sharded extent of the tensor. - if (stride > 1) { - slice = slice.unsqueeze(0); - } - return slice; + // The following slicing is problematic when DID is on an inner split (cf. + // MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and + // it's enforced by getShardedLogicalAxis. + return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); } } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 643b5b2220d..d1f06d80e1d 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -202,6 +202,73 @@ TEST_F(LowerCollectiveTest, Allgather) { EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor)); } +TEST_F(LowerCollectiveTest, Allgather_LoopSplit) { + auto fusion = std::make_unique<Fusion>(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(mesh); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + in->split(0, num_devices, /*inner_split=*/false); + in->axis(0)->parallelize(ParallelType::DIDx); + in->setAllocationDomain(in->getLoopDomain(), true); + + out->split(0, num_devices, /*inner_split=*/false); + out->setAllocationDomain(out->getLoopDomain(), true); + + at::Tensor unsharded_tensor = + at::randn({num_devices * kTensorSize}, at::kFloat); + at::Tensor in_tensor = + shardTensor(unsharded_tensor, in).to(communicator_->device()); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + assertIsCompiledToHostIrContainer(fec); + + EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor)); +} + +// This currently fails due to getShardingChanges reads root/logical only: +// https://github.com/NVIDIA/Fuser/blob/1dda106a946adcfd1526b83e4f2d4abebb9e32e4/csrc/multidevice/utils.cpp#L77. +// Will try to fix this in a follow-up PR and reenable the test. +TEST_F(LowerCollectiveTest, DISABLED_Allgather_LoopSplit_Noncontiguous) { + auto fusion = std::make_unique<Fusion>(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigTensor(2); + in->setDeviceMesh(mesh); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + in->split(1, num_devices, /*inner_split=*/false); + in->axis(1)->parallelize(ParallelType::DIDx); + in->setAllocationDomain(in->getLoopDomain(), true); + + out->split(1, num_devices, /*inner_split=*/false); + out->setAllocationDomain(out->getLoopDomain(), true); + + at::Tensor unsharded_tensor = + at::arange(2 * num_devices * 3, at::kFloat).view({2, num_devices * 3}); + at::Tensor in_tensor = + shardTensor(unsharded_tensor, in).to(communicator_->device()); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + assertIsCompiledToHostIrContainer(fec); + + EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor)); +} + TEST_F(LowerCollectiveTest, Broadcast) { auto fusion = std::make_unique<Fusion>(); FusionGuard fg(fusion.get()); diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 3adac90bc5e..aaa5d3a3218 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -491,4 +491,52 @@ TEST_P(MultiDeviceBroadcastTest, Expanded) { INSTANTIATE_TEST_SUITE_P(, MultiDeviceBroadcastTest, testing::Bool()); +TEST_F(MultiDeviceTest, ShardTensor_OuterSplit) { + const int d = communicator_->size(); + + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv = makeContigConcreteTensor({2, d * 3}); + tv->setDeviceMesh(DeviceMesh::createForNumDevices(d)); + tv->split(1, d, /*inner_split=*/false); + tv->axis(1)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + + fusion.addInput(tv); + fusion.addOutput(tv); + + at::Tensor unsharded = at::arange(2 * d * 3).view({2, d * 3}); + at::Tensor sharded = shardTensor(unsharded, tv); + + EXPECT_THAT(sharded.sizes(), ElementsAre(2, 3)); + at::Tensor expected = unsharded.view({2, d, 3}).index( + {torch::indexing::Slice(), + communicator_->deviceId(), + torch::indexing::Slice()}); + EXPECT_TRUE(at::equal(sharded, expected)); +} + +TEST_F(MultiDeviceTest, ShardTensor_InnerSplit) { + const int d = communicator_->size(); + + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv = makeContigConcreteTensor({d * 3}); + tv->setDeviceMesh(DeviceMesh::createForNumDevices(d)); + tv->split(0, d, /*inner_split=*/true); + tv->axis(-1)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + + fusion.addInput(tv); + fusion.addOutput(tv); + + at::Tensor unsharded = at::arange(d * 3); + EXPECT_THAT( + [&]() { shardTensor(unsharded, tv); }, + ::testing::ThrowsMessage<nvfuser::nvfError>( + ::testing::HasSubstr("DID on inner splits"))); +} + } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 2ef33dcdf8f..0f39ae6f6e5 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -720,14 +720,14 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { std::vector<c10::IValue> inputs = { x, - shardTensor(w0, 0, mesh), - shardTensor(b0, 0, mesh), - shardTensor(w1, 1, mesh), + shardTensor(w0, 0, mesh).unsqueeze(0), + shardTensor(b0, 0, mesh).unsqueeze(0), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector<at::Tensor> expected_outputs = { - shardTensor(reference_outs[0], 1, mesh), - shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), reference_outs[2], reference_outs[3]}; @@ -801,17 +801,17 @@ TEST_P(DistributedTransformerTest, Sequence_Parallel_MLP_Layer) { auto mask_ = reference_outs[4]; std::vector<c10::IValue> inputs = { - shardTensor(x_, 0, mesh), - shardTensor(w0_, 0, mesh), - shardTensor(b0_, 0, mesh), - shardTensor(w1_, 1, mesh), + shardTensor(x_, 0, mesh).unsqueeze(0), + shardTensor(w0_, 0, mesh).unsqueeze(0), + shardTensor(b0_, 0, mesh).unsqueeze(0), + shardTensor(w1_, 1, mesh).unsqueeze(0), b1_}; std::vector<at::Tensor> expected_outputs = { - shardTensor(reference_outs[0], 1, mesh), - shardTensor(reference_outs[1], 1, mesh), - shardTensor(reference_outs[2], 0, mesh), - shardTensor(reference_outs[3], 0, mesh)}; + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[2], 0, mesh).unsqueeze(0), + shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)}; FusionExecutorCache executor_cache(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -866,12 +866,12 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) { x, shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector<at::Tensor> expected_outputs = { shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh) .view({1, B * S, 3 * E / D}), - shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), reference_outs[2], reference_outs[3]}; @@ -929,17 +929,17 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) { at::manual_seed(getATenRandomSeed()); auto reference_outs = reference_mha(x, w0, b0, w1, b1); std::vector<c10::IValue> inputs = { - shardTensor(x, 0, mesh), + shardTensor(x, 0, mesh).unsqueeze(0), shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector<at::Tensor> expected_outputs = { shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh) .view({1, B * S, 3 * E / D}), - shardTensor(reference_outs[1], 1, mesh), - shardTensor(reference_outs[2], 0, mesh), - shardTensor(reference_outs[3], 0, mesh)}; + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[2], 0, mesh).unsqueeze(0), + shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)}; FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -1003,16 +1003,16 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { grad_, x_, mask_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), - shardTensor(linear0_, 1, mesh)}; + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), + shardTensor(linear0_, 1, mesh).unsqueeze(0)}; std::vector<at::Tensor> expected_outputs = { outs[0], // dropout grad - shardTensor(outs[1], 1, mesh), // linear1 weight grad + shardTensor(outs[1], 1, mesh).unsqueeze(0), // linear1 weight grad outs[2], // linear1 bias grad - shardTensor(outs[3], 1, mesh), // gelu grad - shardTensor(outs[4], 0, mesh), // linear0 weight grad - shardTensor(outs[5], 0, mesh), // linear0 bias grad + shardTensor(outs[3], 1, mesh).unsqueeze(0), // gelu grad + shardTensor(outs[4], 0, mesh).unsqueeze(0), // linear0 weight grad + shardTensor(outs[5], 0, mesh).unsqueeze(0), // linear0 bias grad outs[6]}; // linear0 grad x FusionExecutorCache executor_cache(std::move(fusion)); @@ -1094,22 +1094,23 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { std::vector<c10::IValue> inputs = { x, shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), grad, mask, - shardTensor(reference_outs[0], 1, mesh), // sdpa.output - shardTensor(reference_outs[1], 1, mesh), // sdpa.log_sumexp + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), // sdpa.output + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), // sdpa.log_sumexp reference_outs[2], // sdpa.seed reference_outs[3], // sdpa.offset - shardTensor(reference_outs[13], 1, mesh) // linear0 + shardTensor(reference_outs[13], 1, mesh).unsqueeze(0) // linear0 }; std::vector<at::Tensor> expected_outputs = { reference_outs[4], // dropout grad - shardTensor(reference_outs[5], 1, mesh), // linear1 weight grad + shardTensor(reference_outs[5], 1, mesh) + .unsqueeze(0), // linear1 weight grad reference_outs[6], // linear1 bias grad - shardTensor(reference_outs[7], 1, mesh), // q grad - shardTensor(reference_outs[8], 1, mesh), // k grad - shardTensor(reference_outs[9], 1, mesh), // v grad + shardTensor(reference_outs[7], 1, mesh).unsqueeze(0), // q grad + shardTensor(reference_outs[8], 1, mesh).unsqueeze(0), // k grad + shardTensor(reference_outs[9], 1, mesh).unsqueeze(0), // v grad shardTensor(reference_outs[10].view({3, E, E}), 1, mesh) .view({1, 3 * E / D, E}), // linear0 weight grad shardTensor(reference_outs[11].view({3, E}), 1, mesh) @@ -1234,26 +1235,26 @@ TEST_P(DistributedTransformerTest, Forward_SP) { auto at_out = (resid0_ + mlp_out_).to(at_dtype); std::vector<c10::IValue> inputs = { - shardTensor(x_, 0, mesh), + shardTensor(x_, 0, mesh).unsqueeze(0), ln0_w_, ln0_b_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(mha_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), mha_b1_, ln1_w_, ln1_b_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_b0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_b1_}; std::vector<at::Tensor> expected_outputs = { - shardTensor(ln0_out_, 0, mesh), - shardTensor(mha_out_, 0, mesh), - shardTensor(ln1_out_, 0, mesh), - shardTensor(mlp_out_, 0, mesh), - shardTensor(at_out, 0, mesh)}; + shardTensor(ln0_out_, 0, mesh).unsqueeze(0), + shardTensor(mha_out_, 0, mesh).unsqueeze(0), + shardTensor(ln1_out_, 0, mesh).unsqueeze(0), + shardTensor(mlp_out_, 0, mesh).unsqueeze(0), + shardTensor(at_out, 0, mesh).unsqueeze(0)}; FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -1367,13 +1368,13 @@ TEST_P(DistributedTransformerTest, Forward) { ln0_b_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(mha_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), mha_b1_, ln1_w_, ln1_b_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_b0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_b1_}; std::vector<at::Tensor> expected_outputs = { @@ -1620,13 +1621,16 @@ TEST_P(DistributedTransformerTest, Backward) { auto dx_ = (ln0_x_grad_ + resid1_grad_).to(at_dtype); auto expected_outputs = { - shardTensor(mlp_grads_[1], 1, mesh), // mlp_linear1_weight_grad + shardTensor(mlp_grads_[1], 1, mesh) + .unsqueeze(0), // mlp_linear1_weight_grad mlp_grads_[2], // mlp_linear1_bias_grad - shardTensor(mlp_grads_[4], 0, mesh), // mlp_linear0_weight_grad - shardTensor(mlp_grads_[5], 0, mesh), // mlp_linear0_bias_grad + shardTensor(mlp_grads_[4], 0, mesh) + .unsqueeze(0), // mlp_linear0_weight_grad + shardTensor(mlp_grads_[5], 0, mesh).unsqueeze(0), // mlp_linear0_bias_grad ln1_w_grad_, ln1_b_grad_, - shardTensor(mha_grads_[5], 1, mesh), // mha linear1 weight grad + shardTensor(mha_grads_[5], 1, mesh) + .unsqueeze(0), // mha linear1 weight grad mha_grads_[6], // mha linear1 bias grad shardTensor( mha_grads_[10].view({3, E, E}), 1, mesh) // failing starting here @@ -1641,13 +1645,13 @@ TEST_P(DistributedTransformerTest, Backward) { x_, grad_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), - shardTensor(mha_w1_, 1, mesh), - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_out_[4], // mlp dropout mask mha_out_[4], // mha dropout mask - shardTensor(mha_grads_[0], 1, mesh), // sdpa output - shardTensor(mha_grads_[1], 1, mesh), // sdpa logsum_exp + shardTensor(mha_grads_[0], 1, mesh).unsqueeze(0), // sdpa output + shardTensor(mha_grads_[1], 1, mesh).unsqueeze(0), // sdpa logsum_exp mha_grads_[2], // sdpa seed mha_grads_[3], // sdpa offset ln1_w_, @@ -1658,9 +1662,9 @@ TEST_P(DistributedTransformerTest, Backward) { ln0_b_, ln0_mean_, ln0_rstd_, - shardTensor(mha_out_[0], 1, mesh), // mha linear0 + shardTensor(mha_out_[0], 1, mesh).unsqueeze(0), // mha linear0 mha_out_[2].to(at::kFloat), // mha linear1 - shardTensor(mlp_out_[0], 1, mesh) // mlp linear1 + shardTensor(mlp_out_[0], 1, mesh).unsqueeze(0) // mlp linear1 }; FusionExecutorCache executor_cache(std::move(fusion));