diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp index 47a605b01a07a..477f5cb02db52 100644 --- a/ggml-sycl.cpp +++ b/ggml-sycl.cpp @@ -2894,6 +2894,254 @@ namespace dpct using err0 = detail::generic_error_type; using err1 = detail::generic_error_type; + static inline void dpct_free(void *ptr, sycl::queue &q = get_default_queue()) { + detail::dpct_free(ptr, q); + } + + /// dpct accessor used as device function parameter. + template class accessor; + template class accessor { + public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<3>; + accessor(pointer_t data, const sycl::range<3> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<3> &in_range) + : accessor(acc.get_pointer(), in_range) {} + accessor operator[](size_t index) const { + sycl::range<2> sub(_range.get(1), _range.get(2)); + return accessor(_data + index * sub.size(), sub); + } + + pointer_t get_ptr() const { return _data; } + + private: + pointer_t _data; + sycl::range<3> _range; + }; + template class accessor { + public: + using memory_t = detail::memory_traits; + using element_t = typename memory_t::element_t; + using pointer_t = typename memory_t::pointer_t; + using accessor_t = typename memory_t::template accessor_t<2>; + accessor(pointer_t data, const sycl::range<2> &in_range) + : _data(data), _range(in_range) {} + template + accessor(typename std::enable_if::type &acc) + : accessor(acc, acc.get_range()) {} + accessor(const accessor_t &acc, const sycl::range<2> &in_range) + : accessor(acc.get_pointer(), in_range) {} + + pointer_t operator[](size_t index) const { + return _data + _range.get(1) * index; + } + + pointer_t get_ptr() const { return _data; } + + private: + pointer_t _data; + sycl::range<2> _range; + }; + + namespace detail { + /// Device variable with address space of shared, global or constant. + template class device_memory { + public: + using accessor_t = + typename detail::memory_traits::template accessor_t; + using value_t = typename detail::memory_traits::value_t; + using dpct_accessor_t = dpct::accessor; + + device_memory() : device_memory(sycl::range(1)) {} + + /// Constructor of 1-D array with initializer list + device_memory(const sycl::range &in_range, + std::initializer_list &&init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range.size()); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + std::memcpy(_host_ptr, init_list.begin(), init_list.size() * sizeof(T)); + } + + /// Constructor of 2-D array with initializer list + template + device_memory( + const typename std::enable_if>::type &in_range, + std::initializer_list> &&init_list) + : device_memory(in_range) { + assert(init_list.size() <= in_range[0]); + _host_ptr = (value_t *)std::malloc(_size); + std::memset(_host_ptr, 0, _size); + auto tmp_data = _host_ptr; + for (auto sub_list : init_list) { + assert(sub_list.size() <= in_range[1]); + std::memcpy(tmp_data, sub_list.begin(), + sub_list.size() * sizeof(T)); + tmp_data += in_range[1]; + } + } + + /// Constructor with range + device_memory(const sycl::range &range_in) + : _size(range_in.size() * sizeof(T)), _range(range_in), + _reference(false), _host_ptr(nullptr), _device_ptr(nullptr) { + static_assert( + (Memory == global) || (Memory == constant) || (Memory == shared), + "device memory region should be global, constant or shared"); + // Make sure that singleton class mem_mgr and dev_mgr will destruct + // later than this. + detail::mem_mgr::instance(); + dev_mgr::instance(); + } + + /// Constructor with range + template + device_memory(Args... Arguments) + : device_memory(sycl::range(Arguments...)) {} + + ~device_memory() { + if (_device_ptr && !_reference) + dpct::dpct_free(_device_ptr); + if (_host_ptr) + std::free(_host_ptr); + } + + /// Allocate memory with default queue, and init memory if has initial + /// value. + void init() { init(dpct::get_default_queue()); } + /// Allocate memory with specified queue, and init memory if has initial + /// value. + void init(sycl::queue &q) { + if (_device_ptr) + return; + if (!_size) + return; + allocate_device(q); + if (_host_ptr) + detail::dpct_memcpy(q, _device_ptr, _host_ptr, _size, + host_to_device); + } + + /// The variable is assigned to a device pointer. + void assign(value_t *src, size_t size) { + this->~device_memory(); + new (this) device_memory(src, size); + } + + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t *get_ptr() { return get_ptr(get_default_queue()); } + /// Get memory pointer of the memory object, which is virtual pointer when + /// usm is not used, and device pointer when usm is used. + value_t *get_ptr(sycl::queue &q) { + init(q); + return _device_ptr; + } + + /// Get the device memory object size in bytes. + size_t get_size() { return _size; } + + template + typename std::enable_if::type &operator[](size_t index) { + init(); + #ifdef DPCT_USM_LEVEL_NONE + return dpct::get_buffer::type>( + _device_ptr) + .template get_access()[index]; + #else + return _device_ptr[index]; + #endif // DPCT_USM_LEVEL_NONE + } + + #ifdef DPCT_USM_LEVEL_NONE + /// Get sycl::accessor for the device memory object when usm is not used. + accessor_t get_access(sycl::handler &cgh) { + return get_buffer(_device_ptr) + .template reinterpret(_range) + .template get_access::mode, + detail::memory_traits::target>(cgh); + } + #else + /// Get dpct::accessor with dimension info for the device memory object + /// when usm is used and dimension is greater than 1. + template + typename std::enable_if::type + get_access(sycl::handler &cgh) { + return dpct_accessor_t((T *)_device_ptr, _range); + } + #endif // DPCT_USM_LEVEL_NONE + + private: + device_memory(value_t *memory_ptr, size_t size) + : _size(size), _range(size / sizeof(T)), _reference(true), + _device_ptr(memory_ptr) {} + + void allocate_device(sycl::queue &q) { + #ifndef DPCT_USM_LEVEL_NONE + if (Memory == shared) { + _device_ptr = (value_t *)sycl::malloc_shared(_size, q.get_device(), + q.get_context()); + return; + } + #ifdef SYCL_EXT_ONEAPI_USM_DEVICE_READ_ONLY + if (Memory == constant) { + _device_ptr = (value_t *)sycl::malloc_device( + _size, q.get_device(), q.get_context(), + sycl::ext::oneapi::property::usm::device_read_only()); + return; + } + #endif + #endif + _device_ptr = (value_t *)detail::dpct_malloc(_size, q); + } + + size_t _size; + sycl::range _range; + bool _reference; + value_t *_host_ptr; + value_t *_device_ptr; + }; + template + class device_memory : public device_memory { + public: + using base = device_memory; + using value_t = typename base::value_t; + using accessor_t = + typename detail::memory_traits::template accessor_t<0>; + + /// Constructor with initial value. + device_memory(const value_t &val) : base(sycl::range<1>(1), {val}) {} + + /// Default constructor + device_memory() : base(1) {} + + #ifdef DPCT_USM_LEVEL_NONE + /// Get sycl::accessor for the device memory object when usm is not used. + accessor_t get_access(sycl::handler &cgh) { + auto buf = get_buffer(base::get_ptr()) + .template reinterpret(sycl::range<1>(1)); + return accessor_t(buf, cgh); + } + #endif // DPCT_USM_LEVEL_NONE + }; + } // namespace detail + + template + using global_memory = detail::device_memory; + template + using constant_memory = detail::device_memory; + template + using shared_memory = detail::device_memory; + + } // COPY from DPCT head files @@ -2938,6 +3186,15 @@ static int g_work_group_size = 0; #pragma warning(disable: 4244 4267) // possible loss of data #endif +// dmmv = dequantize_mul_mat_vec +#ifndef GGML_SYCL_DMMV_X +#define GGML_SYCL_DMMV_X 32 +#endif +#ifndef GGML_SYCL_MMV_Y +#define GGML_SYCL_MMV_Y 1 +#endif + + static_assert(sizeof(sycl::half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static void crash(){ @@ -3060,7 +3317,7 @@ typedef void (*ggml_sycl_op_flatten_t)(const ggml_tensor *src0, #define QK4_0 32 #define QR4_0 2 #define QI4_0 (QK4_0 / (4 * QR4_0)) -typedef struct dpct_type_471834 { +typedef struct dpct_type_block_q4_0 { sycl::half d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; @@ -3069,7 +3326,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 #define QK4_1 32 #define QR4_1 2 #define QI4_1 (QK4_1 / (4 * QR4_1)) -typedef struct dpct_type_143705 { +typedef struct dpct_type_block_q4_1 { sycl::half2 dm; // dm.x = delta, dm.y = min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; @@ -3078,7 +3335,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong #define QK5_0 32 #define QR5_0 2 #define QI5_0 (QK5_0 / (4 * QR5_0)) -typedef struct dpct_type_673649 { +typedef struct dpct_type_block_q5_0 { sycl::half d; // delta uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_0 / 2]; // nibbles / quants @@ -3088,7 +3345,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5 #define QK5_1 32 #define QR5_1 2 #define QI5_1 (QK5_1 / (4 * QR5_1)) -typedef struct dpct_type_135589 { +typedef struct dpct_type_block_q5_1 { sycl::half2 dm; // dm.x = delta, dm.y = min uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants @@ -3098,7 +3355,7 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + #define QK8_0 32 #define QR8_0 1 #define QI8_0 (QK8_0 / (4 * QR8_0)) -typedef struct dpct_type_122878 { +typedef struct dpct_type_block_q8_0 { sycl::half d; // delta int8_t qs[QK8_0]; // quants } block_q8_0; @@ -3107,7 +3364,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo #define QK8_1 32 #define QR8_1 1 #define QI8_1 (QK8_1 / (4 * QR8_1)) -typedef struct dpct_type_143721 { +typedef struct dpct_type_block_q8_1 { sycl::half2 ds; // ds.x = delta, ds.y = sum int8_t qs[QK8_0]; // quants } block_q8_1; @@ -3141,7 +3398,7 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)( #define QR2_K 4 #define QI2_K (QK_K / (4*QR2_K)) -typedef struct dpct_type_619598 { +typedef struct dpct_type_block_q2_K { uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t qs[QK_K/4]; // quants sycl::half2 dm; // super-block scale for quantized scales/mins @@ -3150,7 +3407,7 @@ static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "w #define QR3_K 4 #define QI3_K (QK_K / (4*QR3_K)) -typedef struct dpct_type_138576 { +typedef struct dpct_type_block_q3_K { uint8_t hmask[QK_K/8]; // quants - high bit uint8_t qs[QK_K/4]; // quants - low 2 bits #ifdef GGML_QKK_64 @@ -3166,13 +3423,13 @@ typedef struct dpct_type_138576 { #define QI4_K (QK_K / (4*QR4_K)) #ifdef GGML_QKK_64 typedef struct { - half dm[2]; // super-block scales/mins + sycl::half dm[2]; // super-block scales/mins uint8_t scales[2]; // 4-bit block scales/mins uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; -static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); +static_assert(sizeof(block_q4_K) == sizeof(sycl::half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); #else -typedef struct dpct_type_154943 { +typedef struct dpct_type_block_q4_K { sycl::half2 dm; // super-block scale for quantized scales/mins uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants @@ -3184,14 +3441,14 @@ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, #define QI5_K (QK_K / (4*QR5_K)) #ifdef GGML_QKK_64 typedef struct { - half d; // super-block scale + sycl::half d; // super-block scale int8_t scales[QK_K/16]; // block scales uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits } block_q5_K; static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); #else -typedef struct dpct_type_866817 { +typedef struct dpct_type_block_q5_K { sycl::half2 dm; // super-block scale for quantized scales/mins uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit @@ -3202,7 +3459,7 @@ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/ #define QR6_K 2 #define QI6_K (QK_K / (4*QR6_K)) -typedef struct dpct_type_107281 { +typedef struct dpct_type_block_q6_K { uint8_t ql[QK_K/2]; // quants, lower 4 bits uint8_t qh[QK_K/4]; // quants, upper 2 bits int8_t scales[QK_K/16]; // scales @@ -3210,6 +3467,31 @@ typedef struct dpct_type_107281 { } block_q6_K; static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding"); +#define QR2_XXS 8 +#define QI2_XXS (QK_K / (4*QR2_XXS)) +typedef struct dpct_type_block_iq2_xxs { + sycl::half d; + uint16_t qs[QK_K/8]; +} block_iq2_xxs; +static_assert(sizeof(block_iq2_xxs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t), "wrong iq2_xxs block size/padding"); + +#define QR2_XS 8 +#define QI2_XS (QK_K / (4*QR2_XS)) +typedef struct dpct_type_block_iq2_xs { + sycl::half d; + uint16_t qs[QK_K/8]; + uint8_t scales[QK_K/32]; +} block_iq2_xs; +static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding"); + +#define QR3_XXS 8 +#define QI3_XXS (QK_K / (4*QR3_XXS)) +typedef struct dpct_type_block_iq3_xxs { + sycl::half d; + uint8_t qs[3*(QK_K/8)]; +} block_iq3_xxs; +static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding"); + #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses @@ -3478,7 +3760,7 @@ void log_ggml_var_device(const char*name, float *src, size_t total_elements, boo local_buf = (float *) ggml_sycl_host_malloc(total_size); ggml_sycl_set_device(g_main_device); dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0]; - main_stream->memcpy(local_buf, src, total_size); + main_stream->memcpy(local_buf, src, total_size).wait(); } else { local_buf = (float *)src; @@ -4129,6 +4411,66 @@ static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib, #endif // GGML_SYCL_F16 } +template +static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, + const sycl::nd_item<3> &item_ct1) { + + const int i = item_ct1.get_group(2); + + // assume 32 threads + const int tid = item_ct1.get_local_id(2); + const int il = tid/8; + const int ir = tid%8; + const int ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_0 * x = (const block_q4_0 *)vx + ib; + const float d = sycl::vec(x->d) + .convert()[0]; + const float dm = -8*d; + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l+ 0] = d * (q[l] & 0xF) + dm; + y[l+16] = d * (q[l] >> 4) + dm; + } +} + +template +static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32, + const sycl::nd_item<3> &item_ct1) { + + const int i = item_ct1.get_group(2); + + // assume 32 threads + const int tid = item_ct1.get_local_id(2); + const int il = tid/8; + const int ir = tid%8; + const int ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + dst_t * y = yy + 256*i + 32*ir + 4*il; + + const block_q4_1 * x = (const block_q4_1 *)vx + ib; + const sycl::float2 d = + x->dm.convert(); + + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + y[l + 0] = d.x() * (q[l] & 0xF) + d.y(); + y[l + 16] = d.x() * (q[l] >> 4) + d.y(); + } +} + + //================================== k-quants template @@ -4158,8 +4500,9 @@ static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restri const int il = tid%16; // 0...15 const uint8_t q = x[i].qs[il] >> (2*is); dst_t * y = yy + i*QK_K + 16*is + il; - float dall = __low2half(x[i].dm); - float dmin = __high2half(x[i].dm); + + float dall = x[i].dm[0]; + float dmin = x[i].dm[1]; y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); #endif @@ -4198,7 +4541,7 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); #else - const int tid = threadIdx.x; + const int tid = item_ct1.get_local_id(2); const int is = tid/16; // 0 or 1 const int il = tid%16; // 0...15 const int im = il/8; // 0...1 @@ -4264,7 +4607,7 @@ static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restri y[l +32] = d2 * (q[l] >> 4) - m2; } #else - const int tid = threadIdx.x; + const int tid = item_ct1.get_local_id(2); const uint8_t * q = x[i].qs; dst_t * y = yy + i*QK_K; const float d = (float)x[i].dm[0]; @@ -4309,7 +4652,7 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; #else - const int tid = threadIdx.x; + const int tid = item_ct1.get_local_id(2); const uint8_t q = x[i].qs[tid]; const int im = tid/8; // 0...3 const int in = tid%8; // 0...7 @@ -4351,7 +4694,7 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri #else // assume 32 threads - const int tid = threadIdx.x; + const int tid = item_ct1.get_local_id(2); const int ip = tid/16; // 0 or 1 const int il = tid - 16*ip; // 0...15 @@ -4368,6 +4711,474 @@ static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restri #endif } +static dpct::global_memory + iq2xxs_grid(sycl::range<1>(256), + { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, + 0x0808080808082b08, 0x0808080808082b2b, 0x0808080808190819, + 0x0808080808191908, 0x08080808082b0808, 0x08080808082b082b, + 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819, + 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, + 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b082b2b, 0x080808082b2b082b, + 0x0808081908080819, 0x0808081908081908, 0x0808081908190808, + 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, + 0x080808192b192b08, 0x0808082b08080808, 0x0808082b0808082b, + 0x0808082b082b082b, 0x0808082b2b08082b, 0x0808190808080819, + 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, + 0x0808190819082b08, 0x08081908192b0808, 0x080819082b080819, + 0x080819082b081908, 0x080819082b190808, 0x080819082b2b1908, + 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08, + 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, + 0x080819192b080808, 0x080819192b190819, 0x0808192b08082b19, + 0x0808192b08190808, 0x0808192b19080808, 0x0808192b2b081908, + 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919, + 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, + 0x08082b0819080819, 0x08082b0819081908, 0x08082b0819190808, + 0x08082b081919082b, 0x08082b082b082b08, 0x08082b1908081908, + 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908, + 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, + 0x08190808082b0819, 0x0819080819080808, 0x08190808192b0808, + 0x081908082b081908, 0x081908082b190808, 0x081908082b191919, + 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808, + 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, + 0x0819082b082b1908, 0x0819082b19081919, 0x0819190808080808, + 0x0819190808082b08, 0x08191908082b0808, 0x08191908082b1919, + 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08, + 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, + 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808, + 0x08192b0819080808, 0x08192b082b080819, 0x08192b1908080808, + 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819, + 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, + 0x082b080819081908, 0x082b0808192b0819, 0x082b08082b080808, + 0x082b08082b08082b, 0x082b0819082b2b19, 0x082b081919082b08, + 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819, + 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, + 0x082b19081919192b, 0x082b191908080808, 0x082b191919080819, + 0x082b1919192b1908, 0x082b192b2b190808, 0x082b2b0808082b08, + 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908, + 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, + 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x1908080819082b08, 0x190808081919192b, + 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908, + 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, + 0x19080819192b0819, 0x190808192b080808, 0x190808192b081919, + 0x1908082b08080819, 0x1908082b08190808, 0x1908082b19082b08, + 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808, + 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, + 0x190819082b192b19, 0x190819190819082b, 0x19081919082b1908, + 0x1908192b08080808, 0x19082b0808080819, 0x19082b0808081908, + 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, + 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, + 0x19082b192b08082b, 0x19082b2b19081919, 0x19082b2b2b190808, + 0x1919080808080808, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808, + 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, + 0x191908192b2b1908, 0x1919082b2b190819, 0x191919082b190808, + 0x191919082b19082b, 0x1919191908082b2b, 0x1919192b08080819, + 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819, + 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, + 0x19192b2b08082b08, 0x192b080808081908, 0x192b080808190808, + 0x192b080819080808, 0x192b0808192b2b08, 0x192b081908080808, + 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808, + 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, + 0x192b19190819082b, 0x192b19192b081908, 0x192b2b081908082b, + 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808082b2b, + 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908, + 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, + 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, + 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908, + 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, + 0x2b082b080808082b, 0x2b082b1908081908, 0x2b082b2b08190819, + 0x2b19080808081908, 0x2b19080808190808, 0x2b190808082b1908, + 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b, + 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, + 0x2b191908082b082b, 0x2b19190819081908, 0x2b19191919190819, + 0x2b192b082b080819, 0x2b192b19082b0808, 0x2b2b08080808082b, + 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19, + 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, + 0x2b2b2b1908081908, + }); + +static dpct::global_memory + iq2xs_grid(sycl::range<1>(512), + { + 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, + 0x0808080808082b08, 0x0808080808082b2b, 0x0808080808190819, + 0x0808080808191908, 0x080808080819192b, 0x0808080808192b19, + 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919, + 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, + 0x080808081908192b, 0x0808080819082b19, 0x0808080819190808, + 0x080808081919082b, 0x0808080819191919, 0x0808080819192b08, + 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808, + 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, + 0x080808082b190819, 0x080808082b191908, 0x080808082b192b19, + 0x080808082b2b0808, 0x0808081908080819, 0x0808081908081908, + 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808, + 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, + 0x0808081908192b2b, 0x08080819082b0819, 0x08080819082b1908, + 0x0808081919080808, 0x080808191908082b, 0x0808081919081919, + 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908, + 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, + 0x080808192b081908, 0x080808192b190808, 0x0808082b08080808, + 0x0808082b0808082b, 0x0808082b08081919, 0x0808082b08082b08, + 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808, + 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, + 0x0808082b19191919, 0x0808082b2b080808, 0x0808082b2b082b2b, + 0x0808190808080819, 0x0808190808081908, 0x080819080808192b, + 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b, + 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, + 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, + 0x0808190819081919, 0x0808190819082b08, 0x0808190819190819, + 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808, + 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, + 0x0808191908080808, 0x080819190808082b, 0x0808191908081919, + 0x0808191908082b08, 0x0808191908190819, 0x0808191908191908, + 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908, + 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, + 0x0808192b08080819, 0x0808192b08081908, 0x0808192b08190808, + 0x0808192b082b192b, 0x0808192b19080808, 0x0808192b1908082b, + 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b, + 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, + 0x08082b0808190819, 0x08082b0808191908, 0x08082b08082b0808, + 0x08082b08082b1919, 0x08082b0819080819, 0x08082b0819081908, + 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808, + 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, + 0x08082b1908081908, 0x08082b1908190808, 0x08082b1919080808, + 0x08082b192b080819, 0x08082b192b082b19, 0x08082b2b08080808, + 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b, + 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, + 0x081908080808192b, 0x0819080808082b19, 0x0819080808190808, + 0x081908080819082b, 0x0819080808191919, 0x0819080808192b08, + 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808, + 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, + 0x0819080819190819, 0x0819080819191908, 0x08190808192b0808, + 0x08190808192b2b2b, 0x081908082b080819, 0x081908082b081908, + 0x081908082b190808, 0x0819081908080808, 0x081908190808082b, + 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, + 0x0819081908191908, 0x08190819082b0808, 0x0819081919080819, + 0x0819081919081908, 0x0819081919190808, 0x081908192b080808, + 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819, + 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, + 0x0819082b19080808, 0x0819082b192b0808, 0x0819190808080808, + 0x081919080808082b, 0x0819190808081919, 0x0819190808082b08, + 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808, + 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, + 0x0819190819190808, 0x08191908192b1908, 0x081919082b080808, + 0x0819191908080819, 0x0819191908081908, 0x0819191908190808, + 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908, + 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, + 0x08192b0808190808, 0x08192b080819082b, 0x08192b0819080808, + 0x08192b0819191908, 0x08192b082b08192b, 0x08192b1908080808, + 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819, + 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, + 0x082b080808081919, 0x082b080808082b08, 0x082b080808082b2b, + 0x082b080808190819, 0x082b080808191908, 0x082b0808082b0808, + 0x082b080819080819, 0x082b080819081908, 0x082b080819190808, + 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, + 0x082b081908081908, 0x082b081908190808, 0x082b081919080808, + 0x082b081919082b08, 0x082b0819192b1919, 0x082b082b08080808, + 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08, + 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, + 0x082b1908082b2b19, 0x082b190819080808, 0x082b191908080808, + 0x082b191919080819, 0x082b19191919082b, 0x082b19192b192b19, + 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b, + 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, + 0x082b2b08082b0808, 0x082b2b0819191919, 0x082b2b082b082b08, + 0x082b2b082b2b082b, 0x082b2b19192b2b08, 0x082b2b192b190808, + 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b, + 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, + 0x1908080808081908, 0x190808080808192b, 0x1908080808082b19, + 0x1908080808190808, 0x190808080819082b, 0x1908080808191919, + 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908, + 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, + 0x1908080819082b08, 0x1908080819082b2b, 0x1908080819190819, + 0x1908080819191908, 0x19080808192b0808, 0x19080808192b1919, + 0x190808082b080819, 0x190808082b081908, 0x190808082b190808, + 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, + 0x1908081908082b08, 0x1908081908190819, 0x1908081908191908, + 0x19080819082b0808, 0x1908081919080819, 0x1908081919081908, + 0x1908081919190808, 0x190808192b080808, 0x190808192b081919, + 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, + 0x1908082b08190808, 0x1908082b0819082b, 0x1908082b082b2b19, + 0x1908082b19080808, 0x1908190808080808, 0x190819080808082b, + 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819, + 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, + 0x1908190819080819, 0x1908190819081908, 0x1908190819190808, + 0x190819082b080808, 0x190819082b191908, 0x1908191908080819, + 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908, + 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, + 0x1908192b08082b2b, 0x1908192b19081908, 0x1908192b19190808, + 0x19082b0808080819, 0x19082b0808081908, 0x19082b0808190808, + 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908, + 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, + 0x19082b1919081908, 0x19082b1919190808, 0x19082b19192b2b19, + 0x19082b2b08081908, 0x1919080808080808, 0x191908080808082b, + 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819, + 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, + 0x1919080819080819, 0x1919080819081908, 0x1919080819190808, + 0x191908082b080808, 0x1919081908080819, 0x1919081908081908, + 0x1919081908190808, 0x1919081908191919, 0x1919081919080808, + 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, + 0x1919082b2b2b2b2b, 0x1919190808080819, 0x1919190808081908, + 0x1919190808190808, 0x19191908082b0819, 0x1919190819080808, + 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819, + 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, + 0x191919192b082b08, 0x1919192b082b0819, 0x1919192b192b2b08, + 0x1919192b2b2b0819, 0x19192b0808080808, 0x19192b0808191908, + 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19, + 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, + 0x19192b2b2b081919, 0x192b080808080819, 0x192b080808081908, + 0x192b080808190808, 0x192b080819080808, 0x192b080819191908, + 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19, + 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, + 0x192b082b2b19082b, 0x192b190808080808, 0x192b19080819192b, + 0x192b191908190808, 0x192b191919080808, 0x192b191919081919, + 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b, + 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, + 0x192b2b2b192b082b, 0x2b08080808080808, 0x2b0808080808082b, + 0x2b08080808081919, 0x2b08080808082b08, 0x2b08080808190819, + 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b, + 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, + 0x2b0808082b080808, 0x2b0808082b08082b, 0x2b0808082b2b2b08, + 0x2b0808082b2b2b2b, 0x2b08081908080819, 0x2b08081908081908, + 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808, + 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, + 0x2b08082b082b0808, 0x2b08082b2b080808, 0x2b08082b2b08082b, + 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08, 0x2b08190808080819, + 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b, + 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, + 0x2b0819082b082b19, 0x2b08191908080808, 0x2b08191919081908, + 0x2b0819192b2b1919, 0x2b08192b08192b08, 0x2b08192b192b2b2b, + 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919, + 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, + 0x2b082b082b2b2b08, 0x2b082b190808192b, 0x2b082b2b082b082b, + 0x2b082b2b2b080808, 0x2b082b2b2b082b08, 0x2b082b2b2b19192b, + 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908, + 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, + 0x2b1908082b081908, 0x2b19081908080808, 0x2b190819082b082b, + 0x2b190819192b1908, 0x2b19082b1919192b, 0x2b19082b2b082b19, + 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908, + 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, + 0x2b1919192b190808, 0x2b1919192b19082b, 0x2b19192b19080819, + 0x2b192b0819190819, 0x2b192b082b2b192b, 0x2b192b1919082b19, + 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808, + 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, + 0x2b2b0808082b0808, 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, + 0x2b2b081919190819, 0x2b2b081919192b19, 0x2b2b08192b2b192b, + 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08, + 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, + 0x2b2b190819080808, 0x2b2b19082b191919, 0x2b2b192b192b1919, + 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b, 0x2b2b2b08082b0808, + 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808, + 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, + 0x2b2b2b192b08192b, 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, + 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b, + }); + +static dpct::global_memory iq3xxs_grid( + sycl::range<1>(256), + { + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, + 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, + 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, 0x040c140c, 0x040c142c, + 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, + 0x04141c1c, 0x04141c3e, 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, + 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, 0x041c3e04, 0x04240c1c, + 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, + 0x043e0c24, 0x043e0c34, 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, + 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, 0x0c041c04, 0x0c041c14, + 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, + 0x0c14140c, 0x0c141c04, 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, + 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, 0x0c24042c, 0x0c242c04, + 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, + 0x14041414, 0x14041434, 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, + 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, 0x140c1c04, 0x140c341c, + 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, + 0x141c0c04, 0x141c0c24, 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, + 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, 0x143e040c, 0x143e041c, + 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, + 0x1c0c1404, 0x1c0c1c0c, 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, + 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, 0x1c1c0c0c, 0x1c1c1c1c, + 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, + 0x24040424, 0x24040c3e, 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, + 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, 0x24143404, 0x24143434, + 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, + 0x2c040c14, 0x2c04240c, 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, + 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, 0x2c1c0414, 0x2c1c2c1c, + 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, + 0x34043424, 0x340c140c, 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, + 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, 0x34341c1c, 0x343e041c, + 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, + 0x3e1c0404, 0x3e1c0c2c, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, + 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, + }); + +static dpct::global_memory ksigns_iq2xs( + sycl::range<1>(128), + { + 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, + 141, 142, 15, 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, + 154, 27, 156, 29, 30, 159, 160, 33, 34, 163, 36, 165, 166, + 39, 40, 169, 170, 43, 172, 45, 46, 175, 48, 177, 178, 51, + 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63, 192, + 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, + 78, 207, 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, + 219, 92, 221, 222, 95, 96, 225, 226, 99, 228, 101, 102, 231, + 232, 105, 106, 235, 108, 237, 238, 111, 240, 113, 114, 243, 116, + 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, + }); + +static dpct::global_memory + ksigns64(sycl::range<1>(128), + { + 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, + 0x000000000000ffff, 0xff00000000ff0000, 0x0000000000ff00ff, + 0x0000000000ffff00, 0xff00000000ffffff, 0xff000000ff000000, + 0x00000000ff0000ff, 0x00000000ff00ff00, 0xff000000ff00ffff, + 0x00000000ffff0000, 0xff000000ffff00ff, 0xff000000ffffff00, + 0x00000000ffffffff, 0xff0000ff00000000, 0x000000ff000000ff, + 0x000000ff0000ff00, 0xff0000ff0000ffff, 0x000000ff00ff0000, + 0xff0000ff00ff00ff, 0xff0000ff00ffff00, 0x000000ff00ffffff, + 0x000000ffff000000, 0xff0000ffff0000ff, 0xff0000ffff00ff00, + 0x000000ffff00ffff, 0xff0000ffffff0000, 0x000000ffffff00ff, + 0x000000ffffffff00, 0xff0000ffffffffff, 0xff00ff0000000000, + 0x0000ff00000000ff, 0x0000ff000000ff00, 0xff00ff000000ffff, + 0x0000ff0000ff0000, 0xff00ff0000ff00ff, 0xff00ff0000ffff00, + 0x0000ff0000ffffff, 0x0000ff00ff000000, 0xff00ff00ff0000ff, + 0xff00ff00ff00ff00, 0x0000ff00ff00ffff, 0xff00ff00ffff0000, + 0x0000ff00ffff00ff, 0x0000ff00ffffff00, 0xff00ff00ffffffff, + 0x0000ffff00000000, 0xff00ffff000000ff, 0xff00ffff0000ff00, + 0x0000ffff0000ffff, 0xff00ffff00ff0000, 0x0000ffff00ff00ff, + 0x0000ffff00ffff00, 0xff00ffff00ffffff, 0xff00ffffff000000, + 0x0000ffffff0000ff, 0x0000ffffff00ff00, 0xff00ffffff00ffff, + 0x0000ffffffff0000, 0xff00ffffffff00ff, 0xff00ffffffffff00, + 0x0000ffffffffffff, 0xffff000000000000, 0x00ff0000000000ff, + 0x00ff00000000ff00, 0xffff00000000ffff, 0x00ff000000ff0000, + 0xffff000000ff00ff, 0xffff000000ffff00, 0x00ff000000ffffff, + 0x00ff0000ff000000, 0xffff0000ff0000ff, 0xffff0000ff00ff00, + 0x00ff0000ff00ffff, 0xffff0000ffff0000, 0x00ff0000ffff00ff, + 0x00ff0000ffffff00, 0xffff0000ffffffff, 0x00ff00ff00000000, + 0xffff00ff000000ff, 0xffff00ff0000ff00, 0x00ff00ff0000ffff, + 0xffff00ff00ff0000, 0x00ff00ff00ff00ff, 0x00ff00ff00ffff00, + 0xffff00ff00ffffff, 0xffff00ffff000000, 0x00ff00ffff0000ff, + 0x00ff00ffff00ff00, 0xffff00ffff00ffff, 0x00ff00ffffff0000, + 0xffff00ffffff00ff, 0xffff00ffffffff00, 0x00ff00ffffffffff, + 0x00ffff0000000000, 0xffffff00000000ff, 0xffffff000000ff00, + 0x00ffff000000ffff, 0xffffff0000ff0000, 0x00ffff0000ff00ff, + 0x00ffff0000ffff00, 0xffffff0000ffffff, 0xffffff00ff000000, + 0x00ffff00ff0000ff, 0x00ffff00ff00ff00, 0xffffff00ff00ffff, + 0x00ffff00ffff0000, 0xffffff00ffff00ff, 0xffffff00ffffff00, + 0x00ffff00ffffffff, 0xffffffff00000000, 0x00ffffff000000ff, + 0x00ffffff0000ff00, 0xffffffff0000ffff, 0x00ffffff00ff0000, + 0xffffffff00ff00ff, 0xffffffff00ffff00, 0x00ffffff00ffffff, + 0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, + 0x00ffffffff00ffff, 0xffffffffffff0000, 0x00ffffffffff00ff, + 0x00ffffffffffff00, 0xffffffffffffffff, + }); +//#endif + +static dpct::global_memory + kmask_iq2xs(sycl::range<1>(8), {1, 2, 4, 8, 16, 32, 64, 128}); + +template +static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint64_t *iq2xxs_grid_ptr, + const uint8_t *ksigns_iq2xs_ptr, + const uint8_t *kmask_iq2xs_ptr) { + + const int i = item_ct1.get_group(2); + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + const int tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * aux8 = (const uint8_t *)q2; + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid_ptr + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f); +#else + assert(false); +#endif + +} + +template +static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint64_t *iq2xs_grid, + const uint8_t *ksigns_iq2xs, + const uint8_t *kmask_iq2xs) { + + const int i = item_ct1.get_group(2); + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + const int tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint16_t * q2 = x[i].qs + 4*ib; + const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); + const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); +#else + assert(false); +#endif + +} + +template +static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> &item_ct1, + const uint32_t *iq3xxs_grid, + const uint8_t *ksigns_iq2xs, + const uint8_t *kmask_iq2xs) { + + const int i = item_ct1.get_group(2); + const block_iq3_xxs * x = (const block_iq3_xxs *) vx; + + const int tid = item_ct1.get_local_id(2); +#if QK_K == 256 + const int il = tid/8; // 0...3 + const int ib = tid%8; // 0...7 + dst_t * y = yy + i*QK_K + 32*ib + 8*il; + const uint8_t * q3 = x[i].qs + 8*ib; + const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib; + const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; + for (int j = 0; j < 4; ++j) { + y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); + y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); + } +#else + assert(false); +#endif + +} + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -4446,13 +5257,16 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, } #else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int tid = item_ct1.get_local_id(2) / + (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = item_ct1.get_local_id(2) % + (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3 const int offset = tid * K_QUANTS_PER_ITERATION; uint32_t uaux[2]; const uint8_t * d = (const uint8_t *)uaux; + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { const float * y = yy + i * QK_K + offset; @@ -4462,7 +5276,8 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, uaux[0] = s[0] & 0x0f0f0f0f; uaux[1] = (s[0] >> 4) & 0x0f0f0f0f; - const float2 dall = __half22float2(x[i].dm); + const sycl::float2 dall = + x[i].dm.convert(); float sum1 = 0, sum2 = 0; for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { @@ -4473,8 +5288,9 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, + y[l+48] * d[3] * ((ql >> 6) & 3); sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7]; } - tmp += dall.x * sum1 - dall.y * sum2; + tmp += dall.x() * sum1 - dall.y() * sum2; } + #endif // sum up partial sums and write back result @@ -4569,8 +5385,8 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, } #else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 const int in = offset/8; // 0 or 1 const int im = offset%8; // 0...7 @@ -4719,8 +5535,8 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, } #else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); const int step = tid * K_QUANTS_PER_ITERATION; @@ -4860,8 +5676,8 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, } #else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); const int step = tid * K_QUANTS_PER_ITERATION; const int im = step/8; const int in = step%8; @@ -4969,8 +5785,8 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa #else - const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7 - const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3 + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 const int step = tid * K_QUANTS_PER_ITERATION; @@ -5170,6 +5986,21 @@ static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y[iybs + iqs + y_offset] = v.y(); } +template +static void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, + const sycl::nd_item<3> &item_ct1) { + const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + if (i >= k) { + return; + } + + const src_t * x = (src_t *) vx; + + y[i] = x[i]; +} + // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q @@ -6588,8 +7419,8 @@ vec_dot_q4_K_q8_1(const void *__restrict__ vbq, const float dall = bq4_K->dm[0]; const float dmin = bq4_K->dm[1]; - const float d8_1 = __low2float(bq8_1[0].ds); - const float d8_2 = __low2float(bq8_1[1].ds); + const float d8_1 = bq8_1[0].ds[0]; + const float d8_2 = bq8_1[1].ds[1]; const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); @@ -6600,10 +7431,10 @@ vec_dot_q4_K_q8_1(const void *__restrict__ vbq, const int v1 = q4[0]; const int v2 = q4[4]; - const int dot1 = __dp4a(ui2, v2 & 0x0f0f0f0f, __dp4a(ui1, v1 & 0x0f0f0f0f, 0)); - const int dot2 = __dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, __dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); - const int dot3 = __dp4a(0x01010101, ui2, __dp4a(0x01010101, ui1, 0)); - const int dot4 = __dp4a(0x01010101, ui4, __dp4a(0x01010101, ui3, 0)); + const int dot1 = dpct::dp4a(ui2, v2 & 0x0f0f0f0f, dpct::dp4a(ui1, v1 & 0x0f0f0f0f, 0)); + const int dot2 = dpct::dp4a(ui4, (v2 >> 4) & 0x0f0f0f0f, dpct::dp4a(ui3, (v1 >> 4) & 0x0f0f0f0f, 0)); + const int dot3 = dpct::dp4a(0x01010101, ui2, dpct::dp4a(0x01010101, ui1, 0)); + const int dot4 = dpct::dp4a(0x01010101, ui4, dpct::dp4a(0x01010101, ui3, 0)); sumf_d += d8_1 * (dot1 * s[0]) + d8_2 * (dot2 * s[1]); sumf_m += d8_1 * (dot3 * s[2]) + d8_2 * (dot4 * s[3]); @@ -6772,8 +7603,8 @@ vec_dot_q5_K_q8_1(const void *__restrict__ vbq, const float d = bq5_K->d; - const float d8_1 = __low2half(bq8_1[0].ds); - const float d8_2 = __low2half(bq8_1[1].ds); + const float d8_1 = bq8_1[0].ds[0]; + const float d8_2 = bq8_1[1].ds[1]; const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); @@ -6794,8 +7625,8 @@ vec_dot_q5_K_q8_1(const void *__restrict__ vbq, const int v3 = (((vh >> 0) & 0x10101010) ^ 0x10101010) | ((vl1 >> 4) & 0x0f0f0f0f); const int v4 = (((vh >> 2) & 0x10101010) ^ 0x10101010) | ((vl2 >> 4) & 0x0f0f0f0f); - const float sumf_d = d8_1 * (__dp4a(ui1, v1, 0) * s[0] + __dp4a(ui2, v2, 0) * s[1]) - + d8_2 * (__dp4a(ui3, v3, 0) * s[2] + __dp4a(ui4, v4, 0) * s[3]); + const float sumf_d = d8_1 * (dpct::dp4a(ui1, v1, 0) * s[0] + dpct::dp4a(ui2, v2, 0) * s[1]) + + d8_2 * (dpct::dp4a(ui3, v3, 0) * s[2] + dpct::dp4a(ui4, v4, 0) * s[3]); return d * sumf_d; @@ -7051,6 +7882,150 @@ static __dpct_inline__ float vec_dot_q6_K_q8_1_mul_mat( return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); } + +static __dpct_inline__ float +vec_dot_iq2_xxs_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint64_t *iq2xxs_grid, const uint8_t *ksigns_iq2xs, + const uint8_t *kmask_iq2xs) { +#if QK_K == 256 + const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq; + +#if QR2_XXS == 8 + const int ib32 = iqs; + const uint16_t * q2 = bq2->qs + 4*ib32; + const uint8_t * aux8 = (const uint8_t *)q2; + const int8_t * q8 = bq8_1[ib32].qs; + uint32_t aux32 = q2[2] | (q2[3] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[l]); + const uint8_t signs = ksigns_iq2xs[aux32 & 127]; + for (int j = 0; j < 8; ++j) { + sumi += q8[j] * grid[j] * (signs & kmask_iq2xs[j] ? -1 : 1); + } + q8 += 8; + aux32 >>= 7; + } + const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.25f; + return d * sumi; +#else + // iqs is 0...15 + const int ib32 = iqs/2; + const int il = iqs%2; + const uint16_t * q2 = bq2->qs + 4*ib32; + const uint8_t * aux8 = (const uint8_t *)q2; + const uint8_t * grid1 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+0]); + const uint8_t * grid2 = (const uint8_t *)(iq2xxs_grid + aux8[2*il+1]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float)bq2->d * (0.5f + (aux32 >> 28)) * bq8_1[ib32].ds[0] * 0.25f; + const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14*il) & 127]; + const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14*il + 7)) & 127]; + const int8_t * q8 = bq8_1[ib32].qs + 16*il; + int sumi1 = 0, sumi2 = 0; + for (int j = 0; j < 8; ++j) { + sumi1 += q8[j+0] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1); + sumi2 += q8[j+8] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1); + } + return d * (sumi1 + sumi2); +#endif +#else + assert(false); + return 0.f; +#endif +} + +static __dpct_inline__ float +vec_dot_iq2_xs_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint64_t *iq2xs_grid, const uint64_t *ksigns64) { +#if DPCT_COMPATIBILITY_TEMP >= \ + MIN_CC_DP4A // lowest compute capability for integer intrinsics +#if QK_K == 256 + const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq; + + const int ib32 = iqs; + const uint16_t * q2 = bq2->qs + 4*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + const uint8_t ls1 = bq2->scales[ib32] & 0xf; + const uint8_t ls2 = bq2->scales[ib32] >> 4; + int sumi1 = 0; + for (int l = 0; l < 2; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs[1], signs[1], std::minus<>()); + sumi1 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi1); + sumi1 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi1); + q8 += 8; + } + int sumi2 = 0; + for (int l = 2; l < 4; ++l) { + const uint32_t * grid = (const uint32_t *)(iq2xs_grid + (q2[l] & 511)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const int grid_l = dpct::vectorized_binary( + grid[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid[1] ^ signs[1], signs[1], std::minus<>()); + sumi2 = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi2); + sumi2 = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi2); + q8 += 8; + } + const float d = (float)bq2->d * bq8_1[ib32].ds[0] * 0.25f; + return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2); +#else + assert(false); + return 0.f; +#endif +#else + assert(false); + return 0.f; +#endif +} + +static __dpct_inline__ float +vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq, + const block_q8_1 *__restrict__ bq8_1, const int &iqs, + const uint32_t *iq3xxs_grid, const uint64_t *ksigns64) { +#if DPCT_COMPATIBILITY_TEMP >= \ + MIN_CC_DP4A // lowest compute capability for integer intrinsics +#if QK_K == 256 + const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; + + const int ib32 = iqs; + const uint8_t * q3 = bq2->qs + 8*ib32; + const uint16_t * gas = (const uint16_t *)(bq2->qs + QK_K/4) + 2*ib32; + const int8_t * q8 = bq8_1[ib32].qs; + uint32_t aux32 = gas[0] | (gas[1] << 16); + int sumi = 0; + for (int l = 0; l < 4; ++l) { + const uint32_t * grid1 = iq3xxs_grid + q3[2*l+0]; + const uint32_t * grid2 = iq3xxs_grid + q3[2*l+1]; + const uint32_t * signs = (const uint32_t *)(ksigns64 + (aux32 & 127)); + const int grid_l = dpct::vectorized_binary( + grid1[0] ^ signs[0], signs[0], std::minus<>()); + const int grid_h = dpct::vectorized_binary( + grid2[0] ^ signs[1], signs[1], std::minus<>()); + sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi); + sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi); + q8 += 8; + aux32 >>= 7; + } + const float d = (float)bq2->d * (0.5f + aux32) * bq8_1[ib32].ds[0] * 0.5f; + return d * sumi; +#else + assert(false); + return 0.f; +#endif +#else + assert(false); + return 0.f; +#endif +} + + template @@ -7632,7 +8607,8 @@ template static void template static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, - const sycl::nd_item<3> &item_ct1) { + const sycl::nd_item<3> &item_ct1, + const uint32_t *iq3xxs_grid_ptr, const uint64_t *ksigns64_ptr) { const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); @@ -7649,12 +8625,11 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ const block_q_t * x = (const block_q_t *) vx; const block_q8_1 * y = (const block_q8_1 *) vy; - for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { - const int ibx = row * blocks_per_row + i + - item_ct1.get_local_id(2) / (qi / vdr); // x block index + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index - const int iby = (i + item_ct1.get_local_id(2) / (qi / vdr)) * - (qk / QK8_1); // y block index that aligns with ibx + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx const int iqs = vdr * @@ -7676,6 +8651,145 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ } } +template +static void mul_mat_vec_q_iq2_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1, + const uint64_t *iq2xxs_grid_ptr, const uint8_t *ksigns_iq2xs_ptr, + const uint8_t *kmask_iq2xs_ptr ) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq2_xxs_q8_1(&x[ibx], &y[iby], iqs, iq2xxs_grid_ptr, ksigns_iq2xs_ptr, kmask_iq2xs_ptr); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq2_xs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1, + const uint64_t *iq2xs_grid_ptr, const uint64_t *ksigns64_ptr ) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq2_xs_q8_1(&x[ibx], &y[iby], iqs, iq2xs_grid_ptr, ksigns64_ptr); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template +static void mul_mat_vec_q_iq3_xxs_q8_1(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, + const sycl::nd_item<3> &item_ct1, + const uint32_t *iq3xxs_grid_ptr, const uint64_t *ksigns64_ptr ) { + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i; // x block index + + const int iby = i * (qk/QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (item_ct1.get_local_id(2) % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_iq3_xxs_q8_1(&x[ibx], &y[iby], iqs, iq3xxs_grid_ptr, ksigns64_ptr); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + template static void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows, const sycl::nd_item<3> &item_ct1) { @@ -9109,7 +10223,18 @@ static void dequantize_row_q2_K_sycl(const void *vx, dst_t *y, const int k, }); } #else - dequantize_block_q2_K<<>>(vx, y); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q2_K(vx, y, item_ct1); + }); + } + #endif } @@ -9130,10 +10255,57 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int k, }); } #else - dequantize_block_q3_K<<>>(vx, y); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K(vx, y, item_ct1); + }); + } #endif } +template +static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int k, + dpct::queue_ptr stream) { + const int nb32 = k / 32; + const int nb = (k + 255) / 256; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q4_0(vx, y, nb32, item_ct1); + }); + } +} + +template +static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int k, + dpct::queue_ptr stream) { + const int nb32 = k / 32; + const int nb = (k + 255) / 256; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q4_1(vx, y, nb32, item_ct1); + }); + } +} + + template static void dequantize_row_q4_K_sycl(const void *vx, dst_t *y, const int k, dpct::queue_ptr stream) { @@ -9168,7 +10340,18 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int k, }); } #else - dequantize_block_q5_K<<>>(vx, y); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q5_K(vx, y, item_ct1); + }); + } + #endif } @@ -9189,11 +10372,132 @@ static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int k, }); } #else - dequantize_block_q6_K<<>>(vx, y); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q6_K(vx, y, item_ct1); + }); + } + #endif } -static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) { + +template +static void dequantize_row_iq2_xxs_sycl(const void *vx, dst_t *y, const int k, + dpct::queue_ptr stream) { + const int nb = k / QK_K; + { + iq2xxs_grid.init(*stream); + ksigns_iq2xs.init(*stream); + kmask_iq2xs.init(*stream); + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + auto iq2xxs_grid_ptr_ct1 = iq2xxs_grid.get_ptr(); + auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr(); + auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr(); + + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq2_xxs( + vx, y, item_ct1, iq2xxs_grid_ptr_ct1, + ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1); + }); + }); + } +} + +template +static void dequantize_row_iq2_xs_sycl(const void *vx, dst_t *y, const int k, + dpct::queue_ptr stream) { + const int nb = k / QK_K; + { + iq2xs_grid.init(*stream); + ksigns_iq2xs.init(*stream); + kmask_iq2xs.init(*stream); + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + auto iq2xs_grid_ptr_ct1 = iq2xs_grid.get_ptr(); + auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr(); + auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr(); + + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq2_xs( + vx, y, item_ct1, iq2xs_grid_ptr_ct1, + ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1); + }); + }); + } +} + +template +static void dequantize_row_iq3_xxs_sycl(const void *vx, dst_t *y, const int k, + dpct::queue_ptr stream) { + const int nb = k / QK_K; + { + iq3xxs_grid.init(*stream); + ksigns_iq2xs.init(*stream); + kmask_iq2xs.init(*stream); + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr(); + auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr(); + + cgh.parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * + sycl::range<3>(1, 1, 32), + sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_iq3_xxs( + vx, y, item_ct1, iq3xxs_grid_ptr_ct1, + ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1); + }); + }); + } +} + +template +static void convert_unary_sycl(const void *__restrict__ vx, + dst_t *__restrict__ y, const int k, + dpct::queue_ptr stream) { + const int num_blocks = (k + SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / SYCL_DEQUANTIZE_BLOCK_SIZE; + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * + sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + convert_unary(vx, y, k, item_ct1); + }); + } +} + + +static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) try { + int id; switch (type) { case GGML_TYPE_Q4_0: return dequantize_block_sycl; @@ -9215,19 +10519,30 @@ static to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type) { return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: return dequantize_row_q6_K_sycl; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_sycl; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_sycl; + case GGML_TYPE_IQ3_XXS: + return dequantize_row_iq3_xxs_sycl; case GGML_TYPE_F32: - return dequantize_block_sycl<1, 1, convert_f32>; + return convert_unary_sycl; default: return nullptr; } } +catch (sycl::exception const &exc) { + std::cerr << exc.what() << "Exception caught at file:" << __FILE__ + << ", line:" << __LINE__ << std::endl; + std::exit(1); +} static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: - return dequantize_block_sycl; + return dequantize_row_q4_0_sycl; case GGML_TYPE_Q4_1: - return dequantize_block_sycl; + return dequantize_row_q4_1_sycl; case GGML_TYPE_Q5_0: return dequantize_block_sycl; case GGML_TYPE_Q5_1: @@ -9244,8 +10559,14 @@ static to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type) { return dequantize_row_q5_K_sycl; case GGML_TYPE_Q6_K: return dequantize_row_q6_K_sycl; + case GGML_TYPE_IQ2_XXS: + return dequantize_row_iq2_xxs_sycl; + case GGML_TYPE_IQ2_XS: + return dequantize_row_iq2_xs_sycl; + case GGML_TYPE_IQ3_XXS: + return dequantize_row_iq3_xxs_sycl; case GGML_TYPE_F16: - return dequantize_block_sycl<1, 1, convert_f16>; + return convert_unary_sycl; default: return nullptr; } @@ -9455,24 +10776,385 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, } } -template -static void mul_mat_vec_q_sycl_submitter(const void *vx, const void *vy, + +static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq2xxs_grid.init(*stream); + ksigns_iq2xs.init(*stream); + kmask_iq2xs.init(*stream); + + + stream->submit([&](sycl::handler &cgh) { + auto iq2xxs_grid_ptr_ct1 = iq2xxs_grid.get_ptr(); + auto ksigns_iq2xs_ptr_ct1 = ksigns_iq2xs.get_ptr(); + auto kmask_iq2xs_ptr_ct1 = kmask_iq2xs.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq2_xxs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1, + iq2xxs_grid_ptr_ct1, ksigns_iq2xs_ptr_ct1, kmask_iq2xs_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, dpct::queue_ptr stream) { - GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; - const sycl::range<3> block_nums(1, 1, block_num_y); - const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); - stream->parallel_for( - sycl::nd_range<3>(block_nums * block_dims, block_dims), [= - ](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] { - mul_mat_vec_q( - vx, vy, dst, ncols, nrows, item_ct1); - }); + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq2xs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq2xs_grid_ptr_ct1 = iq2xs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq2_xs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1, + iq2xs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } +} + +static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + iq3xxs_grid.init(*stream); + ksigns64.init(*stream); + + stream->submit([&](sycl::handler &cgh) { + auto iq3xxs_grid_ptr_ct1 = iq3xxs_grid.get_ptr(); + auto ksigns64_ptr_ct1 = ksigns64.get_ptr(); + + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[intel::reqd_sub_group_size(32)]] { + mul_mat_vec_q_iq3_xxs_q8_1( + vx, vy, dst, ncols, nrows, item_ct1, + iq3xxs_grid_ptr_ct1, ksigns64_ptr_ct1); + }); + }); + } } + static void ggml_mul_mat_q4_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols_x, const int nrows_x, const int ncols_y, @@ -11451,7 +13133,7 @@ struct sycl_pool_alloc { device_id = get_current_device_id(); device_index = g_sycl_gpu_mgr->get_index(device_id); ptr = (T *) ggml_sycl_pool_malloc(device_index, size * sizeof(T), &this->actual_size); - // GGML_SYCL_DEBUG("alloc %lu return %p actual size=%lu\n", size * sizeof(T), ptr, this->actual_size); + // GGML_SYCL_DEBUG("sycl_pool_alloc %lu return %p actual size=%lu\n", size * sizeof(T), ptr, this->actual_size); return ptr; } @@ -12242,63 +13924,46 @@ inline void ggml_sycl_op_mul_mat_vec_q( const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; - // TODO: support these quantization types - GGML_ASSERT(!(src0->type == GGML_TYPE_IQ2_XXS || - src0->type == GGML_TYPE_IQ2_XS || - src0->type == GGML_TYPE_IQ3_XXS || - src0->type == GGML_TYPE_IQ1_S)); - switch (src0->type) { case GGML_TYPE_Q4_0: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q4_1: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q5_0: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q5_1: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q8_0: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q2_K: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q3_K: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q4_K: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q5_K: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; case GGML_TYPE_Q6_K: - mul_mat_vec_q_sycl_submitter( - src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); - break; + mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XXS: + mul_mat_vec_iq2_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_vec_iq2_xs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_vec_iq3_xxs_q8_1_sycl(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream); + break; default: GGML_ASSERT(false); break; @@ -12311,6 +13976,7 @@ inline void ggml_sycl_op_mul_mat_vec_q( (void) src1_padded_row_size; } + inline void ggml_sycl_op_dequantize_mul_mat_vec( const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i, @@ -12318,10 +13984,11 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec( const int64_t src1_ncols, const int64_t src1_padded_row_size, const dpct::queue_ptr &stream) { - GGML_TENSOR_BINARY_OP_LOCALS; - + const int64_t ne00 = src0->ne[0]; const int64_t row_diff = row_high - row_low; + GGML_ASSERT(src1->type == GGML_TYPE_F32); + // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics #ifdef GGML_SYCL_F16 sycl_pool_alloc src1_dfloat_a; @@ -12333,15 +14000,10 @@ inline void ggml_sycl_op_dequantize_mul_mat_vec( src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; if (src1_convert_f16) { - if (src1->type == GGML_TYPE_F16) { - src1_dfloat = (sycl::half *)src1->data + src1_padded_row_size; - } else { - src1_dfloat = src1_dfloat_a.alloc(ne00); - ggml_cpy_f32_f16_sycl((const char *)src1_ddf_i, (char *)src1_dfloat, - ne00, ne00, ne01, ne02, nb00, nb01, nb02, - nb03, ne10, ne11, ne12, nb10, nb11, nb12, - nb13, stream); - } + src1_dfloat = src1_dfloat_a.alloc(ne00); + const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type); + GGML_ASSERT(to_fp16_sycl != nullptr); + to_fp16_sycl(src1_ddf_i, src1_dfloat, ne00, stream); } #else const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion @@ -12495,7 +14157,7 @@ inline void ggml_sycl_op_mul_mat_sycl( *g_sycl_handles[id], oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *g_sycl_handles[id]), src0_ddf_i, ne00, - src1_ddf_i, ne10, dpct::get_value(&beta, *g_sycl_handles[id]), + src1_ddf1_i, ne10, dpct::get_value(&beta, *g_sycl_handles[id]), dst_dd_i, ldc))); } (void) dst; @@ -12923,7 +14585,7 @@ static void ggml_sycl_op_flatten(const ggml_tensor *src0, // copy dst to host if necessary if (!dst_on_device) { SYCL_CHECK(CHECK_TRY_ERROR( - main_stream->memcpy(dst->data, dst_ddf, ggml_nbytes(dst)))); + main_stream->memcpy(dst->data, dst_ddf, ggml_nbytes(dst)).wait())); } if (dst->backend == GGML_BACKEND_TYPE_CPU) { @@ -13200,7 +14862,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy( src1_ddq_i, src1_ddq_i_source, src1_ncols * src1_padded_col_size * q8_1_ts / - q8_1_bs))); + q8_1_bs).wait())); } else { float * src1_ddf_i_source = (float *) src1_extra->data_device[g_main_device]; @@ -13294,7 +14956,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0, dhf_dst_i += src1_col_0*ne0; SYCL_CHECK(CHECK_TRY_ERROR( stream->memcpy(dhf_dst_i, dst_dd_i, - src1_ncols * ne0 * sizeof(float)))); + src1_ncols * ne0 * sizeof(float)).wait())); } } @@ -13852,13 +15514,13 @@ static __global__ void k_compute_batched_ptrs_id( src0_f16 = (half *) srcs_ar[i]; } else { src0_f16 = src0_as_f16; - if (threadIdx.x == 0 && threadIdx.y == 0) { + if (item_ct1.get_local_id(2) == 0 && threadIdx.y == 0) { const to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(src0_type); to_fp16(srcs_ar[i], src0_f16, src0_ne, syclStreamFireAndForget); } } - int i13 = blockIdx.x * blockDim.x + threadIdx.x; + int i13 = blockIdx.x * blockDim.x + item_ct1.get_local_id(2); int i12 = blockIdx.y * blockDim.y + threadIdx.y; if (i13 >= ne13 || i12 >= ne12) { @@ -14024,8 +15686,8 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, if (ids->backend == GGML_BACKEND_TYPE_GPU) { const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; SYCL_CHECK(CHECK_TRY_ERROR( - stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)))); - SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); + stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids)).wait())); + // SYCL_CHECK(CHECK_TRY_ERROR(stream->wait())); } else { memcpy(ids_host.data(), ids->data, ggml_nbytes(ids)); } @@ -14095,7 +15757,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, SYCL_CHECK(CHECK_TRY_ERROR( stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11, - src1_original + i01 * nb11, nb11))); + src1_original + i01 * nb11, nb11).wait())); num_src1_rows++; } @@ -14128,7 +15790,7 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0, SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy( dst_original + i01 * nb1, - dst_contiguous.get() + num_src1_rows * nb1, nb1))); + dst_contiguous.get() + num_src1_rows * nb1, nb1).wait())); num_src1_rows++; } } @@ -15522,7 +17184,7 @@ GGML_CALL static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend, GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy( - (char *)tensor->data + offset, data, size))); + (char *)tensor->data + offset, data, size).wait())); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -15538,7 +17200,7 @@ GGML_CALL static void ggml_backend_sycl_get_tensor_async(ggml_backend_t backend, GGML_ASSERT(tensor->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type"); GGML_ASSERT(tensor->backend == GGML_BACKEND_TYPE_GPU); SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy( - data, (const char *)tensor->data + offset, size))); + data, (const char *)tensor->data + offset, size).wait())); } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -15557,7 +17219,7 @@ GGML_CALL static bool ggml_backend_sycl_cpy_tensor_async(ggml_backend_t backend, was inserted. You need to rewrite this code. */ SYCL_CHECK(CHECK_TRY_ERROR(g_syclStreams[sycl_ctx->device][0]->memcpy( - dst->data, src->data, ggml_nbytes(dst)))); + dst->data, src->data, ggml_nbytes(dst)).wait())); return true; } @@ -15647,20 +17309,12 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons if (a->ne[3] != b->ne[3]) { return false; } - - if (a->type == GGML_TYPE_IQ1_S) { - return false; - } - if (a->type == GGML_TYPE_IQ3_XXS) { - return false; - } - if (a->type == GGML_TYPE_IQ2_XXS) { - return false; - } - if (a->type == GGML_TYPE_IQ2_XS) { + ggml_type a_type = a->type; + if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || + a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ3_S || + a_type == GGML_TYPE_IQ2_S || a_type == GGML_TYPE_IQ4_XS) { return false; } - return true; } break; case GGML_OP_GET_ROWS: @@ -15705,15 +17359,15 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons } return false; } break; - case GGML_OP_DUP: - case GGML_OP_REPEAT: case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; } break; + case GGML_OP_DUP: case GGML_OP_NONE: case GGML_OP_RESHAPE: + case GGML_OP_REPEAT: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: