Skip to content

Commit

Permalink
Update heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobhinkle committed Dec 10, 2024
1 parent 231fa5f commit faf0f95
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 0 deletions.
4 changes: 4 additions & 0 deletions csrc/scheduler/matmul_heuristic_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) {
};
config->load_stages =
mparams->circular_buffer_options.smem_circular_buffer_stage;
config->prefetch_distance =
mparams->circular_buffer_options.smem_circular_buffer_prefetch;
config->async_gmem_load_operands = mparams->async_gmem_load_operands;
setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile);
setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile);
Expand Down Expand Up @@ -163,6 +165,8 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) {
setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile);
mparams->circular_buffer_options.smem_circular_buffer_stage =
config->load_stages;
mparams->circular_buffer_options.smem_circular_buffer_prefetch =
config->prefetch_distance;
mparams->async_gmem_load_operands = config->async_gmem_load_operands;
// Update mma macro if necessary to match provided instruction tile
MmaMacroEncode menc(mparams->mma_macro); // this will record the family
Expand Down
1 change: 1 addition & 0 deletions csrc/scheduler/matmul_heuristic_plugin_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct KernelConfig {
Tile instruction_tile = {16, 16, 16};
uint16_t splitk_factor = 1;
uint8_t load_stages = 2;
uint8_t prefetch_distance = 1;
uint8_t grid_swizzle_factor = 0;
uint8_t cta_order = 0;
bool circular_buffer_smem_read = true;
Expand Down
8 changes: 8 additions & 0 deletions csrc/scheduler/matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ void limitCircularBufferingSmemOperands(

mparams->circular_buffer_options.circular_buffer_smem_write = (stages != 1);
mparams->circular_buffer_options.smem_circular_buffer_stage = (int)stages;
mparams->circular_buffer_options.smem_circular_buffer_prefetch = std::min(
mparams->circular_buffer_options.smem_circular_buffer_prefetch,
(int)stages - 1);
}

//! A wrapper for core heuristics initialization.
Expand Down Expand Up @@ -156,6 +159,8 @@ inline bool initCoreHeuristics(
mparams->circular_buffer_options.circular_buffer_smem_write = true;
mparams->circular_buffer_options.circular_buffer_smem_read = true;
mparams->circular_buffer_options.smem_circular_buffer_stage = stages;
mparams->circular_buffer_options.smem_circular_buffer_prefetch =
stages - 1;
}
}

Expand All @@ -181,6 +186,9 @@ inline bool initCoreHeuristics(
// most.
mparams->circular_buffer_options.smem_circular_buffer_stage = std::min(
2, mparams->circular_buffer_options.smem_circular_buffer_stage);
mparams->circular_buffer_options.smem_circular_buffer_prefetch = std::min(
mparams->circular_buffer_options.smem_circular_buffer_prefetch,
mparams->circular_buffer_options.smem_circular_buffer_stage - 1);
}
return true;
}
Expand Down

0 comments on commit faf0f95

Please sign in to comment.