Skip to content

Commit

Permalink
use dmmv as default
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Jul 18, 2024
1 parent a8c75c0 commit 0b8565d
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion ggml/src/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3632,7 +3632,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
// check data types and tensor shapes for custom matrix multiplication kernels:
bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
&& src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;

bool use_mul_mat_vec_q = ggml_is_quantized(src0->type)
&& src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
Expand Down
39 changes: 20 additions & 19 deletions ggml/src/ggml-sycl/dmmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream,char* vx_tmp) {
dpct::queue_ptr stream) {
#if WARP_SIZE==32
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
Expand Down Expand Up @@ -911,7 +911,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_0(vx_tmp, y, dst, ncols, nrows, item_ct1);
dequantize_mul_mat_vec_q4_0(vx, y, dst, ncols, nrows, item_ct1);
});
}
#endif
Expand Down Expand Up @@ -1083,7 +1083,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
float *dst_dd_i, const int64_t row_low, const int64_t row_high,
const int64_t src1_ncols, const int64_t src1_padded_row_size,
const int64_t src1_ncols, const int64_t src1_padded_col_size,
const dpct::queue_ptr &stream) {

const int64_t ne00 = src0->ne[0];
Expand Down Expand Up @@ -1111,50 +1111,51 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
#endif // GGML_SYCL_F16
ggml_sycl_pool_alloc<char> src0_test(ctx.pool());
char *src0_test_ptr = src0_test.alloc(ggml_nbytes(src0));

switch (src0->type) {
for (int i = 0; i < src1_ncols; i++)
{
const dfloat* src1_dfloat_bs = src1_dfloat + i * src1_padded_col_size;
float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0];
switch (src0->type) {
case GGML_TYPE_Q4_0:
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream,src0_test_ptr);
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_1:
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_0:
dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_1:
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q8_0:
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q3_K:
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_K:
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_K:
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_Q6_K:
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
case GGML_TYPE_F16:
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
default:
printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type);
GGML_ASSERT(false);
break;
}
}

(void) src1;
(void) dst;
(void) src1_ddq_i;
(void) src1_ncols;
(void) src1_padded_row_size;
}

0 comments on commit 0b8565d

Please sign in to comment.