Skip to content

Commit

Permalink
add new format
Browse files Browse the repository at this point in the history
  • Loading branch information
luoyu-intel committed Jul 11, 2024
1 parent 0a64cae commit 38ff2ac
Showing 1 changed file with 95 additions and 21 deletions.
116 changes: 95 additions & 21 deletions ggml/src/ggml-sycl/dmmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,92 @@ static void dequantize_mul_mat_vec_q4_0(const void * __restrict__ vx, const dflo
#else
float tmp = 0.0f;
#endif // GGML_SYCL_F16
int constexpr ColTile = QK4_0 / WARP_SIZE;
static_assert(QK4_0 % WARP_SIZE == 0);
static_assert(ColTile == 2);

static_assert(QK4_0 == 32);
static_assert(WARP_SIZE == 16);
const block_q4_0 * x = (const block_q4_0 *) vx;
int constexpr WarpK = WARP_SIZE * QK4_0;
int constexpr Unroll = 2;
const int iqs = tid; // x quant index
for (int i = 0; i < ncols; i += QK4_0) {
const int ib = (row * ncols + i) / QK4_0; // x block index
const int iybs = i; // y block start index
const dfloat d = x[ib].d;

const int vui = x[ib].qs[iqs];
dfloat2 v;
v.x() = ((vui & 0xF) - 8) * d;
v.y() = ((vui >> 4) - 8) * d;
int ncols_pad = ncols - ncols % (WarpK * Unroll);
int i = 0;
for (; i < ncols_pad; i += WarpK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++)
{
const int iybs = i + tid * QK4_0 + iu * WarpK; // y block start index
const int ib = (row * ncols + i) / QK4_0 + tid + iu * WARP_SIZE; // x block index
const dfloat d = *(sycl::half *)((char *)x + ncols * nrows / 2 + ib * 2);
sycl::vec<uint8_t, QK4_0 / 2> tmp_qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)((char *)x + ib * QK4_0 / 2);
int constexpr KUnroll = 1;
#pragma unroll
for (int ir = 0; ir < QK4_0 / 2; ir += KUnroll)
{
const int vui = tmp_qs[ir];
dfloat2 v;
v.x() = ((vui & 0xF) - 8) * d;
v.y() = ((vui >> 4) - 8) * d;
#ifdef GGML_SYCL_F16
dfloat2 t1{ y[iybs + iqs + 0],
y[iybs + iqs + QK4_0 / 2] };
tmp += v * t1;
dfloat2 t1{ y[iybs + ir + 0],
y[iybs + ir + QK4_0 / 2] };
tmp += v * t1;
#else
tmp += v.x() * y[iybs + iqs + 0];
tmp += v.y() * y[iybs + iqs + QK4_0 / 2];
tmp += v.x() * y[iybs + ir + 0];
tmp += v.y() * y[iybs + ir + QK4_0 / 2];
#endif
}
}
}
for (; i < ncols_pad; i += WarpK * 1) {
#pragma unroll
for (int iu = 0; iu < 1; iu++)
{
const int iybs = i + tid * QK4_0 + iu * WarpK; // y block start index
const int ib = (row * ncols + i) / QK4_0 + tid + iu * WARP_SIZE; // x block index
const dfloat d = *(sycl::half *)((char *)x + ncols * nrows / 2 + ib * 2);
sycl::vec<uint8_t, QK4_0 / 2> tmp_qs = *(sycl::vec<uint8_t, QK4_0 / 2>*)((char *)x + ib * QK4_0 / 2);
int constexpr KUnroll = 1;
#pragma unroll
for (int ir = 0; ir < QK4_0 / 2; ir += KUnroll)
{
const int vui = tmp_qs[ir];
dfloat2 v;
v.x() = ((vui & 0xF) - 8) * d;
v.y() = ((vui >> 4) - 8) * d;
#ifdef GGML_SYCL_F16
dfloat2 t1{ y[iybs + ir + 0],
y[iybs + ir + QK4_0 / 2] };
tmp += v * t1;
#else
tmp += v.x() * y[iybs + ir + 0];
tmp += v.y() * y[iybs + ir + QK4_0 / 2];
#endif
}
}
}
#if 1
for (; i < ncols; i += QK4_0) {
const int iybs = i; // y block start index
const int ib = (row * ncols + i) / QK4_0; // x block index
#pragma unroll
for (int ir = 0; ir < 1; ir++)
{
const dfloat d = *(sycl::half *)((char *)x + ncols * nrows / 2 + ib * 2);
const int vui = *(uint8_t *)((char *)x + ib * QK4_0 / 2 + iqs);
dfloat2 v;
v.x() = ((vui & 0xF) - 8) * d;
v.y() = ((vui >> 4) - 8) * d;
#ifdef GGML_SYCL_F16
dfloat2 t1{ y[iybs + ir * QK4_0 + iqs + 0],
y[iybs + ir * QK4_0 + iqs + QK4_0 / 2] };
tmp += v * t1;
#else
tmp += v.x() * y[iybs + ir * QK4_0 + iqs + 0];
tmp += v.y() * y[iybs + ir * QK4_0 + iqs + QK4_0 / 2];
#endif
}

}
#endif
// sum up partial sums and write back result
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
Expand Down Expand Up @@ -818,7 +879,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
float *dst, const int ncols,
const int nrows,
dpct::queue_ptr stream) {
dpct::queue_ptr stream,char* vx_tmp) {
#if WARP_SIZE==32
GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0);
const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
Expand All @@ -840,13 +901,24 @@ static void dequantize_mul_mat_vec_q4_0_sycl(const void *vx, const dfloat *y,
GGML_ASSERT(ncols % WARP_SIZE == 0);
const sycl::range<3> block_nums(1, 1, nrows);
const sycl::range<3> block_dims(1, 1, WARP_SIZE);
stream->parallel_for(
nrows * ncols / QK4_0,
[=](auto i) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
const block_q4_0 * x = (const block_q4_0 *) vx;
int ib = i;
typedef sycl::vec<uint8_t, QK4_0 / 2> CT;
CT tmp = *(CT *)x[ib].qs;
*(CT*)(vx_tmp + ib * QK4_0 / 2) = tmp;
*(sycl::half *)(vx_tmp + ncols * nrows / 2 + ib * 2) = x[ib].d;

});
{
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});
stream->parallel_for(
sycl::nd_range<3>(block_nums * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
dequantize_mul_mat_vec_q4_0(vx, y, dst, ncols, nrows, item_ct1);
dequantize_mul_mat_vec_q4_0(vx_tmp, y, dst, ncols, nrows, item_ct1);
});
}
#endif
Expand Down Expand Up @@ -1044,10 +1116,12 @@ void ggml_sycl_op_dequantize_mul_mat_vec(
#else
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
#endif // GGML_SYCL_F16
ggml_sycl_pool_alloc<char> src0_test(ctx.pool());
char *src0_test_ptr = src0_test.alloc(ggml_nbytes(src0));

switch (src0->type) {
case GGML_TYPE_Q4_0:
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
dequantize_mul_mat_vec_q4_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream,src0_test_ptr);
break;
case GGML_TYPE_Q4_1:
dequantize_mul_mat_vec_q4_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
Expand Down

0 comments on commit 38ff2ac

Please sign in to comment.