From 5dba2575065f5e27d468f2776e3d460a21d916e6 Mon Sep 17 00:00:00 2001 From: wchen61 Date: Fri, 3 Jan 2025 06:58:56 +0800 Subject: [PATCH] Resolve race conditions in Marlin kernel (#11493) Signed-off-by: wchen61 --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 40 ++++++++++---------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 0c698ced7713d..04ef842fbdf95 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -834,6 +834,7 @@ __global__ void Marlin( int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage); + int4* sh_red = sh_s + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; @@ -932,11 +933,11 @@ __global__ void Marlin( int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } + if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) { s_gl_rd += s_gl_rd_delta; } } else { @@ -1038,9 +1039,7 @@ __global__ void Marlin( // No act-order case if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { int warp_id = threadIdx.x / 32; @@ -1339,15 +1338,15 @@ __global__ void Marlin( int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + float* c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh[red_sh_wr] = + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } @@ -1357,7 +1356,7 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < 4 * 2; i++) { float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += @@ -1397,7 +1396,7 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], + &sh_red[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); @@ -1410,7 +1409,7 @@ __global__ void Marlin( for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast( @@ -1461,10 +1460,10 @@ __global__ void Marlin( float* frag_c_ptr = reinterpret_cast(&frag_c); #pragma unroll for (int k = 0; k < th_size; k++) { - sh[threadIdx.x] = + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; - float* sh_c_ptr = reinterpret_cast(&sh[threadIdx.x]); + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); #pragma unroll for (int f = 0; f < 4; f++) { frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; @@ -1515,7 +1514,7 @@ __global__ void Marlin( res = __hmul2(res, s[0]); } - ((scalar_t2*)sh)[idx] = res; + ((scalar_t2*)sh_red)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1543,7 +1542,7 @@ __global__ void Marlin( i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; + C[c_gl_wr] = sh_red[c_sh_rd]; c_gl_wr += c_gl_wr_delta; c_sh_rd += c_sh_rd_delta; } @@ -1865,9 +1864,12 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; + float reduce_size = max(th_config.num_threads * 32 * 4, + (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); + return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size); } bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,