Skip to content

Commit

Permalink
Merge branch 'main' into pm/add_thunder_bench
Browse files Browse the repository at this point in the history
  • Loading branch information
Priya2698 authored Dec 10, 2024
2 parents 3fa27fb + 4a897a4 commit d69811f
Show file tree
Hide file tree
Showing 31 changed files with 1,203 additions and 450 deletions.
34 changes: 34 additions & 0 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,20 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}
}

std::string genComputeBlockDim() {
std::stringstream ss;
const auto& pdim_map = kernel_->summary().parallel_dimension_map;
if (!pdim_map.hasWarpSpecialization()) {
ss << "DefaultBlockDim()";
} else {
ss << "dim3("
<< genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDx)) << ", "
<< genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDy)) << ", "
<< genInlineOrOne(pdim_map.getRawCompute(ParallelType::TIDz)) << ")";
}
return ss.str();
}

std::string genReductionOp(BinaryOpType op_type, DataType data_type) {
std::stringstream lambda;
lambda << "[](" << data_type << " &a, " << data_type << " b) "
Expand Down Expand Up @@ -1252,6 +1266,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genStaticCast(genPtrType(data_type), "shared_mem"));
NVF_ERROR(stmt->predicate() != nullptr && stmt->predicate()->hasValue());
func_args.arg(genInline(stmt->predicate()));
func_args.arg(genComputeBlockDim());

indent() << genCall("broadcast::blockBroadcast", template_args, func_args)
<< ";\n";
Expand Down Expand Up @@ -1284,6 +1299,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
NVF_ERROR(read_pred != nullptr && read_pred->hasValue());
func_args.arg(genInline(read_pred));
func_args.arg(genStaticCast(output->dtype(), genInline(init)));
func_args.arg(genComputeBlockDim());

ArgumentBuilder template_args;
if (reduction_dims.first->getParallelType() == ParallelType::TIDx &&
Expand Down Expand Up @@ -1349,6 +1365,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genInline(write_pred));
}
func_args.arg(genCall(data_type, genInline(init)));
func_args.arg(genComputeBlockDim());

indent() << genCall("blockReduce", template_args, func_args) << ";\n";
}
Expand Down Expand Up @@ -1578,6 +1595,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genInline(wop->writePredicate()));
}
func_args.arg(genStaticCast(data_type, 0));
func_args.arg(genComputeBlockDim());

indent() << genCall("blockWelford", template_args, func_args) << ";\n";
}
Expand Down Expand Up @@ -1781,6 +1799,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genCall(data_type, genInline(grop->init())));
func_args.arg(genInline(grop->entrance_index()));
func_args.arg(genInline(grop->entrances()));
func_args.arg(genComputeBlockDim());

addProfileArguments(func_args, grop);

Expand Down Expand Up @@ -1915,6 +1934,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(read_pred).arg(write_pred);
// init_val
func_args.arg(genCall("LocalTuple", data_type, genInline(grop->init())));
// block_dim
func_args.arg(genComputeBlockDim());
// reduction_op
func_args.arg(genReductionOp(op_type, out->dtype()));

Expand Down Expand Up @@ -1971,6 +1992,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
}
// Init val
func_args.arg(genCall(data_type, genInline(grop->initVal(0))));
// block_dim
func_args.arg(genComputeBlockDim());

addProfileArguments(func_args, grop);

Expand Down Expand Up @@ -2059,6 +2082,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {

func_args.arg(genInline(grouped_grop->entrance_index()));
func_args.arg(genInline(grouped_grop->entrances()));
func_args.arg(genComputeBlockDim());

addProfileArguments(func_args, grouped_grop);

Expand Down Expand Up @@ -2271,6 +2295,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genCall("ConstRefTuple", types, inputs));
func_args.arg(genCall("VolatilePtrTuple", types, work_bufs));
func_args.arg(genCall("LocalTuple", types, init_vals));
func_args.arg(genComputeBlockDim());

// global_sync_buffer
const auto sync_buffer =
Expand Down Expand Up @@ -2407,6 +2432,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genCall("LocalTuple", data_types, init_args[0]));
func_args.arg(genCall("LocalTuple", data_types, init_args[1]));
func_args.arg(genCall("LocalTuple", index_types, init_args[2]));
// block_dim
func_args.arg(genComputeBlockDim());
// work buffer
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[0]));
func_args.arg(genCall("VolatilePtrTuple", data_types, work_bufs[1]));
Expand Down Expand Up @@ -2498,6 +2525,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genVariableNameConvertAlignedArray(input.get(1)));
func_args.arg(genVariableNameConvertAlignedArray(input.get(2)))
.append("[0]");
// block_dim
func_args.arg(genComputeBlockDim());

// global buf
for (const auto i : c10::irange(3)) {
Expand Down Expand Up @@ -2652,6 +2681,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genStaticCast(data_type, 0));
func_args.arg(genInline(gwop->entrance_index()));
func_args.arg(genInline(gwop->entrances()));
func_args.arg(genComputeBlockDim());

