Skip to content

Commit

Permalink
Dequant improvements rebase (ggerganov#8255)
Browse files Browse the repository at this point in the history
* Single load for half2

* Store scales in local mem

* Vec load quantized values
  • Loading branch information
AidanBeltonS authored Jul 3, 2024
1 parent a27152b commit fadde67
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
6 changes: 6 additions & 0 deletions ggml/src/ggml-sycl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,10 @@ static __dpct_inline__ float warp_reduce_max(float x,
return x;
}

// Helper for vec loading aligned data
template <typename Tp, int n>
inline sycl::vec<Tp, n> vec_aligned_load(const Tp* aligned_ptr) {
return *reinterpret_cast<const sycl::vec<Tp, n>*>(aligned_ptr);
}

#endif // GGML_SYCL_COMMON_HPP
7 changes: 5 additions & 2 deletions ggml/src/ggml-sycl/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,15 @@ static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k,
dpct::has_capability_or_fail(stream->get_device(),
{sycl::aspect::fp16});

stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
stream->submit([&](sycl::handler &cgh) {
sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(12), cgh);
cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) *
sycl::range<3>(1, 1, 32),
sycl::range<3>(1, 1, 32)),
[=](sycl::nd_item<3> item_ct1) {
dequantize_block_q4_K(vx, y, item_ct1);
dequantize_block_q4_K(vx, y, scale_local_acc.get_pointer(), item_ct1);
});
});
}
}

Expand Down
30 changes: 19 additions & 11 deletions ggml/src/ggml-sycl/dequantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri
#if QK_K == 256
static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
d = q[j] & 63;
m = q[j + 4] & 63;
} else {
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
Expand All @@ -303,7 +304,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8

template<typename dst_t>
static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
const sycl::nd_item<3> &item_ct1) {
uint8_t* scales_local, const sycl::nd_item<3> &item_ct1) {
const block_q4_K * x = (const block_q4_K *) vx;

const int i = item_ct1.get_group(2);
Expand All @@ -318,19 +319,26 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri

dst_t * y = yy + i*QK_K + 64*il + n*ir;

const float dall = x[i].dm[0];
const float dmin = x[i].dm[1];
const sycl::half2 dm = x[i].dm;
const float dall = dm[0];
const float dmin = dm[1];

const uint8_t * q = x[i].qs + 32*il + n*ir;
if (tid < 12)
scales_local[tid] = x[i].scales[tid];
item_ct1.barrier(sycl::access::fence_space::local_space);

uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
get_scale_min_k4(is + 0, scales_local, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, scales_local, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;

sycl::vec<uint8_t, n> q_vec = vec_aligned_load<uint8_t, n>(x[i].qs + 32*il + n*ir);
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
y[l + 0] = d1 * (q_vec[l] & 0xF) - m1;
y[l +32] = d2 * (q_vec[l] >> 4) - m2;
}
#else
const int tid = item_ct1.get_local_id(2);
Expand Down

0 comments on commit fadde67

Please sign in to comment.