diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 70a94fc16b99d0..13a61b850586ac 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -91,6 +91,63 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * } } +static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1); + if (row >= nrows) { + return; + } + + const int tid = item_ct1.get_local_id(2); + +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics +#else + float tmp = 0.0f; +#endif // GGML_SYCL_F16 + int constexpr ColTile = QK4_0 / WARP_SIZE; + static_assert(QK4_0 % WARP_SIZE == 0); + static_assert(ColTile == 2); + + const block_q4_0 * x = (const block_q4_0 *) vx; + + for (int i = 0; i < ncols; i += QK4_0) { + const int col = i + tid * ColTile; + const int ib = (row * ncols + col) / QK4_0; // x block index + const int iqs = (col % QK4_0) / QR4_0; // x quant index + const int iybs = col - col % QK4_0; // y block start index + const dfloat d = x[ib].d; + + const int vui = x[ib].qs[iqs]; + dfloat2 v; + v.x() = (vui & 0xF) * d; + v.y() = (vui >> 4) * d; +#ifdef GGML_SYCL_F16 + dfloat2 t1{ y[iybs + iqs + 0], + y[iybs + iqs + QK4_0 / 2] }; + tmp += v * t1; +#else + tmp += v.x() * y[iybs + iqs + 0]; + tmp += v.y() * y[iybs + iqs + QK4_0 / 2]; +#endif + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif // GGML_SYCL_F16 + } +} + static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -764,6 +821,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, dpct::queue_ptr stream) { +#if WARP_SIZE==32 GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead @@ -780,6 +838,20 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y, vx, y, dst, ncols, nrows, item_ct1); }); } +#else + GGML_ASSERT(ncols % WARP_SIZE == 0); + const sycl::range<3> block_nums(1, 1, nrows); + const sycl::range<3> block_dims(1, 1, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec_q4_0(vx, y, dst, ncols, nrows, item_ct1); + }); + } +#endif } static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,