Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: ElizaWszola <[email protected]>
  • Loading branch information
ElizaWszola committed Nov 11, 2024
1 parent 70abf99 commit 04c842f
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions csrc/quantization/gptq_marlin/gptq_marlin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1821,8 +1821,7 @@ thread_config_t large_batch_thread_configs[] = {

int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits, int group_size,
bool has_act_order, bool is_k_full,
bool is_zp_float) {
bool has_act_order, bool is_k_full) {
bool cache_scales_chunk = has_act_order && !is_k_full;

int tb_n = th_config.thread_n;
Expand Down Expand Up @@ -1890,7 +1889,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, bool has_act_order, bool is_k_full,
int max_shared_mem, bool is_zp_float) {
int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
Expand All @@ -1915,7 +1914,7 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
// Determine cache for scales
int scales_cache_size =
get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full, is_zp_float);
group_size, has_act_order, is_k_full);

// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
Expand Down Expand Up @@ -1950,22 +1949,22 @@ int determine_reduce_max_m(int prob_m, int max_par) {
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
int max_shared_mem, bool is_zp_float) {
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem, is_zp_float)) {
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, has_act_order, is_k_full,
max_shared_mem, is_zp_float)) {
max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
Expand Down Expand Up @@ -2115,23 +2114,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
} else {
// Auto config
exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,
group_size, has_act_order, is_k_full,
max_shared_mem, is_zp_float);
exec_cfg =
determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem);
}

TORCH_CHECK(
exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, has_act_order,
is_k_full, max_shared_mem, is_zp_float),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m,
", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", has_act_order = ", has_act_order,
", is_k_full = ", is_k_full, ", max_shared_mem = ", max_shared_mem);
TORCH_CHECK(exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks,
prob_m, prob_n, prob_k, num_bits, group_size,
has_act_order, is_k_full, max_shared_mem),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [",
prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size,
", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full,
", max_shared_mem = ", max_shared_mem);

int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = exec_cfg.tb_cfg.thread_k;
Expand Down

0 comments on commit 04c842f

Please sign in to comment.