indent() << genCall("welford::gridWelford", template_args, func_args)
<< ";\n";
Expand Down Expand Up @@ -2751,6 +2781,8 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(read_pred).arg(write_pred);
// init_val
func_args.arg(genCall("LocalTuple", data_type_args, init_args));
// block_dim
func_args.arg(genComputeBlockDim());
// reduction_op
func_args.arg(genTemplate(
"welfordCombine", ArgumentBuilder().arg(data_type).arg(index_type)));
Expand Down Expand Up @@ -2877,6 +2909,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(genInline(write_pred));
}
func_args.arg(genCall(data_type, genInline(init)));
func_args.arg(genComputeBlockDim());

indent() << genCall("blockIterGroupedYdimReduce", template_args, func_args)
<< ";\n";
Expand Down Expand Up @@ -3315,6 +3348,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
.append(sync_idx)
.append("]");
sync_call_args.arg(sync_segment_size);
sync_call_args.arg(genComputeBlockDim());

auto sync_call =
genCall("grid_sync::sync", sync_call_template_parms, sync_call_args);
Expand Down
66 changes: 45 additions & 21 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1559,15 +1559,15 @@ void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) {
}

static DataType getMmaInputAType(MmaMacro macro) {
int warp_group_size = isHopper(macro) ? 128 : 32;
int size = getM(macro) * getK(macro) / warp_group_size /
2 /* halves per 32bit register */;
int64_t warp_group_size = isHopper(macro) ? 128L : 32L;
int64_t size = getM(macro) * getK(macro) / warp_group_size /
2L /* halves per 32bit register */;
return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size};
}

static DataType getMmaInputBType(MmaMacro macro) {
int size = getN(macro) * getK(macro) / 32 /* threads per warp */ /
2 /* halves per 32bit register */;
int64_t size = getN(macro) * getK(macro) / 32L /* threads per warp */ /
2L /* halves per 32bit register */;
return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size};
}

