diff --git a/csrc/scheduler/matmul_heuristic_plugin.cpp b/csrc/scheduler/matmul_heuristic_plugin.cpp index b3821787b67..5175be27112 100644 --- a/csrc/scheduler/matmul_heuristic_plugin.cpp +++ b/csrc/scheduler/matmul_heuristic_plugin.cpp @@ -135,6 +135,8 @@ void copyParamsToConfig(KernelConfig* config, const MatmulParams* mparams) { }; config->load_stages = mparams->circular_buffer_options.smem_circular_buffer_stage; + config->prefetch_distance = + mparams->circular_buffer_options.smem_circular_buffer_prefetch; config->async_gmem_load_operands = mparams->async_gmem_load_operands; setConfigTile(config->cta_tile, mparams->tile_sizes.cta_tile); setConfigTile(config->warp_tile, mparams->tile_sizes.warp_tile); @@ -163,6 +165,8 @@ void copyConfigToParams(MatmulParams* mparams, const KernelConfig* config) { setGemmTile(mparams->tile_sizes.warp_tile, config->warp_tile); mparams->circular_buffer_options.smem_circular_buffer_stage = config->load_stages; + mparams->circular_buffer_options.smem_circular_buffer_prefetch = + config->prefetch_distance; mparams->async_gmem_load_operands = config->async_gmem_load_operands; // Update mma macro if necessary to match provided instruction tile MmaMacroEncode menc(mparams->mma_macro); // this will record the family diff --git a/csrc/scheduler/matmul_heuristic_plugin_api.h b/csrc/scheduler/matmul_heuristic_plugin_api.h index 224705530e5..348c3c15f46 100644 --- a/csrc/scheduler/matmul_heuristic_plugin_api.h +++ b/csrc/scheduler/matmul_heuristic_plugin_api.h @@ -74,6 +74,7 @@ struct KernelConfig { Tile instruction_tile = {16, 16, 16}; uint16_t splitk_factor = 1; uint8_t load_stages = 2; + uint8_t prefetch_distance = 1; uint8_t grid_swizzle_factor = 0; uint8_t cta_order = 0; bool circular_buffer_smem_read = true; diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index 415a28829c3..1dc7325d99d 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -91,6 +91,9 @@ void limitCircularBufferingSmemOperands( mparams->circular_buffer_options.circular_buffer_smem_write = (stages != 1); mparams->circular_buffer_options.smem_circular_buffer_stage = (int)stages; + mparams->circular_buffer_options.smem_circular_buffer_prefetch = std::min( + mparams->circular_buffer_options.smem_circular_buffer_prefetch, + (int)stages - 1); } //! A wrapper for core heuristics initialization. @@ -156,6 +159,8 @@ inline bool initCoreHeuristics( mparams->circular_buffer_options.circular_buffer_smem_write = true; mparams->circular_buffer_options.circular_buffer_smem_read = true; mparams->circular_buffer_options.smem_circular_buffer_stage = stages; + mparams->circular_buffer_options.smem_circular_buffer_prefetch = + stages - 1; } } @@ -181,6 +186,9 @@ inline bool initCoreHeuristics( // most. mparams->circular_buffer_options.smem_circular_buffer_stage = std::min( 2, mparams->circular_buffer_options.smem_circular_buffer_stage); + mparams->circular_buffer_options.smem_circular_buffer_prefetch = std::min( + mparams->circular_buffer_options.smem_circular_buffer_prefetch, + mparams->circular_buffer_options.smem_circular_buffer_stage - 1); } return true; }