diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index a99adf0f808cc..75dafa617c6b5 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -3632,7 +3632,7 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // check data types and tensor shapes for custom matrix multiplication kernels: bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 2a3a2046be0e4..33f10ea8d9b22 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -883,7 +883,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, - dpct::queue_ptr stream,char* vx_tmp) { + dpct::queue_ptr stream) { #if WARP_SIZE==32 GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; @@ -911,7 +911,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, stream->parallel_for( sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { - dequantize_mul_mat_vec_q4_0(vx_tmp, y, dst, ncols, nrows, item_ct1); + dequantize_mul_mat_vec_q4_0(vx, y, dst, ncols, nrows, item_ct1); }); } #endif @@ -1083,7 +1083,7 @@ void ggml_sycl_op_dequantize_mul_mat_vec( const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, float *dst_dd_i, const int64_t row_low, const int64_t row_high, - const int64_t src1_ncols, const int64_t src1_padded_row_size, + const int64_t src1_ncols, const int64_t src1_padded_col_size, const dpct::queue_ptr &stream) { const int64_t ne00 = src0->ne[0]; @@ -1111,50 +1111,51 @@ void ggml_sycl_op_dequantize_mul_mat_vec( #endif // GGML_SYCL_F16 ggml_sycl_pool_alloc src0_test(ctx.pool()); char *src0_test_ptr = src0_test.alloc(ggml_nbytes(src0)); - - switch (src0->type) { + for (int i = 0; i < src1_ncols; i++) + { + const dfloat* src1_dfloat_bs = src1_dfloat + i * src1_padded_col_size; + float* dst_dd_i_bs = dst_dd_i + i * dst->ne[0]; + switch (src0->type) { case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream,src0_test_ptr); + dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q4_1: - dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q5_0: - dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q5_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q5_1: - dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q2_K: - dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q5_K: - dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_F16: - convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat_bs, dst_dd_i_bs, ne00, row_diff, stream); break; default: printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); GGML_ASSERT(false); break; + } } - (void) src1; (void) dst; (void) src1_ddq_i; - (void) src1_ncols; - (void) src1_padded_row_size; }