Expand Down Expand Up @@ -1842,8 +1842,8 @@ Val* hardCodedIndexGenerationForStMatrix(
// To account for the threadIdx.y, we have to add it to the offset:
// offset_from_tdy = threadIdx.y * tma_m * tma_n * 2 (half)
//
// Now, lets apply stmatrix tile to the TMA Box.
// [NO(2), MO(4), MI(16), NIO(4), NII(16)].
// Now, lets apply stmatrix tile (16, 16) to the TMA Box [NO(2), M(64), NI(64)].
// [NO(2), MO(4), MI(16), NIO(4), NII(16)].
//
// A warp group of 128 threads contains four warps. StMatrix is a warp-level
// operation, so four StMatrix operations can be issued simultaneously by the
Expand All @@ -1865,6 +1865,7 @@ Val* hardCodedIndexGenerationForStMatrix(
// domain is scheduled as [NO(2), M(64), NI(64)]. Therefore, we must store the
// data in shared memory in [M(64), NI(64)] contiguous tiles.
//
// NOTE: This offset is skipped if for-loop is trivial
// To account for the outer_index, we have to add it to the offset:
// offset_from_outer_index = outer_index * tma_m * NI(64) * 2 (half)
//
Expand Down Expand Up @@ -1928,8 +1929,13 @@ Val* hardCodedIndexGenerationForStMatrix(
// with the 8 rows of the matrix to avoid bank conflicts. This swizzle pattern
// is repeated along the rows of the TMA box.
//
// The number of distinct swizzle rows is number of bytes for swizzle divided by
// size of megabank (16B). The number of times a swizzle pattern is repeated to
// fill core (8, 8) matrix is number of swizzle rows (8) divided by number of
// distinct rows.
//
// Swizzle column
// row_in_swizzle_pattern = row % swizzle_row_size(8)
// row_in_swizzle_pattern = (row % swizzle_row_size(8)) / swizzle_repetitions
// swizzle_col = column XOR row_in_swizzle_pattern
//
// Calculate Tile Offset
Expand All @@ -1939,7 +1945,7 @@ Val* hardCodedIndexGenerationForStMatrix(
//
// Get shared memory offset
// smem_offset = offset_from_tdy + offset_from_outer_index + tile_offset
Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
Val* hardCodedIndexGenerationForStMatrixSwizzle(
const LoadStoreOp* ldst,
ForLoop* loop,
const int64_t stsm_m_tile,
Expand All @@ -1958,16 +1964,19 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(

NVF_ERROR(ldst->out()->isA<TensorView>());
TensorView* out_tv = ldst->out()->as<TensorView>();
NVF_ERROR(getSwizzle(out_tv) == MmaInputSmemSwizzle::B128);
MmaInputSmemSwizzle swizzle = getSwizzle(out_tv);
int64_t swizzle_bytes = getBytesFromSwizzle(swizzle);

// Constants
constexpr int64_t dtype_size = 2;
constexpr int64_t warp_size = 32;
constexpr int64_t swizzle_row_size = 8;
constexpr int64_t stsm_column_size = 8;
constexpr int64_t swizzle_n_tile = 64;
constexpr int64_t megabank_size_bytes = 16;

// Derived constants
const int64_t swizzle_n_tile = swizzle_bytes / dtype_size;
const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes;
constexpr int64_t stsm_column_stride = stsm_column_size * dtype_size;
const int64_t swizzle_n_iter = swizzle_n_tile / stsm_n_tile;
const int64_t swizzle_n_tile_stride = swizzle_n_tile * dtype_size;
Expand Down Expand Up @@ -2000,8 +2009,6 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
Val* warp_id = SimplifyingIrBuilder::divExpr(TDX, warp_size_val);
Val* lane_id = SimplifyingIrBuilder::modExpr(TDX, warp_size_val);

Val* outer_index =
SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val);
Val* inner_index =
SimplifyingIrBuilder::modExpr(loop->index(), swizzle_n_iter_val);

Expand All @@ -2021,6 +2028,17 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
// Swizzle Column
Val* row_in_swizzle_pattern =
SimplifyingIrBuilder::modExpr(row, swizzle_row_size_val);

// The swizzle pattern is repeated to fill (8, 8) matrix for 64B and 32B
// swizzles. swizzle_row_iter is the number of repetitions to fill 8 rows
// with distict swizzle rows.
const int64_t swizzle_row_iter = swizzle_row_size / distinct_swizzle_row_size;
if (swizzle_row_iter > 1) {
Val* swizzle_row_iter_val =
IrBuilder::create<Val>(swizzle_row_iter, DataType::Index);
row_in_swizzle_pattern = SimplifyingIrBuilder::divExpr(
row_in_swizzle_pattern, swizzle_row_iter_val);
}
Val* swizzle_col = bitwise_xor(col, row_in_swizzle_pattern);

// Calculate Tile Offset
Expand All @@ -2031,16 +2049,22 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
Val* offset = SimplifyingIrBuilder::addExpr(row_offset, col_offset);

// Calculate Tile offset
Val* tile_offset = IrBuilder::mulExpr(outer_index, tile_stride_val);
// Skip tile offset if loop is trivial.
if (!loop->stop()->isOneInt()) {
Val* outer_index =
SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val);
Val* tile_offset =
SimplifyingIrBuilder::mulExpr(outer_index, tile_stride_val);
offset = SimplifyingIrBuilder::addExpr(tile_offset, offset);
}

// Calculate TDY offset
Val* tdy_offset = IrBuilder::mulExpr(TDY, tdy_stride_val);
Val* tdy_offset = SimplifyingIrBuilder::mulExpr(TDY, tdy_stride_val);
offset = SimplifyingIrBuilder::addExpr(tdy_offset, offset);

// Create shared memory TensorIndex
Val* out_index = SimplifyingIrBuilder::addExpr(
IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)),
SimplifyingIrBuilder::addExpr(
tdy_offset, SimplifyingIrBuilder::addExpr(tile_offset, offset)));
IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)), offset);
Val* out = IrBuilder::create<kir::TensorIndex>(
dynamic_cast<TensorView*>(ldst->out()), out_index);
return out;
Expand Down Expand Up @@ -2092,11 +2116,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
ldst, for_loops_[0], m_tile, n_tile, m, n);
break;
case MmaInputSmemSwizzle::B128:
out = hardCodedIndexGenerationForStMatrix128BSwizzle(
case MmaInputSmemSwizzle::B64:
case MmaInputSmemSwizzle::B32:
out = hardCodedIndexGenerationForStMatrixSwizzle(
ldst, for_loops_[0], m_tile, n_tile, m, n);
break;
case MmaInputSmemSwizzle::B32:
case MmaInputSmemSwizzle::B64:
default:
NVF_ERROR("Unsupported Swizzle Type for StMatrix");
}
Expand Down
4 changes: 2 additions & 2 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ c10::intrusive_ptr<c10d::Work> postAllgather(
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
auto splits = at::split(output_tensor, /*split_size=*/1, /*dim=*/0);
assertBufferCount(splits, communication->team().size());
auto splits =
at::tensor_split(output_tensor, communication->team_size(), /*dim=*/0);
assertBuffersHaveSameSize({input_tensor}, splits);

// allgather primitive in c10d induces extra buffering time to copy out the
Expand Down
5 changes: 5 additions & 0 deletions csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ class Communication : public Expr {
return attribute<Team>(1);
}

// A convenience helper so the user doesn't need to convert size_t to int64_t.
int64_t team_size() const {
return static_cast<int64_t>(team().size());
}

DeviceIdxType root() const {
return attribute<DeviceIdxType>(2);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/multidevice/lower_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ void lowerToReduceScatter(
std::vector<Communication*>& comms) {
const DeviceMesh& mesh = input_tv->getDeviceMesh();
auto reduction_axis = output_tv->getReductionAxis().value();
auto scattered_axis = getShardedAxis(output_tv);
auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx);
// The output tensor is sharded on scattered_axis and needs to be mapped
// back onto the input. The input has an reduced axis, so the scattered axis
// is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The
Expand Down
Loading

0 comments on commit d69811f

Please sign in to comment.