Skip to content

Commit

Permalink
add new q40 dmmv
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Jul 10, 2024
1 parent a59f8fd commit 6c2c841
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions ggml/src/ggml-sycl/dmmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,63 @@ static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat *
}
}

static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows,
const sycl::nd_item<3> &item_ct1) {
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1);
if (row >= nrows) {
return;
}

const int tid = item_ct1.get_local_id(2);

#ifdef GGML_SYCL_F16
sycl::half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
#else
float tmp = 0.0f;
#endif // GGML_SYCL_F16
int constexpr ColTile = QK4_0 / WARP_SIZE;
static_assert(QK4_0 % WARP_SIZE == 0);
static_assert(ColTile == 2);

const block_q4_0 * x = (const block_q4_0 *) vx;

for (int i = 0; i < ncols; i += QK4_0) {
const int col = i + tid * ColTile;
const int ib = (row * ncols + col) / QK4_0; // x block index
const int iqs = (col % QK4_0) / QR4_0; // x quant index
const int iybs = col - col % QK4_0; // y block start index
const dfloat d = x[ib].d;

const int vui = x[ib].qs[iqs];
dfloat2 v;
v.x() = (vui & 0xF) * d;
v.y() = (vui >> 4) * d;
#ifdef GGML_SYCL_F16
dfloat2 t1{ y[iybs + iqs + 0],
y[iybs + iqs + QK4_0 / 2] };
tmp += v * t1;
#else
tmp += v.x() * y[iybs + iqs + 0];
tmp += v.y() * y[iybs + iqs + QK4_0 / 2];
#endif
}

// sum up partial sums and write back result
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
tmp +=
dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
}

if (tid == 0) {
#ifdef GGML_SYCL_F16
dst[row] = tmp.x() + tmp.y();
#else
dst[row] = tmp;
#endif // GGML_SYCL_F16
}
}

static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
Expand Down Expand Up @@ -764,6 +821,7 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
#if WARP_SIZE==32
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
Expand All @@ -780,6 +838,20 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
vx, y, dst, ncols, nrows, item_ct1);
});
}
#else
GGML_ASSERT(ncols % WARP_SIZE == 0);
const sycl::range<3> block_nums(1, 1, nrows);
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_0(vx, y, dst, ncols, nrows, item_ct1);
});
}
#endif
}

static void dequantize_mul_mat_vec_q4_1_sycl(const void *vx, const dfloat *y,
Expand Down

0 comments on commit 6c2c841

Please sign in to comment.