Skip to content

Commit

Permalink
Unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Sep 20, 2024
1 parent 936f2b9 commit 98ec9b6
Show file tree
Hide file tree
Showing 14 changed files with 400 additions and 136 deletions.
19 changes: 9 additions & 10 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks,
int thread_k_blocks, bool has_act_order, int group_blocks,
int num_threads, int blocks, int max_shared_mem, cudaStream_t stream,
const int4* A_ptr, const int4* B_ptr, int4* C_ptr,
const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr,
const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr,
int num_groups, int expert_idx, int num_experts, int topk, int prob_m,
int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input,
bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) {

vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;

if (false) {
Expand Down
18 changes: 9 additions & 9 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks,
int thread_k_blocks, bool has_act_order, int group_blocks,
int num_threads, int blocks, int max_shared_mem, cudaStream_t stream,
const int4* A_ptr, const int4* B_ptr, int4* C_ptr,
const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr,
const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr,
int num_groups, int expert_idx, int num_experts, int topk, int prob_m,
int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input,
bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks);
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);

} // namespace marlin_moe
19 changes: 9 additions & 10 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks,
int thread_k_blocks, bool has_act_order, int group_blocks,
int num_threads, int blocks, int max_shared_mem, cudaStream_t stream,
const int4* A_ptr, const int4* B_ptr, int4* C_ptr,
const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr,
const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr,
int num_groups, int expert_idx, int num_experts, int topk, int prob_m,
int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input,
bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) {

vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;

if (false) {
Expand Down
18 changes: 9 additions & 9 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4b8(
vllm::ScalarType const& q_type, int thread_n_blocks,
int thread_k_blocks, bool has_act_order, int group_blocks,
int num_threads, int blocks, int max_shared_mem, cudaStream_t stream,
const int4* A_ptr, const int4* B_ptr, int4* C_ptr,
const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr,
const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr,
int num_groups, int expert_idx, int num_experts, int topk, int prob_m,
int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input,
bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks);
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);

} // namespace marlin_moe
19 changes: 9 additions & 10 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku8(
vllm::ScalarType const& q_type, int thread_n_blocks,
int thread_k_blocks, bool has_act_order, int group_blocks,
int num_threads, int blocks, int max_shared_mem, cudaStream_t stream,
const int4* A_ptr, const int4* B_ptr, int4* C_ptr,
const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr,
const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr,
int num_groups, int expert_idx, int num_experts, int topk, int prob_m,
int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input,
bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) {

vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;

if (false) {
Expand Down
18 changes: 9 additions & 9 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku8.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku8(
vllm::ScalarType const& q_type, int thread_n_blocks,
int thread_k_blocks, bool has_act_order, int group_blocks,
int num_threads, int blocks, int max_shared_mem, cudaStream_t stream,
const int4* A_ptr, const int4* B_ptr, int4* C_ptr,
const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr,
const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr,
int num_groups, int expert_idx, int num_experts, int topk, int prob_m,
int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input,
bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks);
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);

} // namespace marlin_moe
19 changes: 9 additions & 10 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks,
int thread_k_blocks, bool has_act_order, int group_blocks,
int num_threads, int blocks, int max_shared_mem, cudaStream_t stream,
const int4* A_ptr, const int4* B_ptr, int4* C_ptr,
const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr,
const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr,
int num_groups, int expert_idx, int num_experts, int topk, int prob_m,
int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input,
bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks) {

vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;

if (false) {
Expand Down
18 changes: 9 additions & 9 deletions csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
namespace marlin_moe {

bool call_marlin_moe_kernel_ku8b128(
vllm::ScalarType const& q_type, int thread_n_blocks,
int thread_k_blocks, bool has_act_order, int group_blocks,
int num_threads, int blocks, int max_shared_mem, cudaStream_t stream,
const int4* A_ptr, const int4* B_ptr, int4* C_ptr,
const int* sorted_ids_ptr, const float* topk_weights_ptr, const int4* s_ptr,
const int4* zp_ptr, const int* g_idx_ptr, int* expert_offsets_ptr,
int num_groups, int expert_idx, int num_experts, int topk, int prob_m,
int prob_n, int prob_k, int tot_m, int* locks, bool replicate_input,
bool apply_weights, int m_block, int max_par, int cfg_max_m_blocks);
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);

}
12 changes: 6 additions & 6 deletions csrc/moe/marlin_moe_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
const int4* zp_ptr =
(const int4*)zp +
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
prob_n / 8) *
prob_n / 4) *
expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
Expand Down Expand Up @@ -565,12 +565,12 @@ torch::Tensor marlin_gemm_moe(
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
TORCH_CHECK(b_zeros.size(0) == num_groups,
"b_zeros dim 0 = ", b_zeros.size(0),
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1),
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_scales.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}

Expand Down
Loading

0 comments on commit 98ec9b6

Please sign in to comment.