From aa2effb4e0579e8f9b3edf2bf4627567e179ab01 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Wed, 8 Nov 2023 06:30:37 +0800 Subject: [PATCH] feat: sync whisper.cpp and enable metal on ios (#154) * feat: sync whisper.cpp and update patches * feat(ios): enable metal * fix(android): build * fix: tests * fix(ios): update ggml-metal patch * fix(ios): force disable coreml when metal enabled * feat(example): update options * fix: tests * feat: sync whisper.cpp * feat(ts): update useGpu option doc * chore: apply patches * feat(ios, example): update RNWHISPER_DISABLE_METAL usage * feat: sync whisper.cpp --- android/src/main/CMakeLists.txt | 2 + .../main/java/com/rnwhisper/RNWhisper.java | 7 +- cpp/ggml-alloc.c | 693 +- cpp/ggml-alloc.h | 75 +- cpp/ggml-backend-impl.h | 87 + cpp/ggml-backend.c | 950 ++ cpp/ggml-backend.h | 136 + cpp/ggml-impl.h | 243 + cpp/ggml-metal-whisper.metal | 712 +- cpp/ggml-metal.h | 21 + cpp/ggml-metal.m | 853 +- cpp/ggml-quants.c | 7377 ++++++++++++++++ cpp/ggml-quants.h | 224 + cpp/ggml.c | 7650 ++++++++--------- cpp/ggml.h | 295 +- cpp/whisper.cpp | 285 +- cpp/whisper.h | 59 +- docs/API/README.md | 37 +- docs/API/classes/WhisperContext.md | 36 +- docs/API/enums/AudioSessionCategoryIos.md | 12 +- .../enums/AudioSessionCategoryOptionIos.md | 14 +- docs/API/enums/AudioSessionModeIos.md | 16 +- example/ios/Podfile | 3 +- example/ios/Podfile.lock | 6 +- example/src/context-opts.ios.ts | 26 +- ios/RNWhisper.mm | 8 +- ios/RNWhisperContext.h | 6 +- ios/RNWhisperContext.mm | 78 +- jest/mock.js | 2 +- scripts/bootstrap.sh | 12 + scripts/ggml-metal.m.patch | 44 +- scripts/whisper.cpp.patch | 63 +- scripts/whisper.h.patch | 22 +- src/NativeRNWhisper.ts | 9 +- src/index.ts | 24 +- src/version.json | 2 +- whisper-rn.podspec | 3 +- whisper.cpp | 2 +- 38 files changed, 14925 insertions(+), 5169 deletions(-) create mode 100644 cpp/ggml-backend-impl.h create mode 100644 cpp/ggml-backend.c create mode 100644 cpp/ggml-backend.h create mode 100644 cpp/ggml-impl.h create mode 100644 cpp/ggml-quants.c create mode 100644 cpp/ggml-quants.h diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index 9ade9fa..febe580 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -9,6 +9,8 @@ set( SOURCE_FILES ${RNWHISPER_LIB_DIR}/ggml.c ${RNWHISPER_LIB_DIR}/ggml-alloc.c + ${RNWHISPER_LIB_DIR}/ggml-backend.c + ${RNWHISPER_LIB_DIR}/ggml-quants.c ${RNWHISPER_LIB_DIR}/whisper.cpp ${RNWHISPER_LIB_DIR}/rn-whisper.cpp ${CMAKE_SOURCE_DIR}/jni.cpp diff --git a/android/src/main/java/com/rnwhisper/RNWhisper.java b/android/src/main/java/com/rnwhisper/RNWhisper.java index 32c90be..22da1d1 100644 --- a/android/src/main/java/com/rnwhisper/RNWhisper.java +++ b/android/src/main/java/com/rnwhisper/RNWhisper.java @@ -13,6 +13,7 @@ import com.facebook.react.bridge.LifecycleEventListener; import com.facebook.react.bridge.ReadableMap; import com.facebook.react.bridge.WritableMap; +import com.facebook.react.bridge.Arguments; import java.util.HashMap; import java.util.Random; @@ -107,7 +108,11 @@ protected void onPostExecute(Integer id) { promise.reject(exception); return; } - promise.resolve(id); + WritableMap result = Arguments.createMap(); + result.putInt("contextId", id); + result.putBoolean("gpu", false); + result.putString("reasonNoGPU", "Currently not supported"); + promise.resolve(result); tasks.remove(this); } }.execute(); diff --git a/cpp/ggml-alloc.c b/cpp/ggml-alloc.c index d8d79b9..da4e4ba 100644 --- a/cpp/ggml-alloc.c +++ b/cpp/ggml-alloc.c @@ -1,69 +1,21 @@ #include "ggml-alloc.h" +#include "ggml-backend-impl.h" #include "ggml.h" +#include "ggml-impl.h" #include +#include #include #include #include #include -#ifdef __has_include - #if __has_include() - #include - #if defined(_POSIX_MAPPED_FILES) - #include - #include - #endif - #endif -#endif - -#if defined(_WIN32) - #define WIN32_LEAN_AND_MEAN - #ifndef NOMINMAX - #define NOMINMAX - #endif - #include - #include -#endif - - -#define UNUSED(x) (void)(x) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define WSP_GGML_MAX_CONCUR (2*WSP_GGML_MAX_NODES) +#define MAX_FREE_BLOCKS 256 //#define WSP_GGML_ALLOCATOR_DEBUG -//#define AT_PRINTF printf -#define AT_PRINTF(...) ((void)0) - -struct hash_node { - struct wsp_ggml_tensor * t; - int n_children; - int n_views; -}; - -static size_t hash(void * p) { - return (size_t)p % WSP_GGML_GRAPH_HASHTABLE_SIZE; -} - -static struct hash_node * hash_get(struct hash_node hash_table[], struct wsp_ggml_tensor * t) { - size_t h = hash(t); - - // linear probing - size_t i = h; - while (hash_table[i].t != NULL) { - if (hash_table[i].t == t) { - return &hash_table[i]; - } - i = (i + 1) % WSP_GGML_GRAPH_HASHTABLE_SIZE; - if (i == h) { - // hash table is full - WSP_GGML_ASSERT(false); - } - } - - hash_table[i].t = t; - return &hash_table[i]; -} +//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__) +#define AT_PRINTF(...) // TODO: WSP_GGML_PAD ? static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) { @@ -77,19 +29,18 @@ struct free_block { size_t size; }; -#define MAX_FREE_BLOCKS 128 - -struct wsp_ggml_allocr { - void * data; - size_t size; +struct wsp_ggml_tallocr { + struct wsp_ggml_backend_buffer * buffer; + bool buffer_owned; + void * base; size_t alignment; + int n_free_blocks; struct free_block free_blocks[MAX_FREE_BLOCKS]; - struct hash_node hash_table[WSP_GGML_GRAPH_HASHTABLE_SIZE]; + size_t max_size; + bool measure; - int parse_seq[WSP_GGML_MAX_CONCUR]; - int parse_seq_len; #ifdef WSP_GGML_ALLOCATOR_DEBUG struct wsp_ggml_tensor * allocated_tensors[1024]; @@ -97,7 +48,7 @@ struct wsp_ggml_allocr { }; #ifdef WSP_GGML_ALLOCATOR_DEBUG -static void add_allocated_tensor(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) { +static void add_allocated_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) { for (int i = 0; i < 1024; i++) { if (alloc->allocated_tensors[i] == NULL) { alloc->allocated_tensors[i] = tensor; @@ -106,7 +57,7 @@ static void add_allocated_tensor(struct wsp_ggml_allocr * alloc, struct wsp_ggml } WSP_GGML_ASSERT(!"out of allocated_tensors"); } -static void remove_allocated_tensor(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) { +static void remove_allocated_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) { for (int i = 0; i < 1024; i++) { if (alloc->allocated_tensors[i] == tensor || (alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) { @@ -119,28 +70,20 @@ static void remove_allocated_tensor(struct wsp_ggml_allocr * alloc, struct wsp_g } #endif -static size_t wsp_ggml_allocr_get_alloc_size(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) { - return wsp_ggml_nbytes(tensor); - - UNUSED(alloc); -} - // check if a tensor is allocated by this buffer -static bool wsp_ggml_allocr_is_own(struct wsp_ggml_allocr * alloc, const struct wsp_ggml_tensor * tensor) { - void * ptr = tensor->data; - return ptr >= alloc->data && (char *)ptr < (char *)alloc->data + alloc->max_size; +static bool wsp_ggml_tallocr_is_own(wsp_ggml_tallocr_t alloc, const struct wsp_ggml_tensor * tensor) { + return tensor->buffer == alloc->buffer; } static bool wsp_ggml_is_view(struct wsp_ggml_tensor * t) { return t->view_src != NULL; } -void wsp_ggml_allocr_alloc(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) { -#ifdef WSP_GGML_ALLOCATOR_DEBUG +void wsp_ggml_tallocr_alloc(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) { WSP_GGML_ASSERT(!wsp_ggml_is_view(tensor)); // views generally get data pointer from one of their sources WSP_GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated -#endif - size_t size = wsp_ggml_allocr_get_alloc_size(alloc, tensor); + + size_t size = wsp_ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); size = aligned_offset(NULL, size, alloc->alignment); AT_PRINTF("%s: allocating %s (%zu bytes) - ", __func__, tensor->name, size); @@ -187,6 +130,10 @@ void wsp_ggml_allocr_alloc(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tenso } tensor->data = addr; + tensor->buffer = alloc->buffer; + if (!alloc->measure) { + wsp_ggml_backend_buffer_init_tensor(alloc->buffer, tensor); + } #ifdef WSP_GGML_ALLOCATOR_DEBUG add_allocated_tensor(alloc, tensor); @@ -202,23 +149,28 @@ void wsp_ggml_allocr_alloc(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tenso } #endif - alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size); + alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size); } // this is a very naive implementation, but for our case the number of free blocks should be very small -static void wsp_ggml_allocr_free_tensor(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor) { - void * ptr = tensor->data; - - if (wsp_ggml_allocr_is_own(alloc, tensor) == false) { +static void wsp_ggml_tallocr_free_tensor(wsp_ggml_tallocr_t alloc, struct wsp_ggml_tensor * tensor) { + if (wsp_ggml_tallocr_is_own(alloc, tensor) == false) { // the tensor was not allocated in this buffer // this can happen because the graph allocator will try to free weights and other tensors from different buffers // the easiest way to deal with this is just to ignore it + // AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer); return; } - size_t size = wsp_ggml_allocr_get_alloc_size(alloc, tensor); + void * ptr = tensor->data; + + size_t size = wsp_ggml_backend_buffer_get_alloc_size(alloc->buffer, tensor); size = aligned_offset(NULL, size, alloc->alignment); - AT_PRINTF("%s: freeing %s (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, size, alloc->n_free_blocks); + AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks); + + if (!alloc->measure) { + wsp_ggml_backend_buffer_free_tensor(alloc->buffer, tensor); + } #ifdef WSP_GGML_ALLOCATOR_DEBUG remove_allocated_tensor(alloc, tensor); @@ -272,136 +224,180 @@ static void wsp_ggml_allocr_free_tensor(struct wsp_ggml_allocr * alloc, struct w alloc->n_free_blocks++; } -void wsp_ggml_allocr_set_parse_seq(struct wsp_ggml_allocr * alloc, const int * list, int n) { - for (int i = 0; i < n; i++) { - alloc->parse_seq[i] = list[i]; +void wsp_ggml_tallocr_reset(wsp_ggml_tallocr_t alloc) { + alloc->n_free_blocks = 1; + size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment); + alloc->free_blocks[0].addr = (char *)alloc->base + align_offset; + + if (alloc->measure) { + alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows + } else { + alloc->free_blocks[0].size = wsp_ggml_backend_buffer_get_size(alloc->buffer) - align_offset; } - alloc->parse_seq_len = n; } -void wsp_ggml_allocr_reset(struct wsp_ggml_allocr * alloc) { - alloc->n_free_blocks = 1; - size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment); - alloc->free_blocks[0].addr = (char *)alloc->data + align_offset; - alloc->free_blocks[0].size = alloc->size - align_offset; -} +wsp_ggml_tallocr_t wsp_ggml_tallocr_new(void * data, size_t size, size_t alignment) { + struct wsp_ggml_backend_buffer * buffer = wsp_ggml_backend_cpu_buffer_from_ptr(NULL, data, size); -struct wsp_ggml_allocr * wsp_ggml_allocr_new(void * data, size_t size, size_t alignment) { - struct wsp_ggml_allocr * alloc = (struct wsp_ggml_allocr *)malloc(sizeof(struct wsp_ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); + wsp_ggml_tallocr_t alloc = (wsp_ggml_tallocr_t)malloc(sizeof(struct wsp_ggml_tallocr)); - *alloc = (struct wsp_ggml_allocr){ - /*.data = */ data, - /*.size = */ size, + *alloc = (struct wsp_ggml_tallocr) { + /*.buffer = */ buffer, + /*.buffer_owned = */ true, + /*.base = */ wsp_ggml_backend_buffer_get_base(buffer), /*.alignment = */ alignment, /*.n_free_blocks = */ 0, /*.free_blocks = */ {{0}}, - /*.hash_table = */ {{0}}, /*.max_size = */ 0, /*.measure = */ false, - /*.parse_seq = */ {0}, - /*.parse_seq_len = */ 0, #ifdef WSP_GGML_ALLOCATOR_DEBUG /*.allocated_tensors = */ {0}, #endif }; - wsp_ggml_allocr_reset(alloc); + wsp_ggml_tallocr_reset(alloc); return alloc; } -// OS specific functions to allocate and free uncommitted virtual memory -static void * alloc_vmem(size_t size) { -#if defined(_WIN32) - return VirtualAlloc(NULL, size, MEM_RESERVE, PAGE_NOACCESS); -#elif defined(_POSIX_MAPPED_FILES) - void * ptr = mmap(NULL, size, PROT_NONE, MAP_PRIVATE | MAP_ANON, -1, 0); - if (ptr == MAP_FAILED) { - return NULL; - } - return ptr; -#else - // use a fixed address for other platforms - uintptr_t base_addr = (uintptr_t)-size - 0x100; - return (void *)base_addr; -#endif -} +wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure(size_t alignment) { + wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment); + alloc->measure = true; -static void free_vmem(void * base_addr, size_t size) { -#if defined(_WIN32) - VirtualFree(base_addr, 0, MEM_RELEASE); - UNUSED(size); -#elif defined(_POSIX_MAPPED_FILES) - munmap(base_addr, size); -#else - // nothing to do - UNUSED(base_addr); - UNUSED(size); -#endif + return alloc; } -// allocate uncommitted virtual memory to measure the size of the graph -static void alloc_measure_vmem(void ** base_addr, size_t * size) { - // 128GB for 64-bit, 1GB for 32-bit - *size = sizeof(void *) == 4 ? 1ULL<<30 : 1ULL<<37; - do { - *base_addr = alloc_vmem(*size); - if (*base_addr != NULL) { - AT_PRINTF("allocated %.2f GB of virtual memory for measure buffer at %p\n", *size / 1024.0 / 1024.0 / 1024.0, *base_addr); - return; - } - // try again with half the size - *size /= 2; - } while (*size > 0); +wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_backend(struct wsp_ggml_backend * backend) { + // create a backend buffer to get the correct tensor allocation sizes + wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_buffer(backend, 1); - WSP_GGML_ASSERT(!"failed to allocate virtual memory for measure buffer"); + // TODO: move alloc initialization to a common wsp_ggml_tallocr_new_impl function + wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new_from_buffer(buffer); + alloc->buffer_owned = true; + alloc->measure = true; + wsp_ggml_tallocr_reset(alloc); + return alloc; } -static void free_measure_vmem(void * base_addr, size_t size) { - free_vmem(base_addr, size); +wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size) { + wsp_ggml_backend_buffer_t buffer = wsp_ggml_backend_alloc_buffer(backend, size); + wsp_ggml_tallocr_t alloc = wsp_ggml_tallocr_new_from_buffer(buffer); + alloc->buffer_owned = true; + return alloc; } -struct wsp_ggml_allocr * wsp_ggml_allocr_new_measure(size_t alignment) { - struct wsp_ggml_allocr * alloc = (struct wsp_ggml_allocr *)malloc(sizeof(struct wsp_ggml_allocr) /* + n_free_blocks * sizeof(struct free_block) */); +wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer) { + wsp_ggml_tallocr_t alloc = (wsp_ggml_tallocr_t)malloc(sizeof(struct wsp_ggml_tallocr)); - void * base_addr; - size_t size; - - alloc_measure_vmem(&base_addr, &size); - - *alloc = (struct wsp_ggml_allocr){ - /*.data = */ base_addr, - /*.size = */ size, - /*.alignment = */ alignment, + *alloc = (struct wsp_ggml_tallocr) { + /*.buffer = */ buffer, + /*.buffer_owned = */ false, + /*.base = */ wsp_ggml_backend_buffer_get_base(buffer), + /*.alignment = */ wsp_ggml_backend_buffer_get_alignment(buffer), /*.n_free_blocks = */ 0, /*.free_blocks = */ {{0}}, - /*.hash_table = */ {{0}}, /*.max_size = */ 0, - /*.measure = */ true, - /*.parse_seq = */ {0}, - /*.parse_seq_len = */ 0, + /*.measure = */ false, #ifdef WSP_GGML_ALLOCATOR_DEBUG /*.allocated_tensors = */ {0}, #endif }; - wsp_ggml_allocr_reset(alloc); + wsp_ggml_tallocr_reset(alloc); return alloc; } -void wsp_ggml_allocr_free(struct wsp_ggml_allocr * alloc) { - if (alloc->measure) { - free_measure_vmem(alloc->data, alloc->size); +struct wsp_ggml_backend_buffer * wsp_ggml_tallocr_get_buffer(wsp_ggml_tallocr_t alloc) { + return alloc->buffer; +} + +void wsp_ggml_tallocr_free(wsp_ggml_tallocr_t alloc) { + if (alloc == NULL) { + return; + } + + if (alloc->buffer_owned) { + wsp_ggml_backend_buffer_free(alloc->buffer); } free(alloc); } -bool wsp_ggml_allocr_is_measure(struct wsp_ggml_allocr * alloc) { +bool wsp_ggml_tallocr_is_measure(wsp_ggml_tallocr_t alloc) { return alloc->measure; } -//////////// compute graph allocator +size_t wsp_ggml_tallocr_max_size(wsp_ggml_tallocr_t alloc) { + return alloc->max_size; +} + +// graph allocator + +struct hash_node { + int n_children; + int n_views; +}; + +struct wsp_ggml_gallocr { + wsp_ggml_tallocr_t talloc; + struct wsp_ggml_hash_set hash_set; + struct hash_node * hash_values; + size_t hash_values_size; + wsp_ggml_tallocr_t * hash_allocs; + int * parse_seq; + int parse_seq_len; +}; + +wsp_ggml_gallocr_t wsp_ggml_gallocr_new(void) { + wsp_ggml_gallocr_t galloc = (wsp_ggml_gallocr_t)malloc(sizeof(struct wsp_ggml_gallocr)); + + *galloc = (struct wsp_ggml_gallocr) { + /*.talloc = */ NULL, + /*.hash_set = */ {0}, + /*.hash_values = */ NULL, + /*.hash_values_size = */ 0, + /*.hash_allocs = */ NULL, + /*.parse_seq = */ NULL, + /*.parse_seq_len = */ 0, + }; + + return galloc; +} + +void wsp_ggml_gallocr_free(wsp_ggml_gallocr_t galloc) { + if (galloc == NULL) { + return; + } + + if (galloc->hash_set.keys != NULL) { + free(galloc->hash_set.keys); + } + if (galloc->hash_values != NULL) { + free(galloc->hash_values); + } + if (galloc->hash_allocs != NULL) { + free(galloc->hash_allocs); + } + if (galloc->parse_seq != NULL) { + free(galloc->parse_seq); + } + free(galloc); +} + +void wsp_ggml_gallocr_set_parse_seq(wsp_ggml_gallocr_t galloc, const int * list, int n) { + free(galloc->parse_seq); + galloc->parse_seq = malloc(sizeof(int) * n); + + for (int i = 0; i < n; i++) { + galloc->parse_seq[i] = list[i]; + } + galloc->parse_seq_len = n; +} + +static struct hash_node * hash_get(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * t) { + size_t i = wsp_ggml_hash_find_or_insert(galloc->hash_set, t); + return &galloc->hash_values[i]; +} static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) { if (a->type != b->type) { @@ -435,7 +431,6 @@ static bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op) { case WSP_GGML_OP_ROPE: case WSP_GGML_OP_RMS_NORM: case WSP_GGML_OP_SOFT_MAX: - case WSP_GGML_OP_CONT: return true; default: @@ -443,12 +438,38 @@ static bool wsp_ggml_op_can_inplace(enum wsp_ggml_op op) { } } -static void allocate_node(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * node) { - struct hash_node * ht = alloc->hash_table; +static wsp_ggml_tallocr_t node_tallocr(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) { + if (galloc->talloc != NULL) { + return galloc->talloc; + } + + return galloc->hash_allocs[wsp_ggml_hash_find_or_insert(galloc->hash_set, node)]; +} + +static void init_view(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * view) { + wsp_ggml_tallocr_t alloc = node_tallocr(galloc, view); + + //printf("init_view: %s from src %s\n", view->name, view->view_src->name); + WSP_GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL); + view->backend = view->view_src->backend; + view->buffer = view->view_src->buffer; + view->data = (char *)view->view_src->data + view->view_offs; + + // FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend + // due to the wsp_ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras + assert(wsp_ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend); + + if (!alloc->measure) { + wsp_ggml_backend_buffer_init_tensor(alloc->buffer, view); + } +} + +static void allocate_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) { + wsp_ggml_tallocr_t alloc = node_tallocr(galloc, node); + if (node->data == NULL) { if (wsp_ggml_is_view(node)) { - assert(node->view_src->data != NULL); - node->data = (char *)node->view_src->data + node->view_offs; + init_view(galloc, node); } else { // see if we can reuse a parent's buffer (inplace) if (wsp_ggml_op_can_inplace(node->op)) { @@ -459,16 +480,16 @@ static void allocate_node(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor } // if the node's data is external, then we cannot re-use it - if (wsp_ggml_allocr_is_own(alloc, parent) == false) { + if (wsp_ggml_tallocr_is_own(alloc, parent) == false) { AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data); continue; } - struct hash_node * p_hn = hash_get(ht, parent); + struct hash_node * p_hn = hash_get(galloc, parent); if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && wsp_ggml_are_same_layout(node, parent)) { if (wsp_ggml_is_view(parent)) { struct wsp_ggml_tensor * view_src = parent->view_src; - struct hash_node * view_src_hn = hash_get(ht, view_src); + struct hash_node * view_src_hn = hash_get(galloc, view_src); if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { // TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite // the parent's data that it will need later (same layout requirement). the problem is that then @@ -476,158 +497,270 @@ static void allocate_node(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor // adding a view_src pointer to the tensor would solve this and simplify the code dealing with views // for now, we only reuse the parent's data if the offset is zero (view_src->data == parent->data) AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name); - node->data = parent->data; + node->view_src = view_src; + view_src_hn->n_views += 1; + init_view(galloc, node); return; } } else { AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name); - node->data = parent->data; + node->view_src = parent; + p_hn->n_views += 1; + init_view(galloc, node); return; } } } } - wsp_ggml_allocr_alloc(alloc, node); + wsp_ggml_tallocr_alloc(alloc, node); } } } -static size_t wsp_ggml_allocr_alloc_graph_tensors_n( - struct wsp_ggml_allocr * alloc, - struct wsp_ggml_cgraph ** graphs, int n_graphs, - struct wsp_ggml_tensor *** inputs, struct wsp_ggml_tensor *** outputs) { +static void free_node(wsp_ggml_gallocr_t galloc, struct wsp_ggml_tensor * node) { + wsp_ggml_tallocr_t alloc = node_tallocr(galloc, node); - // reset hash table - struct hash_node * ht = alloc->hash_table; - memset(ht, 0, sizeof(struct hash_node) * WSP_GGML_GRAPH_HASHTABLE_SIZE); + wsp_ggml_tallocr_free_tensor(alloc, node); +} + +static void wsp_ggml_tallocr_alloc_graph_impl(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * gf) { + const int * parse_seq = galloc->parse_seq; + int parse_seq_len = galloc->parse_seq_len; // count number of children and views - for (int g = 0; g < n_graphs; g++) { - struct wsp_ggml_cgraph * gf = graphs[g]; - for (int i = 0; i < gf->n_nodes; i++) { - struct wsp_ggml_tensor * node = gf->nodes[i]; + for (int i = 0; i < gf->n_nodes; i++) { + struct wsp_ggml_tensor * node = gf->nodes[i]; + + if (wsp_ggml_is_view(node)) { + struct wsp_ggml_tensor * view_src = node->view_src; + hash_get(galloc, view_src)->n_views += 1; + if (node->buffer == NULL && node->data != NULL) { + // view of a pre-allocated tensor, didn't call init_view() yet + init_view(galloc, node); + } + } - if (wsp_ggml_is_view(node)) { - struct wsp_ggml_tensor * view_src = node->view_src; - hash_get(ht, view_src)->n_views += 1; + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; } + hash_get(galloc, parent)->n_children += 1; + if (wsp_ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) { + init_view(galloc, parent); + } + } + } + // allocate tensors + // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers + int last_barrier_pos = 0; + int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes; + + for (int ind = 0; ind < n_nodes; ind++) { + // allocate a node if there is no parse_seq or this is not a barrier + if (parse_seq_len == 0 || parse_seq[ind] != -1) { + int i = parse_seq_len ? parse_seq[ind] : ind; + struct wsp_ggml_tensor * node = gf->nodes[i]; + + // allocate parents (leafs) for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { struct wsp_ggml_tensor * parent = node->src[j]; if (parent == NULL) { break; } - hash_get(ht, parent)->n_children += 1; + allocate_node(galloc, parent); } - } - } - // allocate tensors - for (int g = 0; g < n_graphs; g++) { - struct wsp_ggml_cgraph * gf = graphs[g]; - AT_PRINTF("####### graph %d/%d\n", g, n_graphs); - // graph inputs are allocated first to ensure that they are not overwritten by each other - if (inputs != NULL && inputs[g] != NULL) { - for (int i = 0; inputs[g][i] != NULL; i++) { - struct wsp_ggml_tensor * input = inputs[g][i]; - AT_PRINTF("input: %s\n", input->name); - allocate_node(alloc, input); + // allocate node + allocate_node(galloc, node); + + AT_PRINTF("exec: %s (%s) <= ", wsp_ggml_op_name(node->op), node->name); + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * parent = node->src[j]; + if (parent == NULL) { + break; + } + AT_PRINTF("%s", parent->name); + if (j < WSP_GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) { + AT_PRINTF(", "); + } } + AT_PRINTF("\n"); } - // if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers - int last_barrier_pos = 0; - int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes; - for (int ind = 0; ind < n_nodes; ind++) { - // allocate a node if there is no parse_seq or this is not a barrier - if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) { - int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind; - struct wsp_ggml_tensor * node = gf->nodes[i]; + // update parents + // update immediately if there is no parse_seq + // update only at barriers if there is parse_seq + if ((parse_seq_len == 0) || parse_seq[ind] == -1) { + int update_start = parse_seq_len ? last_barrier_pos : ind; + int update_end = parse_seq_len ? ind : ind + 1; + for (int i = update_start; i < update_end; i++) { + int node_i = parse_seq_len ? parse_seq[i] : i; + struct wsp_ggml_tensor * node = gf->nodes[node_i]; - // allocate parents (leafs) for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { struct wsp_ggml_tensor * parent = node->src[j]; if (parent == NULL) { break; } - allocate_node(alloc, parent); - } + struct hash_node * p_hn = hash_get(galloc, parent); + p_hn->n_children -= 1; - // allocate node - allocate_node(alloc, node); + //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views); - AT_PRINTF("exec: %s (%s) <= ", wsp_ggml_op_name(node->op), node->name); - for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { - struct wsp_ggml_tensor * parent = node->src[j]; - if (parent == NULL) { - break; - } - AT_PRINTF("%s", parent->name); - if (j < WSP_GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) { - AT_PRINTF(", "); - } - } - AT_PRINTF("\n"); - } - - // update parents - // update immediately if there is no parse_seq - // update only at barriers if there is parse_seq - if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) { - int update_start = alloc->parse_seq_len ? last_barrier_pos : ind; - int update_end = alloc->parse_seq_len ? ind : ind + 1; - for (int i = update_start; i < update_end; i++) { - int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i; - struct wsp_ggml_tensor * node = gf->nodes[node_i]; - - for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { - struct wsp_ggml_tensor * parent = node->src[j]; - if (parent == NULL) { - break; - } - struct hash_node * p_hn = hash_get(ht, parent); - p_hn->n_children -= 1; - - //AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views); - - if (p_hn->n_children == 0 && p_hn->n_views == 0) { - if (wsp_ggml_is_view(parent)) { - struct wsp_ggml_tensor * view_src = parent->view_src; - struct hash_node * view_src_hn = hash_get(ht, view_src); - view_src_hn->n_views -= 1; - AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); - if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) { - wsp_ggml_allocr_free_tensor(alloc, view_src); - } - } - else { - if (parent->data != node->data) { - wsp_ggml_allocr_free_tensor(alloc, parent); - } + if (p_hn->n_children == 0 && p_hn->n_views == 0) { + if (wsp_ggml_is_view(parent)) { + struct wsp_ggml_tensor * view_src = parent->view_src; + struct hash_node * view_src_hn = hash_get(galloc, view_src); + view_src_hn->n_views -= 1; + AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views); + if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) { + free_node(galloc, view_src); } } + else { + free_node(galloc, parent); + } } } - AT_PRINTF("\n"); - if (alloc->parse_seq_len) { - last_barrier_pos = ind + 1; - } } - } - // free graph outputs here that wouldn't be freed otherwise because they have no children - if (outputs != NULL && outputs[g] != NULL) { - for (int i = 0; outputs[g][i] != NULL; i++) { - struct wsp_ggml_tensor * output = outputs[g][i]; - AT_PRINTF("output: %s\n", output->name); - wsp_ggml_allocr_free_tensor(alloc, output); + AT_PRINTF("\n"); + if (parse_seq_len) { + last_barrier_pos = ind + 1; } } } +} - return alloc->max_size; +size_t wsp_ggml_gallocr_alloc_graph(wsp_ggml_gallocr_t galloc, wsp_ggml_tallocr_t talloc, struct wsp_ggml_cgraph * graph) { + size_t hash_size = graph->visited_hash_table.size; + + // check if the hash table is initialized and large enough + if (galloc->hash_set.size < hash_size) { + if (galloc->hash_set.keys != NULL) { + free(galloc->hash_set.keys); + } + if (galloc->hash_values != NULL) { + free(galloc->hash_values); + } + galloc->hash_set.keys = malloc(sizeof(struct wsp_ggml_tensor *) * hash_size); + galloc->hash_set.size = hash_size; + galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size); + } + + // reset hash table + memset(galloc->hash_set.keys, 0, sizeof(struct wsp_ggml_tensor *) * hash_size); + memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size); + + galloc->talloc = talloc; + wsp_ggml_tallocr_alloc_graph_impl(galloc, graph); + galloc->talloc = NULL; + + size_t max_size = wsp_ggml_tallocr_max_size(talloc); + + return max_size; +} + +void wsp_ggml_gallocr_alloc_graph_n(wsp_ggml_gallocr_t galloc, struct wsp_ggml_cgraph * graph, struct wsp_ggml_hash_set hash_set, wsp_ggml_tallocr_t * hash_node_alloct) { + const size_t hash_size = hash_set.size; + + WSP_GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs)); + + galloc->talloc = NULL; + + // alloc hash_values if needed + if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) { + free(galloc->hash_values); + galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size); + galloc->hash_values_size = hash_size; + } + + // free hash_set.keys if needed + if (galloc->hash_set.keys != NULL) { + free(galloc->hash_set.keys); + } + galloc->hash_set = hash_set; + + // reset hash values + memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size); + + galloc->hash_allocs = hash_node_alloct; + + wsp_ggml_tallocr_alloc_graph_impl(galloc, graph); + + // remove unowned resources + galloc->hash_set.keys = NULL; + galloc->hash_allocs = NULL; +} + +// legacy API wrapper + +struct wsp_ggml_allocr { + wsp_ggml_tallocr_t talloc; + wsp_ggml_gallocr_t galloc; +}; + +static wsp_ggml_allocr_t wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_t talloc) { + wsp_ggml_allocr_t alloc = (wsp_ggml_allocr_t)malloc(sizeof(struct wsp_ggml_allocr)); + *alloc = (struct wsp_ggml_allocr) { + /*.talloc = */ talloc, + /*.galloc = */ wsp_ggml_gallocr_new(), + }; + return alloc; +} + +wsp_ggml_allocr_t wsp_ggml_allocr_new(void * data, size_t size, size_t alignment) { + return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new(data, size, alignment)); +} + +wsp_ggml_allocr_t wsp_ggml_allocr_new_measure(size_t alignment) { + return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_measure(alignment)); +} + +wsp_ggml_allocr_t wsp_ggml_allocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer) { + return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_from_buffer(buffer)); +} + +wsp_ggml_allocr_t wsp_ggml_allocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size) { + return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_from_backend(backend, size)); +} + +wsp_ggml_allocr_t wsp_ggml_allocr_new_measure_from_backend(struct wsp_ggml_backend * backend) { + return wsp_ggml_allocr_new_impl(wsp_ggml_tallocr_new_measure_from_backend(backend)); +} + +struct wsp_ggml_backend_buffer * wsp_ggml_allocr_get_buffer(wsp_ggml_allocr_t alloc) { + return wsp_ggml_tallocr_get_buffer(alloc->talloc); +} + +void wsp_ggml_allocr_set_parse_seq(wsp_ggml_allocr_t alloc, const int * list, int n) { + wsp_ggml_gallocr_set_parse_seq(alloc->galloc, list, n); +} + +void wsp_ggml_allocr_free(wsp_ggml_allocr_t alloc) { + wsp_ggml_gallocr_free(alloc->galloc); + wsp_ggml_tallocr_free(alloc->talloc); + free(alloc); +} + +bool wsp_ggml_allocr_is_measure(wsp_ggml_allocr_t alloc) { + return wsp_ggml_tallocr_is_measure(alloc->talloc); +} + +void wsp_ggml_allocr_reset(wsp_ggml_allocr_t alloc) { + wsp_ggml_tallocr_reset(alloc->talloc); +} + +void wsp_ggml_allocr_alloc(wsp_ggml_allocr_t alloc, struct wsp_ggml_tensor * tensor) { + wsp_ggml_tallocr_alloc(alloc->talloc, tensor); +} + +size_t wsp_ggml_allocr_max_size(wsp_ggml_allocr_t alloc) { + return wsp_ggml_tallocr_max_size(alloc->talloc); } -size_t wsp_ggml_allocr_alloc_graph(struct wsp_ggml_allocr * alloc, struct wsp_ggml_cgraph * graph) { - return wsp_ggml_allocr_alloc_graph_tensors_n(alloc, &graph, 1, NULL, NULL); +size_t wsp_ggml_allocr_alloc_graph(wsp_ggml_allocr_t alloc, struct wsp_ggml_cgraph * graph) { + return wsp_ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph); } diff --git a/cpp/ggml-alloc.h b/cpp/ggml-alloc.h index 35454d4..6a39668 100644 --- a/cpp/ggml-alloc.h +++ b/cpp/ggml-alloc.h @@ -6,20 +6,79 @@ extern "C" { #endif +struct wsp_ggml_backend; +struct wsp_ggml_backend_buffer; -WSP_GGML_API struct wsp_ggml_allocr * wsp_ggml_allocr_new(void * data, size_t size, size_t alignment); -WSP_GGML_API struct wsp_ggml_allocr * wsp_ggml_allocr_new_measure(size_t alignment); +// +// Legacy API +// + +typedef struct wsp_ggml_allocr * wsp_ggml_allocr_t; + +// initialize allocator for use with CPU backend only +WSP_GGML_API wsp_ggml_allocr_t wsp_ggml_allocr_new(void * data, size_t size, size_t alignment); +WSP_GGML_API wsp_ggml_allocr_t wsp_ggml_allocr_new_measure(size_t alignment); + +// initialize allocator for use with ggml-backend +WSP_GGML_API wsp_ggml_allocr_t wsp_ggml_allocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer); +WSP_GGML_API wsp_ggml_allocr_t wsp_ggml_allocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size); // allocates an owned buffer +WSP_GGML_API wsp_ggml_allocr_t wsp_ggml_allocr_new_measure_from_backend(struct wsp_ggml_backend * backend); + +WSP_GGML_API struct wsp_ggml_backend_buffer * wsp_ggml_allocr_get_buffer(wsp_ggml_allocr_t alloc); // tell the allocator to parse nodes following the order described in the list // you should call this if your graph are optimized to execute out-of-order -WSP_GGML_API void wsp_ggml_allocr_set_parse_seq(struct wsp_ggml_allocr * alloc, const int * list, int n); +WSP_GGML_API void wsp_ggml_allocr_set_parse_seq(wsp_ggml_allocr_t alloc, const int * list, int n); + +WSP_GGML_API void wsp_ggml_allocr_free (wsp_ggml_allocr_t alloc); +WSP_GGML_API bool wsp_ggml_allocr_is_measure (wsp_ggml_allocr_t alloc); +WSP_GGML_API void wsp_ggml_allocr_reset (wsp_ggml_allocr_t alloc); +WSP_GGML_API void wsp_ggml_allocr_alloc (wsp_ggml_allocr_t alloc, struct wsp_ggml_tensor * tensor); +WSP_GGML_API size_t wsp_ggml_allocr_max_size (wsp_ggml_allocr_t alloc); + +WSP_GGML_API size_t wsp_ggml_allocr_alloc_graph(wsp_ggml_allocr_t alloc, struct wsp_ggml_cgraph * graph); + +// +// ggml-backend v2 API +// + +// Seperate tensor and graph allocator objects +// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators +// The original API is kept as a wrapper around the new API + +// Tensor allocator +typedef struct wsp_ggml_tallocr * wsp_ggml_tallocr_t; + +WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new(void * data, size_t size, size_t alignment); +WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure(size_t alignment); +WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_buffer(struct wsp_ggml_backend_buffer * buffer); +WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_from_backend(struct wsp_ggml_backend * backend, size_t size); // allocates an owned buffer +WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_tallocr_new_measure_from_backend(struct wsp_ggml_backend * backend); + +WSP_GGML_API struct wsp_ggml_backend_buffer * wsp_ggml_tallocr_get_buffer(wsp_ggml_tallocr_t talloc); + +WSP_GGML_API void wsp_ggml_tallocr_free (wsp_ggml_tallocr_t talloc); +WSP_GGML_API bool wsp_ggml_tallocr_is_measure (wsp_ggml_tallocr_t talloc); +WSP_GGML_API void wsp_ggml_tallocr_reset (wsp_ggml_tallocr_t talloc); +WSP_GGML_API void wsp_ggml_tallocr_alloc (wsp_ggml_tallocr_t talloc, struct wsp_ggml_tensor * tensor); +WSP_GGML_API size_t wsp_ggml_tallocr_max_size (wsp_ggml_tallocr_t talloc); + + +// Graph allocator +typedef struct wsp_ggml_gallocr * wsp_ggml_gallocr_t; + +WSP_GGML_API wsp_ggml_gallocr_t wsp_ggml_gallocr_new(void); +WSP_GGML_API void wsp_ggml_gallocr_free(wsp_ggml_gallocr_t galloc); -WSP_GGML_API void wsp_ggml_allocr_free(struct wsp_ggml_allocr * alloc); -WSP_GGML_API bool wsp_ggml_allocr_is_measure(struct wsp_ggml_allocr * alloc); -WSP_GGML_API void wsp_ggml_allocr_reset(struct wsp_ggml_allocr * alloc); -WSP_GGML_API void wsp_ggml_allocr_alloc(struct wsp_ggml_allocr * alloc, struct wsp_ggml_tensor * tensor); -WSP_GGML_API size_t wsp_ggml_allocr_alloc_graph(struct wsp_ggml_allocr * alloc, struct wsp_ggml_cgraph * graph); +WSP_GGML_API void wsp_ggml_gallocr_set_parse_seq(wsp_ggml_gallocr_t galloc, const int * list, int n); +WSP_GGML_API size_t wsp_ggml_gallocr_alloc_graph(wsp_ggml_gallocr_t galloc, wsp_ggml_tallocr_t talloc, struct wsp_ggml_cgraph * graph); +// Allocate tensors from the allocators given by the hash table +WSP_GGML_API void wsp_ggml_gallocr_alloc_graph_n( + wsp_ggml_gallocr_t galloc, + struct wsp_ggml_cgraph * graph, + struct wsp_ggml_hash_set hash_set, + wsp_ggml_tallocr_t * hash_node_talloc); #ifdef __cplusplus } diff --git a/cpp/ggml-backend-impl.h b/cpp/ggml-backend-impl.h new file mode 100644 index 0000000..35a8737 --- /dev/null +++ b/cpp/ggml-backend-impl.h @@ -0,0 +1,87 @@ +#pragma once + +// ggml-backend internal header + +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + + // + // Backend buffer + // + + typedef void * wsp_ggml_backend_buffer_context_t; + + struct wsp_ggml_backend_buffer_i { + void (*free_buffer) (wsp_ggml_backend_buffer_t buffer); + void * (*get_base) (wsp_ggml_backend_buffer_t buffer); // get base pointer + size_t (*get_alloc_size)(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); // pre-allocation callback + void (*init_tensor) (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); // post-allocation callback + void (*free_tensor) (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); // pre-free callback + }; + + struct wsp_ggml_backend_buffer { + struct wsp_ggml_backend_buffer_i iface; + + wsp_ggml_backend_t backend; + wsp_ggml_backend_buffer_context_t context; + + size_t size; + }; + + WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init( + struct wsp_ggml_backend * backend, + struct wsp_ggml_backend_buffer_i iface, + wsp_ggml_backend_buffer_context_t context, + size_t size); + + // + // Backend + // + + typedef void * wsp_ggml_backend_context_t; + + struct wsp_ggml_backend_i { + const char * (*get_name)(wsp_ggml_backend_t backend); + + void (*free)(wsp_ggml_backend_t backend); + + // buffer allocation + wsp_ggml_backend_buffer_t (*alloc_buffer)(wsp_ggml_backend_t backend, size_t size); + + // get buffer alignment + size_t (*get_alignment)(wsp_ggml_backend_t backend); + + // tensor data access + // these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize + void (*set_tensor_async)(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async)(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*synchronize) (wsp_ggml_backend_t backend); + + // (optional) copy tensor between different backends, allow for single-copy tranfers + void (*cpy_tensor_from)(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); + void (*cpy_tensor_to) (wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); + + // compute graph with a plan + wsp_ggml_backend_graph_plan_t (*graph_plan_create) (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); + void (*graph_plan_free) (wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); + void (*graph_plan_compute)(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); + + // compute graph without a plan + void (*graph_compute)(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); + + // check if the backend supports an operation + bool (*supports_op)(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op); + }; + + struct wsp_ggml_backend { + struct wsp_ggml_backend_i iface; + + wsp_ggml_backend_context_t context; + }; + +#ifdef __cplusplus +} +#endif diff --git a/cpp/ggml-backend.c b/cpp/ggml-backend.c new file mode 100644 index 0000000..2a22c34 --- /dev/null +++ b/cpp/ggml-backend.c @@ -0,0 +1,950 @@ +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include + +#define UNUSED WSP_GGML_UNUSED + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// backend buffer + +wsp_ggml_backend_buffer_t wsp_ggml_backend_buffer_init( + struct wsp_ggml_backend * backend, + struct wsp_ggml_backend_buffer_i iface, + wsp_ggml_backend_buffer_context_t context, + size_t size) { + wsp_ggml_backend_buffer_t buffer = malloc(sizeof(struct wsp_ggml_backend_buffer)); + + WSP_GGML_ASSERT(iface.get_base != NULL); + + (*buffer) = (struct wsp_ggml_backend_buffer) { + /* .interface = */ iface, + /* .backend = */ backend, + /* .context = */ context, + /* .size = */ size, + }; + + return buffer; +} + +void wsp_ggml_backend_buffer_free(wsp_ggml_backend_buffer_t buffer) { + if (buffer == NULL) { + return; + } + + if (buffer->iface.free_buffer != NULL) { + buffer->iface.free_buffer(buffer); + } + free(buffer); +} + +size_t wsp_ggml_backend_buffer_get_alignment(wsp_ggml_backend_buffer_t buffer) { + return wsp_ggml_backend_get_alignment(buffer->backend); +} + +size_t wsp_ggml_backend_buffer_get_size(wsp_ggml_backend_buffer_t buffer) { + return buffer->size; +} + +void * wsp_ggml_backend_buffer_get_base(wsp_ggml_backend_buffer_t buffer) { + void * base = buffer->iface.get_base(buffer); + + WSP_GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL"); + + return base; +} + +size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) { + // get_alloc_size is optional, defaults to wsp_ggml_nbytes + if (buffer->iface.get_alloc_size) { + return buffer->iface.get_alloc_size(buffer, tensor); + } + return wsp_ggml_nbytes(tensor); +} + +void wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) { + // init_tensor is optional + if (buffer->iface.init_tensor) { + buffer->iface.init_tensor(buffer, tensor); + } +} + +void wsp_ggml_backend_buffer_free_tensor(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor) { + // free_tensor is optional + if (buffer->iface.free_tensor) { + buffer->iface.free_tensor(buffer, tensor); + } +} + +// backend + +wsp_ggml_backend_t wsp_ggml_get_backend(const struct wsp_ggml_tensor * tensor) { + return tensor->buffer ? tensor->buffer->backend : NULL; +} + +const char * wsp_ggml_backend_name(wsp_ggml_backend_t backend) { + if (backend == NULL) { + return "NULL"; + } + return backend->iface.get_name(backend); +} + +void wsp_ggml_backend_free(wsp_ggml_backend_t backend) { + if (backend == NULL) { + return; + } + + backend->iface.free(backend); +} + +wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_buffer(wsp_ggml_backend_t backend, size_t size) { + return backend->iface.alloc_buffer(backend, size); +} + +size_t wsp_ggml_backend_get_alignment(wsp_ggml_backend_t backend) { + return backend->iface.get_alignment(backend); +} + +void wsp_ggml_backend_tensor_set_async(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + wsp_ggml_get_backend(tensor)->iface.set_tensor_async(wsp_ggml_get_backend(tensor), tensor, data, offset, size); +} + +void wsp_ggml_backend_tensor_get_async(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { + wsp_ggml_get_backend(tensor)->iface.get_tensor_async(wsp_ggml_get_backend(tensor), tensor, data, offset, size); +} + +void wsp_ggml_backend_tensor_set(struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + wsp_ggml_backend_t backend = wsp_ggml_get_backend(tensor); + + WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + WSP_GGML_ASSERT(backend != NULL && "tensor backend not set"); + + backend->iface.set_tensor_async(backend, tensor, data, offset, size); + backend->iface.synchronize(backend); +} + +void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { + wsp_ggml_backend_t backend = wsp_ggml_get_backend(tensor); + + WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + WSP_GGML_ASSERT(backend != NULL && "tensor backend not set"); + + backend->iface.get_tensor_async(backend, tensor, data, offset, size); + backend->iface.synchronize(backend); +} + +void wsp_ggml_backend_synchronize(wsp_ggml_backend_t backend) { + backend->iface.synchronize(backend); +} + +wsp_ggml_backend_graph_plan_t wsp_ggml_backend_graph_plan_create(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { + return backend->iface.graph_plan_create(backend, cgraph); +} + +void wsp_ggml_backend_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { + backend->iface.graph_plan_free(backend, plan); +} + +void wsp_ggml_backend_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { + backend->iface.graph_plan_compute(backend, plan); +} + +void wsp_ggml_backend_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { + backend->iface.graph_compute(backend, cgraph); +} + +bool wsp_ggml_backend_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) { + return backend->iface.supports_op(backend, op); +} + +// backend copy + +static bool wsp_ggml_are_same_layout(const struct wsp_ggml_tensor * a, const struct wsp_ggml_tensor * b) { + if (a->type != b->type) { + return false; + } + for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) { + if (a->ne[i] != b->ne[i]) { + return false; + } + if (a->nb[i] != b->nb[i]) { + return false; + } + } + return true; +} + +void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + //printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]); + //printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]); + WSP_GGML_ASSERT(wsp_ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); + + // fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, wsp_ggml_backend_name(src->backend), wsp_ggml_backend_name(dst->backend), wsp_ggml_nbytes(src)); + + if (src == dst) { + return; + } + + // TODO: allow backends to support copy to/from same backend + + if (wsp_ggml_get_backend(dst)->iface.cpy_tensor_from != NULL) { + wsp_ggml_get_backend(dst)->iface.cpy_tensor_from(wsp_ggml_get_backend(dst)->context, src, dst); + } else if (wsp_ggml_get_backend(src)->iface.cpy_tensor_to != NULL) { + wsp_ggml_get_backend(src)->iface.cpy_tensor_to(wsp_ggml_get_backend(src)->context, src, dst); + } else { + // shouldn't be hit when copying from/to CPU + #ifndef NDEBUG + fprintf(stderr, "wsp_ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", wsp_ggml_backend_name(src->buffer->backend), wsp_ggml_backend_name(dst->buffer->backend)); + #endif + size_t nbytes = wsp_ggml_nbytes(src); + void * data = malloc(nbytes); + wsp_ggml_backend_tensor_get(src, data, 0, nbytes); + wsp_ggml_backend_tensor_set(dst, data, 0, nbytes); + free(data); + } +} + +// backend CPU + +struct wsp_ggml_backend_cpu_context { + int n_threads; + void * work_data; + size_t work_size; +}; + +static const char * wsp_ggml_backend_cpu_name(wsp_ggml_backend_t backend) { + return "CPU"; + + UNUSED(backend); +} + +static void wsp_ggml_backend_cpu_free(wsp_ggml_backend_t backend) { + struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context; + free(cpu_ctx->work_data); + free(cpu_ctx); + free(backend); +} + +static void * wsp_ggml_backend_cpu_buffer_get_base(wsp_ggml_backend_buffer_t buffer) { + return (void *)buffer->context; +} + +static void wsp_ggml_backend_cpu_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { + free(buffer->context); + UNUSED(buffer); +} + +static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i = { + /* .free_buffer = */ wsp_ggml_backend_cpu_buffer_free_buffer, + /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base, + /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes + /* .init_tensor = */ NULL, // no initialization required + /* .free_tensor = */ NULL, // no cleanup required +}; + +// for buffers from ptr, free is not called +static struct wsp_ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { + /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed + /* .get_base = */ wsp_ggml_backend_cpu_buffer_get_base, + /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes + /* .init_tensor = */ NULL, + /* .free_tensor = */ NULL, +}; + +static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512 + +static wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_alloc_buffer(wsp_ggml_backend_t backend, size_t size) { + size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned + void * data = malloc(size); // TODO: maybe use WSP_GGML_ALIGNED_MALLOC? + + WSP_GGML_ASSERT(data != NULL && "failed to allocate buffer"); + + return wsp_ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size); +} + +static size_t wsp_ggml_backend_cpu_get_alignment(wsp_ggml_backend_t backend) { + return TENSOR_ALIGNMENT; + UNUSED(backend); +} + +static void wsp_ggml_backend_cpu_set_tensor_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds"); + WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy((char *)tensor->data + offset, data, size); + + UNUSED(backend); +} + +static void wsp_ggml_backend_cpu_get_tensor_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { + WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds"); + WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy(data, (const char *)tensor->data + offset, size); + + UNUSED(backend); +} + +static void wsp_ggml_backend_cpu_synchronize(wsp_ggml_backend_t backend) { + UNUSED(backend); +} + +static void wsp_ggml_backend_cpu_cpy_tensor_from(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src)); + + UNUSED(backend); +} + +static void wsp_ggml_backend_cpu_cpy_tensor_to(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + wsp_ggml_backend_tensor_set(dst, src->data, 0, wsp_ggml_nbytes(src)); + + UNUSED(backend); +} + +struct wsp_ggml_backend_plan_cpu { + struct wsp_ggml_cplan cplan; + struct wsp_ggml_cgraph cgraph; +}; + +static wsp_ggml_backend_graph_plan_t wsp_ggml_backend_cpu_graph_plan_create(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { + struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context; + + struct wsp_ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct wsp_ggml_backend_plan_cpu)); + + cpu_plan->cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads); + cpu_plan->cgraph = *cgraph; + + if (cpu_plan->cplan.work_size > 0) { + cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); + } + + return cpu_plan; +} + +static void wsp_ggml_backend_cpu_graph_plan_free(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { + struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan; + + free(cpu_plan->cplan.work_data); + free(cpu_plan); + + UNUSED(backend); +} + +static void wsp_ggml_backend_cpu_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan) { + struct wsp_ggml_backend_plan_cpu * cpu_plan = (struct wsp_ggml_backend_plan_cpu *)plan; + + wsp_ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan); + + UNUSED(backend); +} + +static void wsp_ggml_backend_cpu_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { + struct wsp_ggml_backend_cpu_context * cpu_ctx = (struct wsp_ggml_backend_cpu_context *)backend->context; + + struct wsp_ggml_cplan cplan = wsp_ggml_graph_plan(cgraph, cpu_ctx->n_threads); + + if (cpu_ctx->work_size < cplan.work_size) { + // TODO: may be faster to free and use malloc to avoid the copy + cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); + cpu_ctx->work_size = cplan.work_size; + } + + cplan.work_data = cpu_ctx->work_data; + + wsp_ggml_graph_compute(cgraph, &cplan); +} + +static bool wsp_ggml_backend_cpu_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) { + return true; + UNUSED(backend); + UNUSED(op); +} + +static struct wsp_ggml_backend_i cpu_backend_i = { + /* .get_name = */ wsp_ggml_backend_cpu_name, + /* .free = */ wsp_ggml_backend_cpu_free, + /* .alloc_buffer = */ wsp_ggml_backend_cpu_alloc_buffer, + /* .get_alignment = */ wsp_ggml_backend_cpu_get_alignment, + /* .set_tensor_async = */ wsp_ggml_backend_cpu_set_tensor_async, + /* .get_tensor_async = */ wsp_ggml_backend_cpu_get_tensor_async, + /* .synchronize = */ wsp_ggml_backend_cpu_synchronize, + /* .cpy_tensor_from = */ wsp_ggml_backend_cpu_cpy_tensor_from, + /* .cpy_tensor_to = */ wsp_ggml_backend_cpu_cpy_tensor_to, + /* .graph_plan_create = */ wsp_ggml_backend_cpu_graph_plan_create, + /* .graph_plan_free = */ wsp_ggml_backend_cpu_graph_plan_free, + /* .graph_plan_compute = */ wsp_ggml_backend_cpu_graph_plan_compute, + /* .graph_compute = */ wsp_ggml_backend_cpu_graph_compute, + /* .supports_op = */ wsp_ggml_backend_cpu_supports_op, +}; + +wsp_ggml_backend_t wsp_ggml_backend_cpu_init(void) { + struct wsp_ggml_backend_cpu_context * ctx = malloc(sizeof(struct wsp_ggml_backend_cpu_context)); + + ctx->n_threads = WSP_GGML_DEFAULT_N_THREADS; + ctx->work_data = NULL; + ctx->work_size = 0; + + wsp_ggml_backend_t cpu_backend = malloc(sizeof(struct wsp_ggml_backend)); + + *cpu_backend = (struct wsp_ggml_backend) { + /* .interface = */ cpu_backend_i, + /* .context = */ ctx + }; + return cpu_backend; +} + +bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend) { + return backend->iface.get_name == wsp_ggml_backend_cpu_name; +} + +void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_threads) { + WSP_GGML_ASSERT(wsp_ggml_backend_is_cpu(backend_cpu)); + + struct wsp_ggml_backend_cpu_context * ctx = (struct wsp_ggml_backend_cpu_context *)backend_cpu->context; + ctx->n_threads = n_threads; +} + +wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(wsp_ggml_backend_t backend_cpu, void * ptr, size_t size) { + return wsp_ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size); +} + +// scheduler + +#define WSP_GGML_MAX_BACKENDS 4 +#define WSP_GGML_MAX_SPLITS 256 +#define WSP_GGML_MAX_SPLIT_INPUTS 16 + +struct wsp_ggml_backend_sched_split { + wsp_ggml_tallocr_t tallocr; + int i_start; + int i_end; + struct wsp_ggml_tensor * inputs[WSP_GGML_MAX_SPLIT_INPUTS]; + int n_inputs; + struct wsp_ggml_cgraph * graph; +}; + +struct wsp_ggml_backend_sched { + int n_backends; + wsp_ggml_backend_t backends[WSP_GGML_MAX_BACKENDS]; + wsp_ggml_tallocr_t tallocs[WSP_GGML_MAX_BACKENDS]; + + wsp_ggml_gallocr_t galloc; + + struct wsp_ggml_hash_set hash_set; + wsp_ggml_tallocr_t * node_talloc; // [hash_set.size] + struct wsp_ggml_tensor * (* node_copies)[WSP_GGML_MAX_BACKENDS]; // [hash_set.size][WSP_GGML_MAX_BACKENDS] + + struct wsp_ggml_cgraph * graph; + struct wsp_ggml_backend_sched_split splits[WSP_GGML_MAX_SPLITS]; + int n_splits; + + struct wsp_ggml_context * ctx; + + // align context_buffer to WSP_GGML_MEM_ALIGN + #ifdef _MSC_VER + __declspec(align(WSP_GGML_MEM_ALIGN)) + #else + __attribute__((aligned(WSP_GGML_MEM_ALIGN))) + #endif + char context_buffer[WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS*sizeof(struct wsp_ggml_tensor) + WSP_GGML_MAX_SPLITS*sizeof(struct wsp_ggml_cgraph)]; +}; + +#define hash_id(node) wsp_ggml_hash_find_or_insert(sched->hash_set, node) +#define node_allocr(node) sched->node_talloc[hash_id(node)] + +static bool wsp_ggml_is_view_op(enum wsp_ggml_op op) { + return op == WSP_GGML_OP_VIEW || op == WSP_GGML_OP_RESHAPE || op == WSP_GGML_OP_PERMUTE || op == WSP_GGML_OP_TRANSPOSE; +} + +// returns the priority of the backend, lower is better +static int sched_backend_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) { + for (int i = 0; i < sched->n_backends; i++) { + if (sched->backends[i] == backend) { + return i; + } + } + return INT_MAX; +} + +static int sched_allocr_prio(wsp_ggml_backend_sched_t sched, wsp_ggml_tallocr_t allocr) { + for (int i = 0; i < sched->n_backends; i++) { + if (sched->tallocs[i] == allocr) { + return i; + } + } + return INT_MAX; +} + +// returns the backend that should be used for the node based on the current locations +char causes[WSP_GGML_DEFAULT_GRAPH_SIZE*4 + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS][128]; // debug, remove +static wsp_ggml_backend_t sched_backend_from_cur(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node) { + // if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there + // ie. kv cache updates + // note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend. + // dst + wsp_ggml_backend_t cur_backend = wsp_ggml_get_backend(node); + if (cur_backend != NULL) { + sprintf(causes[hash_id(node)], "1.dst"); + return cur_backend; + } + + // view_src + if (node->view_src != NULL && wsp_ggml_get_backend(node->view_src) != NULL) { + sprintf(causes[hash_id(node)], "1.vsrc"); + return wsp_ggml_get_backend(node->view_src); + } + + // src + int cur_prio = INT_MAX; + size_t cur_size = 0; + + for (int i = 0; i < WSP_GGML_MAX_SRC; i++) { + const struct wsp_ggml_tensor * src = node->src[i]; + if (src == NULL) { + break; + } + wsp_ggml_backend_t src_backend = wsp_ggml_get_backend(src); + if (src_backend != NULL) { + int src_prio = sched_backend_prio(sched, src_backend); + size_t src_size = wsp_ggml_nbytes(src); + if (src_prio < cur_prio && src_size >= cur_size) { + cur_prio = src_prio; + cur_size = src_size; + cur_backend = src_backend; + sprintf(causes[hash_id(node)], "1.src%d", i); + } + } + } + return cur_backend; +} + +static char * fmt_size(size_t size) { + static char buffer[128]; + if (size >= 1024*1024) { + sprintf(buffer, "%zuM", size/1024/1024); + } else { + sprintf(buffer, "%zuK", size/1024); + } + return buffer; +} + +static void sched_print_assignments(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) { + int cur_split = 0; + for (int i = 0; i < graph->n_nodes; i++) { + if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) { + wsp_ggml_backend_t split_backend = wsp_ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend; + fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, wsp_ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs); + for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) { + fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(wsp_ggml_nbytes(sched->splits[cur_split].inputs[j]))); + } + fprintf(stderr, "\n"); + cur_split++; + } + struct wsp_ggml_tensor * node = graph->nodes[i]; + if (wsp_ggml_is_view_op(node->op)) { + continue; + } + wsp_ggml_tallocr_t node_allocr = node_allocr(node); + wsp_ggml_backend_t node_backend = node_allocr ? wsp_ggml_tallocr_get_buffer(node_allocr)->backend : NULL; + fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, wsp_ggml_op_name(node->op), node->name, fmt_size(wsp_ggml_nbytes(node)), node_allocr ? wsp_ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]); + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + wsp_ggml_tallocr_t src_allocr = node_allocr(src); + wsp_ggml_backend_t src_backend = src_allocr ? wsp_ggml_tallocr_get_buffer(src_allocr)->backend : NULL; + fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(wsp_ggml_nbytes(src)), src_backend ? wsp_ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]); + } + fprintf(stderr, "\n"); + } +} + +// creates a copy of the tensor with the same memory layout +static struct wsp_ggml_tensor * wsp_ggml_dup_tensor_layout(struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * tensor) { + struct wsp_ggml_tensor * dup = wsp_ggml_dup_tensor(ctx, tensor); + for (int i = 0; i < WSP_GGML_MAX_DIMS; i++) { + dup->nb[i] = tensor->nb[i]; + } + return dup; +} + +// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend +// TODO: merge passes +static void sched_split_graph(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) { + // reset state + size_t hash_size = sched->hash_set.size; + memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size); + memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size); + memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size); + sched->n_splits = 0; + + struct wsp_ggml_init_params params = { + /*.mem_size = */ sizeof(sched->context_buffer), + /*.mem_buffer = */ sched->context_buffer, + /*.no_alloc = */ true + }; + + if (sched->ctx != NULL) { + wsp_ggml_free(sched->ctx); + } + + sched->ctx = wsp_ggml_init(params); + + // pass 1: assign backends to ops with allocated inputs + for (int i = 0; i < graph->n_leafs; i++) { + struct wsp_ggml_tensor * leaf = graph->leafs[i]; + if (node_allocr(leaf) != NULL) { + // do not overwrite user assignments + continue; + } + wsp_ggml_backend_t leaf_backend = wsp_ggml_get_backend(leaf); + if (leaf_backend == NULL && leaf->view_src != NULL) { + leaf_backend = wsp_ggml_get_backend(leaf->view_src); + } + if (leaf_backend != NULL) { + node_allocr(leaf) = wsp_ggml_backend_sched_get_tallocr(sched, leaf_backend); + } + } + + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + if (node_allocr(node) != NULL) { + // do not overwrite user assignments + continue; + } + wsp_ggml_backend_t node_backend = sched_backend_from_cur(sched, node); + if (node_backend != NULL) { + node_allocr(node) = wsp_ggml_backend_sched_get_tallocr(sched, node_backend); + } + } + //printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); + + // pass 2: assign backends to ops from current assignments + // TODO: + // - reuse sched_backend_from_cur + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + wsp_ggml_tallocr_t node_allocr = node_allocr(node); + if (node_allocr == NULL) { + int cur_prio = INT_MAX; + size_t cur_size = 0; + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + wsp_ggml_tallocr_t src_allocr = node_allocr(src); + if (src_allocr != NULL) { + int src_prio = sched_allocr_prio(sched, src_allocr); + size_t src_size = wsp_ggml_nbytes(src); + if (src_prio < cur_prio && src_size >= cur_size) { + cur_prio = src_prio; + cur_size = src_size; + node_allocr = src_allocr; + sprintf(causes[hash_id(node)], "2.src%d", j); + } + } + } + if (node_allocr != NULL) { + node_allocr(node) = node_allocr; + } + } + } + //printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); + + // pass 3: assign backends to remaining src from dst (should only be leafs) + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + wsp_ggml_tallocr_t node_allocr = node_allocr(node); + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + wsp_ggml_tallocr_t src_allocr = node_allocr(src); + if (src_allocr == NULL) { + node_allocr(src) = node_allocr; + } + } + } + //printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); + + // pass 4: split graph, find tensors that need to be copied + // TODO: + // - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost + // find first backend + int cur_split = 0; + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + if (node->view_src == NULL) { + sched->splits[0].tallocr = node_allocr(node); + break; + } + } + sched->splits[0].i_start = 0; + sched->splits[0].n_inputs = 0; + memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK + wsp_ggml_tallocr_t cur_allocr = sched->splits[0].tallocr; + size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr); + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + + if (wsp_ggml_is_view_op(node->op)) { + continue; + } + + wsp_ggml_tallocr_t node_allocr = node_allocr(node); + + if (node_allocr != cur_allocr) { + sched->splits[cur_split].i_end = i; + cur_split++; + WSP_GGML_ASSERT(cur_split < WSP_GGML_MAX_SPLITS); + sched->splits[cur_split].tallocr = node_allocr; + sched->splits[cur_split].i_start = i; + sched->splits[cur_split].n_inputs = 0; + memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK + cur_allocr = node_allocr; + cur_backend_id = sched_allocr_prio(sched, cur_allocr); + } + + // find inputs that are not on the same backend + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + wsp_ggml_tallocr_t src_allocr = node_allocr(src); + if (src_allocr != node_allocr) { + int n_inputs = sched->splits[cur_split].n_inputs++; + WSP_GGML_ASSERT(n_inputs < WSP_GGML_MAX_SPLIT_INPUTS); + sched->splits[cur_split].inputs[n_inputs] = (struct wsp_ggml_tensor *)src; + + // create copies + size_t id = hash_id(src); + if (sched->node_copies[id][cur_backend_id] == NULL) { + struct wsp_ggml_tensor * tensor_copy = wsp_ggml_dup_tensor_layout(sched->ctx, src); + sched->node_copies[id][cur_backend_id] = tensor_copy; + node_allocr(tensor_copy) = cur_allocr; + wsp_ggml_backend_t backend = wsp_ggml_tallocr_get_buffer(cur_allocr)->backend; + wsp_ggml_format_name(tensor_copy, "%s#%s", wsp_ggml_backend_name(backend), src->name); + } + node->src[j] = sched->node_copies[id][cur_backend_id]; + } + } + } + sched->splits[cur_split].i_end = graph->n_nodes; + sched->n_splits = cur_split + 1; + + //fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout); + +#if 1 + // sanity check: all sources should have the same backend as the node + for (int i = 0; i < graph->n_nodes; i++) { + struct wsp_ggml_tensor * node = graph->nodes[i]; + wsp_ggml_tallocr_t node_allocr = node_allocr(node); + if (node_allocr == NULL) { + fprintf(stderr, "!!!!!!! %s has no backend\n", node->name); + } + for (int j = 0; j < WSP_GGML_MAX_SRC; j++) { + struct wsp_ggml_tensor * src = node->src[j]; + if (src == NULL) { + break; + } + wsp_ggml_tallocr_t src_allocr = node_allocr(src); + if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now + fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n", + node->name, node_allocr ? wsp_ggml_backend_name(wsp_ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL", + j, src->name, src_allocr ? wsp_ggml_backend_name(wsp_ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL"); + } + } + } +#endif + + // create copies of the graph for each split + // FIXME: avoid this copy, pass split inputs to wsp_ggml_gallocr_alloc_graph_n in some other way + struct wsp_ggml_cgraph * graph_copy = wsp_ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*WSP_GGML_MAX_SPLIT_INPUTS, false); + for (int i = 0; i < sched->n_splits; i++) { + struct wsp_ggml_backend_sched_split * split = &sched->splits[i]; + split->graph = wsp_ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end); + + // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split + for (int j = 0; j < split->n_inputs; j++) { + struct wsp_ggml_tensor * input = split->inputs[j]; + struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)]; + input_cpy->src[0] = input; + graph_copy->nodes[graph_copy->n_nodes++] = input_cpy; + } + + for (int j = split->i_start; j < split->i_end; j++) { + graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j]; + } + } + sched->graph = graph_copy; +} + +static void sched_alloc_splits(wsp_ggml_backend_sched_t sched) { + wsp_ggml_gallocr_alloc_graph_n( + sched->galloc, + sched->graph, + sched->hash_set, + sched->node_talloc); +} + +static void sched_compute_splits(wsp_ggml_backend_sched_t sched) { + uint64_t copy_us[WSP_GGML_MAX_BACKENDS] = {0}; + uint64_t compute_us[WSP_GGML_MAX_BACKENDS] = {0}; + + struct wsp_ggml_backend_sched_split * splits = sched->splits; + + for (int i = 0; i < sched->n_splits; i++) { + struct wsp_ggml_backend_sched_split * split = &splits[i]; + wsp_ggml_backend_t split_backend = wsp_ggml_tallocr_get_buffer(split->tallocr)->backend; + int split_backend_id = sched_backend_prio(sched, split_backend); + + // copy the input tensors to the split backend + uint64_t copy_start_us = wsp_ggml_time_us(); + for (int j = 0; j < split->n_inputs; j++) { + struct wsp_ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)]; + if (split->inputs[j]->buffer == NULL) { + if (split->inputs[j]->view_src == NULL) { + fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name); + exit(1); + } + struct wsp_ggml_tensor * view = split->inputs[j]; + view->backend = view->view_src->backend; + view->buffer = view->view_src->buffer; + view->data = (char *)view->view_src->data + view->view_offs; + wsp_ggml_backend_buffer_init_tensor(wsp_ggml_backend_sched_get_buffer(sched, view->buffer->backend), view); + } + if (input_cpy->buffer == NULL) { + fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name); + exit(1); + } + WSP_GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend); + WSP_GGML_ASSERT(input_cpy->buffer->backend == split_backend); + wsp_ggml_backend_tensor_copy(split->inputs[j], input_cpy); + } + // wsp_ggml_backend_synchronize(split_backend); + int64_t copy_end_us = wsp_ggml_time_us(); + copy_us[split_backend_id] += copy_end_us - copy_start_us; + +#if 0 + char split_filename[WSP_GGML_MAX_NAME]; + snprintf(split_filename, WSP_GGML_MAX_NAME, "split_%i_%s.dot", i, wsp_ggml_backend_name(split_backend)); + wsp_ggml_graph_dump_dot(split->graph, NULL, split_filename); +#endif + + uint64_t compute_start_us = wsp_ggml_time_us(); + wsp_ggml_backend_graph_compute(split_backend, split->graph); + // wsp_ggml_backend_synchronize(split_backend); + uint64_t compute_end_us = wsp_ggml_time_us(); + compute_us[split_backend_id] += compute_end_us - compute_start_us; + } + +#if 0 + // per-backend timings + fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits); + for (int i = 0; i < sched->n_backends; i++) { + if (copy_us[i] > 0 || compute_us[i] > 0) { + fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", wsp_ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]); + } + } +#endif +} + +static void sched_reset(wsp_ggml_backend_sched_t sched) { + for (int i = 0; i < sched->n_backends; i++) { + wsp_ggml_tallocr_reset(sched->tallocs[i]); + } +} + +wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, int n_backends) { + WSP_GGML_ASSERT(n_backends <= WSP_GGML_MAX_BACKENDS); + + struct wsp_ggml_backend_sched * sched = malloc(sizeof(struct wsp_ggml_backend_sched)); + memset(sched, 0, sizeof(struct wsp_ggml_backend_sched)); + + fprintf(stderr, "wsp_ggml_backend_sched size: %lu KB\n", sizeof(struct wsp_ggml_backend_sched)/1024); + + sched->n_backends = n_backends; + for (int i = 0; i < n_backends; i++) { + sched->backends[i] = backends[i]; + } + + sched->galloc = wsp_ggml_gallocr_new(); + + // init measure allocs for each backend + for (int i = 0; i < n_backends; i++) { + sched->tallocs[i] = wsp_ggml_tallocr_new_measure_from_backend(backends[i]); + } + + return sched; +} + +void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched) { + if (sched == NULL) { + return; + } + for (int i = 0; i < sched->n_backends; i++) { + wsp_ggml_tallocr_free(sched->tallocs[i]); + } + wsp_ggml_gallocr_free(sched->galloc); + free(sched->hash_set.keys); + free(sched->node_talloc); + free(sched->node_copies); + free(sched); +} + +void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * measure_graph) { + // initialize hash tables + size_t hash_size = measure_graph->visited_hash_table.size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS; + sched->hash_set.size = hash_size; + sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size); + sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size); + sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size); + + sched_split_graph(sched, measure_graph); + sched_alloc_splits(sched); + + // allocate buffers and reset allocators + for (int i = 0; i < sched->n_backends; i++) { + size_t size = wsp_ggml_tallocr_max_size(sched->tallocs[i]); + wsp_ggml_tallocr_free(sched->tallocs[i]); + sched->tallocs[i] = wsp_ggml_tallocr_new_from_backend(sched->backends[i], size); + } + + sched_reset(sched); +} + +void wsp_ggml_backend_sched_graph_compute(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * graph) { + WSP_GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + WSP_GGML_MAX_SPLITS*WSP_GGML_MAX_SPLIT_INPUTS); + + sched_split_graph(sched, graph); + sched_alloc_splits(sched); + sched_compute_splits(sched); + sched_reset(sched); +} + +wsp_ggml_tallocr_t wsp_ggml_backend_sched_get_tallocr(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) { + int backend_index = sched_backend_prio(sched, backend); + return sched->tallocs[backend_index]; +} + +wsp_ggml_backend_buffer_t wsp_ggml_backend_sched_get_buffer(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend) { + int backend_index = sched_backend_prio(sched, backend); + return wsp_ggml_tallocr_get_buffer(sched->tallocs[backend_index]); +} + +void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node, wsp_ggml_backend_t backend) { + int backend_index = sched_backend_prio(sched, backend); + WSP_GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); + node_allocr(node) = sched->tallocs[backend_index]; +} diff --git a/cpp/ggml-backend.h b/cpp/ggml-backend.h new file mode 100644 index 0000000..08bd134 --- /dev/null +++ b/cpp/ggml-backend.h @@ -0,0 +1,136 @@ +#pragma once + +#include "ggml.h" +#include "ggml-alloc.h" + +#ifdef __cplusplus +extern "C" { +#endif + + // + // Backend buffer + // + + struct wsp_ggml_backend_buffer; + typedef struct wsp_ggml_backend_buffer * wsp_ggml_backend_buffer_t; + + // backend buffer functions + WSP_GGML_API void wsp_ggml_backend_buffer_free (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API size_t wsp_ggml_backend_buffer_get_alignment (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API void * wsp_ggml_backend_buffer_get_base (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API size_t wsp_ggml_backend_buffer_get_size (wsp_ggml_backend_buffer_t buffer); + WSP_GGML_API size_t wsp_ggml_backend_buffer_get_alloc_size(wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); + WSP_GGML_API void wsp_ggml_backend_buffer_init_tensor (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); + WSP_GGML_API void wsp_ggml_backend_buffer_free_tensor (wsp_ggml_backend_buffer_t buffer, struct wsp_ggml_tensor * tensor); + + // + // Backend + // + + struct wsp_ggml_backend; + typedef struct wsp_ggml_backend * wsp_ggml_backend_t; + typedef void * wsp_ggml_backend_graph_plan_t; + + WSP_GGML_API wsp_ggml_backend_t wsp_ggml_get_backend(const struct wsp_ggml_tensor * tensor); + + WSP_GGML_API const char * wsp_ggml_backend_name(wsp_ggml_backend_t backend); + WSP_GGML_API void wsp_ggml_backend_free(wsp_ggml_backend_t backend); + + WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_alloc_buffer(wsp_ggml_backend_t backend, size_t size); + + WSP_GGML_API size_t wsp_ggml_backend_get_alignment(wsp_ggml_backend_t backend); + + WSP_GGML_API void wsp_ggml_backend_tensor_set_async( struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + WSP_GGML_API void wsp_ggml_backend_tensor_get_async(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); + + WSP_GGML_API void wsp_ggml_backend_tensor_set( struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + WSP_GGML_API void wsp_ggml_backend_tensor_get(const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size); + + WSP_GGML_API void wsp_ggml_backend_synchronize(wsp_ggml_backend_t backend); + + WSP_GGML_API wsp_ggml_backend_graph_plan_t wsp_ggml_backend_graph_plan_create (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); + + WSP_GGML_API void wsp_ggml_backend_graph_plan_free (wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); + WSP_GGML_API void wsp_ggml_backend_graph_plan_compute(wsp_ggml_backend_t backend, wsp_ggml_backend_graph_plan_t plan); + WSP_GGML_API void wsp_ggml_backend_graph_compute (wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph); + WSP_GGML_API bool wsp_ggml_backend_supports_op (wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op); + + // tensor copy between different backends + WSP_GGML_API void wsp_ggml_backend_tensor_copy(struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst); + + // + // CPU backend + // + + WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_cpu_init(void); + + WSP_GGML_API bool wsp_ggml_backend_is_cpu(wsp_ggml_backend_t backend); + WSP_GGML_API void wsp_ggml_backend_cpu_set_n_threads(wsp_ggml_backend_t backend_cpu, int n_threads); + + // Create a backend buffer from an existing pointer + WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_cpu_buffer_from_ptr(wsp_ggml_backend_t backend_cpu, void * ptr, size_t size); + + + // + // Backend scheduler + // + + // The backend scheduler allows for multiple backends to be used together + // Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends + // The backends are selected based on: + // - the backend that supports the operation + // - the location of the pre-allocated tensors (e.g. the weights) + /* + Example usage: + + sched = wsp_ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends); + // sched is initialized with measure allocators and cannot be used until allocated with a measure graph + + // initialize buffers from a measure graph + measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed + + // in build_graph: + build_graph(...) { + // allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer) + alloc_cpu = wsp_ggml_backend_sched_get_allocr(sched, backend_cpu); + wsp_ggml_allocr_alloc(alloc_cpu, tensor); + + // manually assigning nodes to a backend (optional, shouldn't be needed in most cases) + struct wsp_ggml_tensor * node = wsp_ggml_mul_mat(ctx, ...); + wsp_ggml_backend_sched_set_node_backend(sched, node, backend_gpu); + } + + // allocate backend buffers from measure graph + wsp_ggml_backend_sched_init_measure(sched, measure_graph); + + // the scheduler is now ready to compute graphs + + // compute + graph = build_graph(sched); + wsp_ggml_backend_sched_graph_compute(sched, graph); + */ + + struct wsp_ggml_backend_sched; + typedef struct wsp_ggml_backend_sched * wsp_ggml_backend_sched_t; + + // Initialize a backend scheduler + WSP_GGML_API wsp_ggml_backend_sched_t wsp_ggml_backend_sched_new(wsp_ggml_backend_t * backends, int n_backends); + + WSP_GGML_API void wsp_ggml_backend_sched_free(wsp_ggml_backend_sched_t sched); + + // Initialize backend buffers from a measure graph + WSP_GGML_API void wsp_ggml_backend_sched_init_measure(wsp_ggml_backend_sched_t sched, struct wsp_ggml_cgraph * measure_graph); + + WSP_GGML_API wsp_ggml_tallocr_t wsp_ggml_backend_sched_get_tallocr(wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend); + WSP_GGML_API wsp_ggml_backend_buffer_t wsp_ggml_backend_sched_get_buffer (wsp_ggml_backend_sched_t sched, wsp_ggml_backend_t backend); + + WSP_GGML_API void wsp_ggml_backend_sched_set_node_backend(wsp_ggml_backend_sched_t sched, struct wsp_ggml_tensor * node, wsp_ggml_backend_t backend); + + // Allocate a graph on the backend scheduler + WSP_GGML_API void wsp_ggml_backend_sched_graph_compute( + wsp_ggml_backend_sched_t sched, + struct wsp_ggml_cgraph * graph); + +#ifdef __cplusplus +} +#endif diff --git a/cpp/ggml-impl.h b/cpp/ggml-impl.h new file mode 100644 index 0000000..ee96697 --- /dev/null +++ b/cpp/ggml-impl.h @@ -0,0 +1,243 @@ +#pragma once + +#include "ggml.h" + +// GGML internal header + +#include +#include +#include +#include // memcpy +#include // fabsf + +#ifdef __cplusplus +extern "C" { +#endif + +// static_assert should be a #define, but if it's not, +// fall back to the _Static_assert C11 keyword. +// if C99 - static_assert is noop +// ref: https://stackoverflow.com/a/53923785/4039976 +#ifndef static_assert +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) +#define static_assert(cond, msg) _Static_assert(cond, msg) +#else +#define static_assert(cond, msg) struct global_scope_noop_trick +#endif +#endif + +// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 +#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) +#ifndef __FMA__ +#define __FMA__ +#endif +#ifndef __F16C__ +#define __F16C__ +#endif +#ifndef __SSE3__ +#define __SSE3__ +#endif +#endif + +// 16-bit float +// on Arm, we use __fp16 +// on x86, we use uint16_t +#if defined(__ARM_NEON) && !defined(_MSC_VER) + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) (x) + +#define WSP_GGML_FP16_TO_FP32(x) ((float) (x)) +#define WSP_GGML_FP32_TO_FP16(x) (x) + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif + +#ifdef __riscv_v_intrinsic +#include +#endif + +#ifdef __F16C__ + +#ifdef _MSC_VER +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) +#else +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) +#endif + +#elif defined(__POWER9_VECTOR__) + +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x) +/* the inline asm below is about 12% faster than the lookup method */ +#define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x) +#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x) + +static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) { + register float f; + register double d; + __asm__( + "mtfprd %0,%2\n" + "xscvhpdp %0,%0\n" + "frsp %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=f"(f): + /* in */ "r"(h)); + return f; +} + +static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) { + register double d; + register wsp_ggml_fp16_t r; + __asm__( /* xscvdphp can work on double or single precision */ + "xscvdphp %0,%2\n" + "mffprd %1,%0\n" : + /* temp */ "=d"(d), + /* out */ "=r"(r): + /* in */ "f"(f)); + return r; +} + +#else + +// FP16 <-> FP32 +// ref: https://github.com/Maratyszcza/FP16 + +static inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32; + fp32.as_bits = w; + return fp32.as_value; +} + +static inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32; + fp32.as_value = f; + return fp32.as_bits; +} + +static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) { + const uint32_t w = (uint32_t) h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float exp_scale = 0x1.0p-112f; +#else + const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); +#endif + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = sign | + (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) { +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) + const float scale_to_inf = 0x1.0p+112f; + const float scale_to_zero = 0x1.0p-110f; +#else + const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); + const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); +#endif + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); +} + +#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x) +#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x) + +#endif // __F16C__ + +#endif // __ARM_NEON + +// precomputed f32 table for f16 (256 KB) +// defined in ggml.c, initialized in wsp_ggml_init() +extern float wsp_ggml_table_f32_f16[1 << 16]; + +// On ARM NEON, it's quicker to directly convert x -> x instead of calling into wsp_ggml_lookup_fp16_to_fp32, +// so we define WSP_GGML_FP16_TO_FP32 and WSP_GGML_FP32_TO_FP16 elsewhere for NEON. +// This is also true for POWER9. +#if !defined(WSP_GGML_FP16_TO_FP32) || !defined(WSP_GGML_FP32_TO_FP16) + +inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) { + uint16_t s; + memcpy(&s, &f, sizeof(uint16_t)); + return wsp_ggml_table_f32_f16[s]; +} + +#define WSP_GGML_FP16_TO_FP32(x) wsp_ggml_lookup_fp16_to_fp32(x) +#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x) + +#endif + +#define WSP_GGML_HASHTABLE_FULL ((size_t)-1) +#define WSP_GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2) + +bool wsp_ggml_hash_contains (const struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key); + +// returns WSP_GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted +size_t wsp_ggml_hash_find (const struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key); + +// returns WSP_GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full +size_t wsp_ggml_hash_insert ( struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key); + +// return index, asserts if table is full +size_t wsp_ggml_hash_find_or_insert( struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key); + +#ifdef __cplusplus +} +#endif diff --git a/cpp/ggml-metal-whisper.metal b/cpp/ggml-metal-whisper.metal index 3087ecd..7c35f23 100644 --- a/cpp/ggml-metal-whisper.metal +++ b/cpp/ggml-metal-whisper.metal @@ -13,23 +13,85 @@ typedef struct { #define QK4_1 32 typedef struct { - half d; // delta - half m; // min + half d; // delta + half m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + #define QK8_0 32 typedef struct { half d; // delta int8_t qs[QK8_0]; // quants } block_q8_0; +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient kernel void kernel_add( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig]; + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; + + src0_ptr += ntg.x*nb00; + src1_ptr += ntg.x*nb10; + dst_ptr += ntg.x*nb0; + } } // assumption: src1 is a row @@ -38,7 +100,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb [[buffer(27)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -63,9 +125,17 @@ kernel void kernel_mul_row( } kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, - constant float & scale, + constant float & scale, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * scale; } @@ -85,6 +155,13 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -107,36 +184,73 @@ kernel void kernel_soft_max( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - float lmax = psrc0[tpitg[0]]; - for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { + float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY; + + for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) { lmax = MAX(lmax, psrc0[i00]); } - const float max = simd_max(lmax); + + float max = simd_max(lmax); + if (tiisg == 0) { + buf[sgitg] = max; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // broadcast, simd group number is ntg / 32 + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max = buf[0]; // parallel sum float lsum = 0.0f; - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { const float exp_psrc0 = exp(psrc0[i00] - max); lsum += exp_psrc0; // Remember the result of exp here. exp is expensive, so we really do not - // whish to compute it twice. + // wish to compute it twice. pdst[i00] = exp_psrc0; } - const float sum = simd_sum(lsum); + float sum = simd_sum(lsum); + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // broadcast, simd group number is ntg / 32 + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + buf[tpitg] += buf[tpitg + i]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[0]; - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + for (int i00 = tpitg; i00 < ne00; i00 += ntg) { pdst[i00] /= sum; } } @@ -147,37 +261,73 @@ kernel void kernel_soft_max_4( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + threadgroup float * buf [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint ntg[[threads_per_threadgroup]]) { + const int64_t i03 = (tgpig) / (ne02*ne01); + const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; + const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max - float4 lmax4 = psrc4[tpitg[0]]; - for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { + float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY; + + for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) { lmax4 = fmax(lmax4, psrc4[i00]); } - float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - const float max = simd_max(lmax); + const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); + float max = simd_max(lmax); + if (tiisg == 0) { + buf[sgitg] = max; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // broadcast, simd group number is ntg / 32 + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]); + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + max = buf[0]; // parallel sum float4 lsum4 = 0.0f; - for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { const float4 exp_psrc4 = exp(psrc4[i00] - max); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } - float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; - const float sum = simd_sum(lsum); + const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + float sum = simd_sum(lsum); + if (tiisg == 0) { + buf[sgitg] = sum; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // broadcast, simd group number is ntg / 32 + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + buf[tpitg] += buf[tpitg + i]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sum = buf[0]; - for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { pdst4[i00] /= sum; } } @@ -197,7 +347,7 @@ kernel void kernel_diag_mask_inf( dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; } else { dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; - } + } } kernel void kernel_diag_mask_inf_8( @@ -291,10 +441,11 @@ kernel void kernel_rms_norm( uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - device const float * x_scalar = (device const float *) x; - float4 sumf=0; - float all_sum=0; + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + + float4 sumf = 0; + float all_sum = 0; // parallel sum for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { @@ -307,6 +458,7 @@ kernel void kernel_rms_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 for (uint i = ntg / 32 / 2; i > 0; i /= 2) { if (tpitg < i) { @@ -314,7 +466,9 @@ kernel void kernel_rms_norm( } } if (tpitg == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } sum[0] /= ne00; } @@ -329,7 +483,9 @@ kernel void kernel_rms_norm( y[i00] = x[i00] * scale; } if (tpitg == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } } } @@ -339,8 +495,11 @@ kernel void kernel_rms_norm( // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -357,8 +516,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -368,9 +530,52 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1]) + sumy * m; } +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + // putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 //Note: This is a template, but strictly speaking it only applies to // quantizations where the block size is 32. It also does not @@ -381,18 +586,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; // src1 vector cache - float sumf[nr]={0.f}; - const int ix = tiisg/2; - const int il = 8*(tiisg%2); + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const int ix = (tiisg/2); + const int il = (tiisg%2)*8; device const float * yb = y + ix * QK4_0 + il; @@ -403,6 +613,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device sumy += yb[i] + yb[i+1]; yl[i+0] = yb[i+ 0]; yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; yl[i+8] = yb[i+16]/16.f; yl[i+9] = yb[i+17]/4096.f; @@ -418,12 +629,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; } } } -kernel void kernel_mul_mat_q4_0_f32( +kernel void kernel_mul_mv_q4_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -436,12 +647,12 @@ kernel void kernel_mul_mat_q4_0_f32( constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } -kernel void kernel_mul_mat_q4_1_f32( +kernel void kernel_mul_mv_q4_1_f32( device const void * src0, device const float * src1, device float * dst, @@ -459,9 +670,46 @@ kernel void kernel_mul_mat_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + + #define NB_Q8_0 8 -kernel void kernel_mul_mat_q8_0_f32( +kernel void kernel_mul_mv_q8_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -525,7 +773,7 @@ kernel void kernel_mul_mat_q8_0_f32( #define N_F32_F32 4 -kernel void kernel_mul_mat_f32_f32( +kernel void kernel_mul_mv_f32_f32( device const char * src0, device const char * src1, device float * dst, @@ -596,7 +844,7 @@ kernel void kernel_mul_mat_f32_f32( } } -kernel void kernel_mul_mat_f16_f32_1row( +kernel void kernel_mul_mv_f16_f32_1row( device const char * src0, device const char * src1, device float * dst, @@ -615,7 +863,7 @@ kernel void kernel_mul_mat_f16_f32_1row( constant int64_t & ne0, constant int64_t & ne1, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -650,7 +898,7 @@ kernel void kernel_mul_mat_f16_f32_1row( #define N_F16_F32 4 -kernel void kernel_mul_mat_f16_f32( +kernel void kernel_mul_mv_f16_f32( device const char * src0, device const char * src1, device float * dst, @@ -722,7 +970,7 @@ kernel void kernel_mul_mat_f16_f32( } // Assumes row size (ne00) is a multiple of 4 -kernel void kernel_mul_mat_f16_f32_l4( +kernel void kernel_mul_mv_f16_f32_l4( device const char * src0, device const char * src1, device float * dst, @@ -783,7 +1031,9 @@ kernel void kernel_alibi_f32( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & m0, + constant float & m0, + constant float & m1, + constant int & n_heads_log2_floor, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -799,37 +1049,122 @@ kernel void kernel_alibi_f32( const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float m_k = pow(m0, i2 + 1); + float m_k; + if (i2 < n_heads_log2_floor) { + m_k = pow(m0, i2 + 1); + } else { + m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1); + } for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); } } +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + thread float * cos_theta, thread float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / freq_scale); + } + *cos_theta = cos(theta) * mscale; + *sin_theta = sin(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float rope_yarn_corr_factor(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * log(n_orig_ctx / (n_rot * 2 * M_PI_F)) / (2 * log(base)); +} + +static void rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + +typedef void (rope_t)( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]); + +template kernel void kernel_rope( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant int & n_orig_ctx, + constant float & freq_base, + constant float & freq_scale, + constant float & ext_factor, + constant float & attn_factor, + constant float & beta_fast, + constant float & beta_slow, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -839,23 +1174,28 @@ kernel void kernel_rope( const bool is_neox = mode & 2; - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + float corr_dims[2]; + rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); - const float theta_0 = freq_scale * (float)p; + device const int32_t * pos = src1; + + const int64_t p = pos[i2]; + + const float theta_0 = (float)p; const float inv_ndims = -1.f/n_dims; if (!is_neox) { for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { const float theta = theta_0 * pow(freq_base, inv_ndims*i0); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const float x0 = src[0]; - const float x1 = src[1]; + const T x0 = src[0]; + const T x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; @@ -864,14 +1204,17 @@ kernel void kernel_rope( for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) { - const float theta = theta_0 * pow(freq_base, inv_ndims*ic - ib); - const float cos_theta = cos(theta); - const float sin_theta = sin(theta); + // simplified from `(ib * n_dims + ic) * inv_ndims` + const float cur_rot = inv_ndims*ic - ib; + + const float theta = theta_0 * pow(freq_base, cur_rot); + float cos_theta, sin_theta; + rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); const int64_t i0 = ib*n_dims + ic/2; - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[n_dims/2]; @@ -883,6 +1226,9 @@ kernel void kernel_rope( } } +template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; +template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -1008,6 +1354,62 @@ kernel void kernel_cpy_f32_f32( } } +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i02 < ne02) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; + src0_ptr += ntg.x*nb00; + } else { + ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; + src1_ptr += ntg.x*nb10; + } + dst_ptr += ntg.x*nb0; + } +} + //============================================ k-quants ====================================================== #ifndef QK_K @@ -1100,7 +1502,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //====================================== dot products ========================= -kernel void kernel_mul_mat_q2_K_f32( +kernel void kernel_mul_mv_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1244,7 +1646,7 @@ kernel void kernel_mul_mat_q2_K_f32( } #if QK_K == 256 -kernel void kernel_mul_mat_q3_K_f32( +kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1273,8 +1675,8 @@ kernel void kernel_mul_mat_q3_K_f32( float yl[32]; - const uint16_t kmask1 = 0x3030; - const uint16_t kmask2 = 0x0f0f; + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; const int tid = tiisg/4; const int ix = tiisg%4; @@ -1396,7 +1798,7 @@ kernel void kernel_mul_mat_q3_K_f32( } } #else -kernel void kernel_mul_mat_q3_K_f32( +kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1467,7 +1869,7 @@ kernel void kernel_mul_mat_q3_K_f32( #endif #if QK_K == 256 -kernel void kernel_mul_mat_q4_K_f32( +kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1573,7 +1975,7 @@ kernel void kernel_mul_mat_q4_K_f32( } } #else -kernel void kernel_mul_mat_q4_K_f32( +kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1662,7 +2064,7 @@ kernel void kernel_mul_mat_q4_K_f32( } #endif -kernel void kernel_mul_mat_q5_K_f32( +kernel void kernel_mul_mv_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1835,7 +2237,7 @@ kernel void kernel_mul_mat_q5_K_f32( } -kernel void kernel_mul_mat_q6_K_f32( +kernel void kernel_mul_mv_q6_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1984,6 +2386,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg } } +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); @@ -2173,7 +2631,7 @@ kernel void kernel_get_rows( } #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B @@ -2210,9 +2668,11 @@ kernel void kernel_mul_mm(device const uchar * src0, const uint r0 = tgpig.y; const uint r1 = tgpig.x; const uint im = tgpig.z; + // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + // a thread shouldn't load data outside of the matrix short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; @@ -2236,26 +2696,30 @@ kernel void kernel_mul_mm(device const uchar * src0, + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - //load data and store to threadgroup memory + // load data and store to threadgroup memory half4x4 temp_a; dequantize_func(x, il, temp_a); threadgroup_barrier(mem_flags::mem_threadgroup); + #pragma unroll(16) for (int i = 0; i < 16; i++) { *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \ - = *((device float2x4 *)y); + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2+nl-1)/nl : x; y += BLOCK_SIZE_K; threadgroup_barrier(mem_flags::mem_threadgroup); - //load matrices from threadgroup memory and conduct outer products + + // load matrices from threadgroup memory and conduct outer products threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(4) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { #pragma unroll(4) @@ -2270,6 +2734,7 @@ kernel void kernel_mul_mm(device const uchar * src0, lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) for (int i = 0; i < 8; i++){ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); @@ -2278,25 +2743,26 @@ kernel void kernel_mul_mm(device const uchar * src0, } if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0; + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; for (int i = 0; i < 8; i++) { simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); } } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; for (int i = 0; i < 8; i++) { simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); - device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg==0) { + + device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg == 0) { for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); } } @@ -2317,6 +2783,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; @@ -2345,6 +2813,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/cpp/ggml-metal.h b/cpp/ggml-metal.h index 872fedf..e827472 100644 --- a/cpp/ggml-metal.h +++ b/cpp/ggml-metal.h @@ -19,6 +19,9 @@ #pragma once +#include "ggml.h" +#include "ggml-backend.h" + #include #include @@ -33,8 +36,15 @@ struct wsp_ggml_cgraph; extern "C" { #endif +// +// internal API +// temporary exposed to user-code +// + struct wsp_ggml_metal_context; +void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data); + // number of command buffers to use struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb); void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx); @@ -79,6 +89,17 @@ int * wsp_ggml_metal_get_concur_list(struct wsp_ggml_metal_context * ctx); // creates gf->n_threads command buffers in parallel void wsp_ggml_metal_graph_compute(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_cgraph * gf); +// +// backend API +// user-code should use only these functions +// + +WSP_GGML_API wsp_ggml_backend_t wsp_ggml_backend_metal_init(void); + +WSP_GGML_API bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend); + +WSP_GGML_API void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb); + #ifdef __cplusplus } #endif diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index cbcc6f8..3973987 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -1,5 +1,6 @@ #import "ggml-metal.h" +#import "ggml-backend-impl.h" #import "ggml.h" #import @@ -11,16 +12,19 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -// TODO: temporary - reuse llama.cpp logging #ifdef WSP_GGML_METAL_NDEBUG -#define metal_printf(...) +#define WSP_GGML_METAL_LOG_INFO(...) +#define WSP_GGML_METAL_LOG_WARN(...) +#define WSP_GGML_METAL_LOG_ERROR(...) #else -#define metal_printf(...) fprintf(stderr, __VA_ARGS__) +#define WSP_GGML_METAL_LOG_INFO(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_INFO, __VA_ARGS__) +#define WSP_GGML_METAL_LOG_WARN(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_WARN, __VA_ARGS__) +#define WSP_GGML_METAL_LOG_ERROR(...) wsp_ggml_metal_log(WSP_GGML_LOG_LEVEL_ERROR, __VA_ARGS__) #endif #define UNUSED(x) (void)(x) -#define WSP_GGML_MAX_CONCUR (2*WSP_GGML_MAX_NODES) +#define WSP_GGML_MAX_CONCUR (2*WSP_GGML_DEFAULT_GRAPH_SIZE) struct wsp_ggml_metal_buffer { const char * name; @@ -59,6 +63,7 @@ WSP_GGML_METAL_DECL_KERNEL(mul); WSP_GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast WSP_GGML_METAL_DECL_KERNEL(scale); + WSP_GGML_METAL_DECL_KERNEL(scale_4); WSP_GGML_METAL_DECL_KERNEL(silu); WSP_GGML_METAL_DECL_KERNEL(relu); WSP_GGML_METAL_DECL_KERNEL(gelu); @@ -70,6 +75,8 @@ WSP_GGML_METAL_DECL_KERNEL(get_rows_f16); WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_0); WSP_GGML_METAL_DECL_KERNEL(get_rows_q4_1); + WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_0); + WSP_GGML_METAL_DECL_KERNEL(get_rows_q5_1); WSP_GGML_METAL_DECL_KERNEL(get_rows_q8_0); WSP_GGML_METAL_DECL_KERNEL(get_rows_q2_K); WSP_GGML_METAL_DECL_KERNEL(get_rows_q3_K); @@ -78,33 +85,40 @@ WSP_GGML_METAL_DECL_KERNEL(get_rows_q6_K); WSP_GGML_METAL_DECL_KERNEL(rms_norm); WSP_GGML_METAL_DECL_KERNEL(norm); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_f32_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32); - WSP_GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_f32_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_f32_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32); + WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); WSP_GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); - WSP_GGML_METAL_DECL_KERNEL(rope); + WSP_GGML_METAL_DECL_KERNEL(rope_f32); + WSP_GGML_METAL_DECL_KERNEL(rope_f16); WSP_GGML_METAL_DECL_KERNEL(alibi_f32); WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f16); WSP_GGML_METAL_DECL_KERNEL(cpy_f32_f32); WSP_GGML_METAL_DECL_KERNEL(cpy_f16_f16); + WSP_GGML_METAL_DECL_KERNEL(concat); + WSP_GGML_METAL_DECL_KERNEL(sqr); #undef WSP_GGML_METAL_DECL_KERNEL }; @@ -120,8 +134,37 @@ @interface WSPGGMLMetalClass : NSObject @implementation WSPGGMLMetalClass @end +wsp_ggml_log_callback wsp_ggml_metal_log_callback = NULL; +void * wsp_ggml_metal_log_user_data = NULL; + +void wsp_ggml_metal_log_set_callback(wsp_ggml_log_callback log_callback, void * user_data) { + wsp_ggml_metal_log_callback = log_callback; + wsp_ggml_metal_log_user_data = user_data; +} + +static void wsp_ggml_metal_log(enum wsp_ggml_log_level level, const char* format, ...){ + if (wsp_ggml_metal_log_callback != NULL) { + va_list args; + va_start(args, format); + char buffer[128]; + int len = vsnprintf(buffer, 128, format, args); + if (len < 128) { + wsp_ggml_metal_log_callback(level, buffer, wsp_ggml_metal_log_user_data); + } else { + char* buffer2 = malloc(len+1); + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + wsp_ggml_metal_log_callback(level, buffer2, wsp_ggml_metal_log_user_data); + free(buffer2); + } + va_end(args); + } +} + + + struct wsp_ggml_metal_context * wsp_ggml_metal_init(int n_cb) { - metal_printf("%s: allocating\n", __func__); + WSP_GGML_METAL_LOG_INFO("%s: allocating\n", __func__); id device; NSString * s; @@ -131,14 +174,14 @@ @implementation WSPGGMLMetalClass NSArray * devices = MTLCopyAllDevices(); for (device in devices) { s = [device name]; - metal_printf("%s: found device: %s\n", __func__, [s UTF8String]); + WSP_GGML_METAL_LOG_INFO("%s: found device: %s\n", __func__, [s UTF8String]); } #endif // Pick and show default Metal device device = MTLCreateSystemDefaultDevice(); s = [device name]; - metal_printf("%s: picking default device: %s\n", __func__, [s UTF8String]); + WSP_GGML_METAL_LOG_INFO("%s: picking default device: %s\n", __func__, [s UTF8String]); // Configure context struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context)); @@ -150,68 +193,69 @@ @implementation WSPGGMLMetalClass ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); -#ifdef WSP_GGML_SWIFT - // load the default.metallib file + // load library { - NSError * error = nil; - - NSBundle * bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]]; - NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"]; - NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath]; - NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"]; - NSURL * libURL = [NSURL fileURLWithPath:libPath]; - - // Load the metallib file into a Metal library - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; - - if (error) { - metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - } + NSBundle * bundle = nil; +#ifdef SWIFT_PACKAGE + bundle = SWIFTPM_MODULE_BUNDLE; #else - UNUSED(msl_library_source); - - // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource - { + bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]]; +#endif NSError * error = nil; + NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"]; + if (libPath != nil) { + NSURL * libURL = [NSURL fileURLWithPath:libPath]; + WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); + ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + } else { + WSP_GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__); + + NSString * sourcePath; + NSString * ggmlMetalPathResources = [[NSProcessInfo processInfo].environment objectForKey:@"WSP_GGML_METAL_PATH_RESOURCES"]; + if (ggmlMetalPathResources) { + sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"]; + } else { + sourcePath = [bundle pathForResource:@"ggml-metal-whisper" ofType:@"metal"]; + } + if (sourcePath == nil) { + WSP_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); + sourcePath = @"ggml-metal.metal"; + } + WSP_GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]); + NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error]; + if (error) { + WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); + return NULL; + } - //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; - NSBundle * bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]]; - NSString * path = [bundle pathForResource:@"ggml-metal-whisper" ofType:@"metal"]; - metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); - - NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; - if (error) { - metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); - return NULL; - } - + MTLCompileOptions* options = nil; #ifdef WSP_GGML_QKK_64 - MTLCompileOptions* options = [MTLCompileOptions new]; - options.preprocessorMacros = @{ @"QK_K" : @(64) }; - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; -#else - ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error]; + options = [MTLCompileOptions new]; + options.preprocessorMacros = @{ @"QK_K" : @(64) }; #endif + ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + } + if (error) { - metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); + WSP_GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; } } -#endif // load kernels { NSError * error = nil; + + /* + WSP_GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ + (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ + (int) ctx->pipeline_##name.threadExecutionWidth); \ + */ #define WSP_GGML_METAL_ADD_KERNEL(name) \ ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ - metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (__bridge void *) ctx->pipeline_##name, \ - (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ - (int) ctx->pipeline_##name.threadExecutionWidth); \ if (error) { \ - metal_printf("%s: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + WSP_GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ return NULL; \ } @@ -220,6 +264,7 @@ @implementation WSPGGMLMetalClass WSP_GGML_METAL_ADD_KERNEL(mul); WSP_GGML_METAL_ADD_KERNEL(mul_row); WSP_GGML_METAL_ADD_KERNEL(scale); + WSP_GGML_METAL_ADD_KERNEL(scale_4); WSP_GGML_METAL_ADD_KERNEL(silu); WSP_GGML_METAL_ADD_KERNEL(relu); WSP_GGML_METAL_ADD_KERNEL(gelu); @@ -231,6 +276,8 @@ @implementation WSPGGMLMetalClass WSP_GGML_METAL_ADD_KERNEL(get_rows_f16); WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_0); WSP_GGML_METAL_ADD_KERNEL(get_rows_q4_1); + WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_0); + WSP_GGML_METAL_ADD_KERNEL(get_rows_q5_1); WSP_GGML_METAL_ADD_KERNEL(get_rows_q8_0); WSP_GGML_METAL_ADD_KERNEL(get_rows_q2_K); WSP_GGML_METAL_ADD_KERNEL(get_rows_q3_K); @@ -239,44 +286,66 @@ @implementation WSPGGMLMetalClass WSP_GGML_METAL_ADD_KERNEL(get_rows_q6_K); WSP_GGML_METAL_ADD_KERNEL(rms_norm); WSP_GGML_METAL_ADD_KERNEL(norm); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_f32_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); - WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); - WSP_GGML_METAL_ADD_KERNEL(rope); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_f32_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32); + if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { + WSP_GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); + WSP_GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); + } + WSP_GGML_METAL_ADD_KERNEL(rope_f32); + WSP_GGML_METAL_ADD_KERNEL(rope_f16); WSP_GGML_METAL_ADD_KERNEL(alibi_f32); WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f16); WSP_GGML_METAL_ADD_KERNEL(cpy_f32_f32); WSP_GGML_METAL_ADD_KERNEL(cpy_f16_f16); + WSP_GGML_METAL_ADD_KERNEL(concat); + WSP_GGML_METAL_ADD_KERNEL(sqr); #undef WSP_GGML_METAL_ADD_KERNEL } - metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); #if TARGET_OS_OSX - metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); + // print MTL GPU family: + WSP_GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]); + + // determine max supported GPU family + // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf + for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) { + if ([ctx->device supportsFamily:i]) { + WSP_GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i); + break; + } + } + + WSP_GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); + WSP_GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.maxTransferRate != 0) { - metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); + WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); } else { - metal_printf("%s: maxTransferRate = built-in GPU\n", __func__); + WSP_GGML_METAL_LOG_INFO("%s: maxTransferRate = built-in GPU\n", __func__); } #endif @@ -284,7 +353,7 @@ @implementation WSPGGMLMetalClass } void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { - metal_printf("%s: deallocating\n", __func__); + WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); #define WSP_GGML_METAL_DEL_KERNEL(name) \ WSP_GGML_METAL_DEL_KERNEL(add); @@ -292,6 +361,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { WSP_GGML_METAL_DEL_KERNEL(mul); WSP_GGML_METAL_DEL_KERNEL(mul_row); WSP_GGML_METAL_DEL_KERNEL(scale); + WSP_GGML_METAL_DEL_KERNEL(scale_4); WSP_GGML_METAL_DEL_KERNEL(silu); WSP_GGML_METAL_DEL_KERNEL(relu); WSP_GGML_METAL_DEL_KERNEL(gelu); @@ -303,6 +373,8 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { WSP_GGML_METAL_DEL_KERNEL(get_rows_f16); WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_0); WSP_GGML_METAL_DEL_KERNEL(get_rows_q4_1); + WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_0); + WSP_GGML_METAL_DEL_KERNEL(get_rows_q5_1); WSP_GGML_METAL_DEL_KERNEL(get_rows_q8_0); WSP_GGML_METAL_DEL_KERNEL(get_rows_q2_K); WSP_GGML_METAL_DEL_KERNEL(get_rows_q3_K); @@ -311,33 +383,42 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { WSP_GGML_METAL_DEL_KERNEL(get_rows_q6_K); WSP_GGML_METAL_DEL_KERNEL(rms_norm); WSP_GGML_METAL_DEL_KERNEL(norm); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_f32_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); - WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); - WSP_GGML_METAL_DEL_KERNEL(rope); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_f32_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32); + if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) { + WSP_GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); + WSP_GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); + } + WSP_GGML_METAL_DEL_KERNEL(rope_f32); + WSP_GGML_METAL_DEL_KERNEL(rope_f16); WSP_GGML_METAL_DEL_KERNEL(alibi_f32); WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f16); WSP_GGML_METAL_DEL_KERNEL(cpy_f32_f32); WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16); + WSP_GGML_METAL_DEL_KERNEL(concat); + WSP_GGML_METAL_DEL_KERNEL(sqr); #undef WSP_GGML_METAL_DEL_KERNEL @@ -348,7 +429,7 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { void * data = NULL; const int result = posix_memalign((void **) &data, sysconf(_SC_PAGESIZE), n); if (result != 0) { - metal_printf("%s: error: posix_memalign failed\n", __func__); + WSP_GGML_METAL_LOG_ERROR("%s: error: posix_memalign failed\n", __func__); return NULL; } @@ -376,7 +457,7 @@ int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx) { // Metal buffer based on the host memory pointer // static id wsp_ggml_metal_get_buffer(struct wsp_ggml_metal_context * ctx, struct wsp_ggml_tensor * t, size_t * offs) { - //metal_printf("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); + //WSP_GGML_METAL_LOG_INFO("%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach); const int64_t tsize = wsp_ggml_nbytes(t); @@ -384,17 +465,17 @@ int wsp_ggml_metal_if_optimized(struct wsp_ggml_metal_context * ctx) { for (int i = 0; i < ctx->n_buffers; ++i) { const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data; - //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); + //WSP_GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name); if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) { *offs = (size_t) ioffs; - //metal_printf("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs); + //WSP_GGML_METAL_LOG_INFO("%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs); return ctx->buffers[i].metal; } } - metal_printf("%s: error: buffer is nil\n", __func__); + WSP_GGML_METAL_LOG_ERROR("%s: error: buffer is nil\n", __func__); return nil; } @@ -406,7 +487,7 @@ bool wsp_ggml_metal_add_buffer( size_t size, size_t max_size) { if (ctx->n_buffers >= WSP_GGML_METAL_MAX_BUFFERS) { - metal_printf("%s: too many buffers\n", __func__); + WSP_GGML_METAL_LOG_ERROR("%s: error: too many buffers\n", __func__); return false; } @@ -416,7 +497,7 @@ bool wsp_ggml_metal_add_buffer( const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data; if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) { - metal_printf("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name); + WSP_GGML_METAL_LOG_ERROR("%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name); return false; } } @@ -437,11 +518,11 @@ bool wsp_ggml_metal_add_buffer( ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil]; if (ctx->buffers[ctx->n_buffers].metal == nil) { - metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); + WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0); return false; } - metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); + WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0); ++ctx->n_buffers; } else { @@ -461,13 +542,13 @@ bool wsp_ggml_metal_add_buffer( ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil]; if (ctx->buffers[ctx->n_buffers].metal == nil) { - metal_printf("%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); + WSP_GGML_METAL_LOG_ERROR("%s: error: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0); return false; } - metal_printf("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); + WSP_GGML_METAL_LOG_INFO("%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i); if (i + size_step < size) { - metal_printf("\n"); + WSP_GGML_METAL_LOG_INFO("\n"); } ++ctx->n_buffers; @@ -475,17 +556,17 @@ bool wsp_ggml_metal_add_buffer( } #if TARGET_OS_OSX - metal_printf(", (%8.2f / %8.2f)", + WSP_GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", ctx->device.currentAllocatedSize / 1024.0 / 1024.0, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) { - metal_printf(", warning: current allocated size is greater than the recommended max working set size\n"); + WSP_GGML_METAL_LOG_WARN(", warning: current allocated size is greater than the recommended max working set size\n", __func__); } else { - metal_printf("\n"); + WSP_GGML_METAL_LOG_INFO("\n"); } #else - metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); + WSP_GGML_METAL_LOG_INFO(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0); #endif } @@ -598,7 +679,7 @@ void wsp_ggml_metal_graph_find_concurrency( } if (ctx->concur_list_len > WSP_GGML_MAX_CONCUR) { - metal_printf("%s: too many elements for metal ctx->concur_list!\n", __func__); + WSP_GGML_METAL_LOG_WARN("%s: too many elements for metal ctx->concur_list!\n", __func__); } } @@ -652,12 +733,26 @@ void wsp_ggml_metal_graph_compute( continue; } - //metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op)); + //WSP_GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, wsp_ggml_op_name(gf->nodes[i]->op)); struct wsp_ggml_tensor * src0 = gf->nodes[i]->src[0]; struct wsp_ggml_tensor * src1 = gf->nodes[i]->src[1]; struct wsp_ggml_tensor * dst = gf->nodes[i]; + switch (dst->op) { + case WSP_GGML_OP_NONE: + case WSP_GGML_OP_RESHAPE: + case WSP_GGML_OP_VIEW: + case WSP_GGML_OP_TRANSPOSE: + case WSP_GGML_OP_PERMUTE: + { + // noop -> next node + } continue; + default: + { + } break; + } + const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0; const int64_t ne02 = src0 ? src0->ne[2] : 0; @@ -696,53 +791,117 @@ void wsp_ggml_metal_graph_compute( id id_src1 = src1 ? wsp_ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; id id_dst = dst ? wsp_ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil; - //metal_printf("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op)); + //WSP_GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, wsp_ggml_op_name(dst->op)); //if (src0) { - // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02, + // WSP_GGML_METAL_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src0t), ne00, ne01, ne02, // wsp_ggml_is_contiguous(src0), src0->name); //} //if (src1) { - // metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12, + // WSP_GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, wsp_ggml_type_name(src1t), ne10, ne11, ne12, // wsp_ggml_is_contiguous(src1), src1->name); //} //if (dst) { - // metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2, + // WSP_GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, wsp_ggml_type_name(dstt), ne0, ne1, ne2, // dst->name); //} switch (dst->op) { - case WSP_GGML_OP_NONE: - case WSP_GGML_OP_RESHAPE: - case WSP_GGML_OP_VIEW: - case WSP_GGML_OP_TRANSPOSE: - case WSP_GGML_OP_PERMUTE: + case WSP_GGML_OP_CONCAT: { - // noop + const int64_t nb = ne00; + + [encoder setComputePipelineState:ctx->pipeline_concat]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + + const int nth = MIN(1024, ne0); + + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case WSP_GGML_OP_ADD: { WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src1)); - // utilize float4 - WSP_GGML_ASSERT(ne00 % 4 == 0); - const int64_t nb = ne00/4; + bool bcast_row = false; - if (wsp_ggml_nelements(src1) == ne10) { + int64_t nb = ne00; + + if (wsp_ggml_nelements(src1) == ne10 && ne00 % 4 == 0) { // src1 is a row WSP_GGML_ASSERT(ne11 == 1); + + nb = ne00 / 4; [encoder setComputePipelineState:ctx->pipeline_add_row]; + + bcast_row = true; } else { [encoder setComputePipelineState:ctx->pipeline_add]; } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; - - const int64_t n = wsp_ggml_nelements(dst)/4; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; + [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; + + if (bcast_row) { + const int64_t n = wsp_ggml_nelements(dst)/4; + + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + const int nth = MIN(1024, ne0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } } break; case WSP_GGML_OP_MUL: { @@ -775,13 +934,19 @@ void wsp_ggml_metal_graph_compute( const float scale = *(const float *) src1->data; - [encoder setComputePipelineState:ctx->pipeline_scale]; + int64_t n = wsp_ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + [encoder setComputePipelineState:ctx->pipeline_scale_4]; + } else { + [encoder setComputePipelineState:ctx->pipeline_scale]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - const int64_t n = wsp_ggml_nelements(dst)/4; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case WSP_GGML_OP_UNARY: @@ -792,9 +957,10 @@ void wsp_ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = wsp_ggml_nelements(dst)/4; + const int64_t n = wsp_ggml_nelements(dst); + WSP_GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case WSP_GGML_UNARY_OP_RELU: { @@ -812,23 +978,39 @@ void wsp_ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = wsp_ggml_nelements(dst)/4; + const int64_t n = wsp_ggml_nelements(dst); + WSP_GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: { - metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op)); + WSP_GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op)); WSP_GGML_ASSERT(false); } } break; + case WSP_GGML_OP_SQR: + { + WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0)); + + [encoder setComputePipelineState:ctx->pipeline_sqr]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = wsp_ggml_nelements(dst); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case WSP_GGML_OP_SOFT_MAX: { - const int nth = 32; + int nth = 32; // SIMD width if (ne00%4 == 0) { [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; } else { + do { + nth *= 2; + } while (nth <= ne00 && nth <= 1024); + nth /= 2; [encoder setComputePipelineState:ctx->pipeline_soft_max]; } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -836,8 +1018,9 @@ void wsp_ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case WSP_GGML_OP_DIAG_MASK_INF: { @@ -863,26 +1046,53 @@ void wsp_ggml_metal_graph_compute( } break; case WSP_GGML_OP_MUL_MAT: { - // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 - WSP_GGML_ASSERT(ne00 == ne10); - // WSP_GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere - uint gqa = ne12/ne02; WSP_GGML_ASSERT(ne03 == ne13); + const uint gqa = ne12/ne02; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 1; + +#if 0 + // the numbers below are measured on M2 Ultra for 7B and 13B models + // these numbers do not translate to other devices or model sizes + // TODO: need to find a better approach + if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) { + switch (src0t) { + case WSP_GGML_TYPE_F16: ne11_mm_min = 2; break; + case WSP_GGML_TYPE_Q8_0: ne11_mm_min = 7; break; + case WSP_GGML_TYPE_Q2_K: ne11_mm_min = 15; break; + case WSP_GGML_TYPE_Q3_K: ne11_mm_min = 7; break; + case WSP_GGML_TYPE_Q4_0: + case WSP_GGML_TYPE_Q4_1: ne11_mm_min = 15; break; + case WSP_GGML_TYPE_Q4_K: ne11_mm_min = 11; break; + case WSP_GGML_TYPE_Q5_0: // not tested yet + case WSP_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet + case WSP_GGML_TYPE_Q5_K: ne11_mm_min = 7; break; + case WSP_GGML_TYPE_Q6_K: ne11_mm_min = 7; break; + default: ne11_mm_min = 1; break; + } + } +#endif + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (!wsp_ggml_is_transposed(src0) && + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + !wsp_ggml_is_transposed(src0) && !wsp_ggml_is_transposed(src1) && src1t == WSP_GGML_TYPE_F32 && - [ctx->device supportsFamily:MTLGPUFamilyApple7] && - ne00%32 == 0 && - ne11 > 1) { + ne00 % 32 == 0 && ne00 >= 64 && + ne11 > ne11_mm_min) { + //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); switch (src0->type) { case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break; case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break; + case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break; case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; @@ -906,17 +1116,18 @@ void wsp_ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // use custom matrix x vector kernel switch (src0t) { case WSP_GGML_TYPE_F32: { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32]; nrows = 4; } break; case WSP_GGML_TYPE_F16: @@ -924,12 +1135,12 @@ void wsp_ggml_metal_graph_compute( nth0 = 32; nth1 = 1; if (ne11 * ne12 < 4) { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row]; } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4]; nrows = ne11; } else { - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32]; nrows = 4; } } break; @@ -940,7 +1151,7 @@ void wsp_ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; } break; case WSP_GGML_TYPE_Q4_1: { @@ -949,7 +1160,25 @@ void wsp_ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; + } break; + case WSP_GGML_TYPE_Q5_0: + { + WSP_GGML_ASSERT(ne02 == 1); + WSP_GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; + } break; + case WSP_GGML_TYPE_Q5_1: + { + WSP_GGML_ASSERT(ne02 == 1); + WSP_GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; } break; case WSP_GGML_TYPE_Q8_0: { @@ -958,7 +1187,7 @@ void wsp_ggml_metal_graph_compute( nth0 = 8; nth1 = 8; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; } break; case WSP_GGML_TYPE_Q2_K: { @@ -967,7 +1196,7 @@ void wsp_ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; } break; case WSP_GGML_TYPE_Q3_K: { @@ -976,7 +1205,7 @@ void wsp_ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; } break; case WSP_GGML_TYPE_Q4_K: { @@ -985,7 +1214,7 @@ void wsp_ggml_metal_graph_compute( nth0 = 4; //1; nth1 = 8; //32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; } break; case WSP_GGML_TYPE_Q5_K: { @@ -994,7 +1223,7 @@ void wsp_ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; } break; case WSP_GGML_TYPE_Q6_K: { @@ -1003,11 +1232,11 @@ void wsp_ggml_metal_graph_compute( nth0 = 2; nth1 = 32; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; } break; default: { - metal_printf("Asserting on type %d\n",(int)src0t); + WSP_GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); WSP_GGML_ASSERT(false && "not implemented"); } }; @@ -1031,8 +1260,9 @@ void wsp_ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || src0t == WSP_GGML_TYPE_Q8_0 || - src0t == WSP_GGML_TYPE_Q2_K) {// || src0t == WSP_GGML_TYPE_Q4_K) { + if (src0t == WSP_GGML_TYPE_Q4_0 || src0t == WSP_GGML_TYPE_Q4_1 || + src0t == WSP_GGML_TYPE_Q5_0 || src0t == WSP_GGML_TYPE_Q5_1 || src0t == WSP_GGML_TYPE_Q8_0 || + src0t == WSP_GGML_TYPE_Q2_K) { // || src0t == WSP_GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == WSP_GGML_TYPE_Q4_K) { @@ -1063,6 +1293,8 @@ void wsp_ggml_metal_graph_compute( case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case WSP_GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case WSP_GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case WSP_GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break; + case WSP_GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break; case WSP_GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; case WSP_GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; case WSP_GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; @@ -1085,10 +1317,12 @@ void wsp_ggml_metal_graph_compute( } break; case WSP_GGML_OP_RMS_NORM: { + WSP_GGML_ASSERT(ne00 % 4 == 0); + float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = 512; + const int nth = MIN(512, ne00); [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1096,7 +1330,7 @@ void wsp_ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth/32*sizeof(float), 16) atIndex:0]; const int64_t nrows = wsp_ggml_nrows(src0); @@ -1107,7 +1341,7 @@ void wsp_ggml_metal_graph_compute( float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = 256; + const int nth = MIN(256, ne00); [encoder setComputePipelineState:ctx->pipeline_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1115,7 +1349,7 @@ void wsp_ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:WSP_GGML_PAD(nth*sizeof(float), 16) atIndex:0]; const int64_t nrows = wsp_ggml_nrows(src0); @@ -1125,17 +1359,16 @@ void wsp_ggml_metal_graph_compute( { WSP_GGML_ASSERT((src0t == WSP_GGML_TYPE_F32)); - const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + const int nth = MIN(1024, ne00); + + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - if (__builtin_popcount(n_head) != 1) { - WSP_GGML_ASSERT(false && "only power-of-two n_head implemented"); - } - const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); [encoder setComputePipelineState:ctx->pipeline_alibi_f32]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1156,55 +1389,74 @@ void wsp_ggml_metal_graph_compute( [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; - - const int nth = 32; + [encoder setBytes:&m0 length:sizeof( float) atIndex:18]; + [encoder setBytes:&m1 length:sizeof( float) atIndex:19]; + [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case WSP_GGML_OP_ROPE: { - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; + WSP_GGML_ASSERT(ne10 == ne02); - float freq_base; - float freq_scale; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + const int nth = MIN(1024, ne00); - [encoder setComputePipelineState:ctx->pipeline_rope]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:18]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:19]; - [encoder setBytes:&mode length:sizeof( int) atIndex:20]; - [encoder setBytes:&freq_base length:sizeof(float) atIndex:21]; - [encoder setBytes:&freq_scale length:sizeof(float) atIndex:22]; + const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[3]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + switch (src0->type) { + case WSP_GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; + case WSP_GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break; + default: WSP_GGML_ASSERT(false); + }; + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:4]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:12]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:15]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:16]; + [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&n_past length:sizeof( int) atIndex:19]; + [encoder setBytes:&n_dims length:sizeof( int) atIndex:20]; + [encoder setBytes:&mode length:sizeof( int) atIndex:21]; + [encoder setBytes:&n_orig_ctx length:sizeof( int) atIndex:22]; + [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; + [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; + [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; + [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; + [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; + [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case WSP_GGML_OP_DUP: case WSP_GGML_OP_CPY: case WSP_GGML_OP_CONT: { - const int nth = 32; + const int nth = MIN(1024, ne00); switch (src0t) { case WSP_GGML_TYPE_F32: @@ -1249,7 +1501,7 @@ void wsp_ggml_metal_graph_compute( } break; default: { - metal_printf("%s: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op)); + WSP_GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, wsp_ggml_op_name(dst->op)); WSP_GGML_ASSERT(false); } } @@ -1274,10 +1526,147 @@ void wsp_ggml_metal_graph_compute( MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status]; if (status != MTLCommandBufferStatusCompleted) { - metal_printf("%s: command buffer %d failed with status %lu\n", __func__, i, status); + WSP_GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); WSP_GGML_ASSERT(false); } } } } + +//////////////////////////////////////////////////////////////////////////////// + +// backend interface + +static const char * wsp_ggml_backend_metal_name(wsp_ggml_backend_t backend) { + return "Metal"; + + UNUSED(backend); +} + +static void wsp_ggml_backend_metal_free(wsp_ggml_backend_t backend) { + struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context; + wsp_ggml_metal_free(ctx); + free(backend); +} + +static void * wsp_ggml_backend_metal_buffer_get_base(wsp_ggml_backend_buffer_t buffer) { + return (void *)buffer->context; +} + +static void wsp_ggml_backend_metal_buffer_free_buffer(wsp_ggml_backend_buffer_t buffer) { + free(buffer->context); + UNUSED(buffer); +} + +static struct wsp_ggml_backend_buffer_i metal_backend_buffer_i = { + /* .free_buffer = */ wsp_ggml_backend_metal_buffer_free_buffer, + /* .get_base = */ wsp_ggml_backend_metal_buffer_get_base, + /* .get_alloc_size = */ NULL, // defaults to wsp_ggml_nbytes + /* .init_tensor = */ NULL, // no initialization required + /* .free_tensor = */ NULL, // no cleanup required +}; + +static wsp_ggml_backend_buffer_t wsp_ggml_backend_metal_alloc_buffer(wsp_ggml_backend_t backend, size_t size) { + struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context; + + void * data = wsp_ggml_metal_host_malloc(size); + + // TODO: set proper name of the buffers + wsp_ggml_metal_add_buffer(ctx, "backend", data, size, 0); + + return wsp_ggml_backend_buffer_init(backend, metal_backend_buffer_i, data, size); +} + +static size_t wsp_ggml_backend_metal_get_alignment(wsp_ggml_backend_t backend) { + return 32; + UNUSED(backend); +} + +static void wsp_ggml_backend_metal_set_tensor_async(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor write out of bounds"); + WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy((char *)tensor->data + offset, data, size); + + UNUSED(backend); +} + +static void wsp_ggml_backend_metal_get_tensor_async(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * tensor, void * data, size_t offset, size_t size) { + WSP_GGML_ASSERT(offset + size <= wsp_ggml_nbytes(tensor) && "tensor read out of bounds"); + WSP_GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + memcpy(data, (const char *)tensor->data + offset, size); + + UNUSED(backend); +} + +static void wsp_ggml_backend_metal_synchronize(wsp_ggml_backend_t backend) { + UNUSED(backend); +} + +static void wsp_ggml_backend_metal_cpy_tensor_from(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + wsp_ggml_backend_tensor_get(src, dst->data, 0, wsp_ggml_nbytes(src)); + + UNUSED(backend); +} + +static void wsp_ggml_backend_metal_cpy_tensor_to(wsp_ggml_backend_t backend, struct wsp_ggml_tensor * src, struct wsp_ggml_tensor * dst) { + wsp_ggml_backend_tensor_set_async(dst, src->data, 0, wsp_ggml_nbytes(src)); + + UNUSED(backend); +} + +static void wsp_ggml_backend_metal_graph_compute(wsp_ggml_backend_t backend, struct wsp_ggml_cgraph * cgraph) { + struct wsp_ggml_metal_context * metal_ctx = (struct wsp_ggml_metal_context *)backend->context; + + wsp_ggml_metal_graph_compute(metal_ctx, cgraph); +} + +static bool wsp_ggml_backend_metal_supports_op(wsp_ggml_backend_t backend, const struct wsp_ggml_tensor * op) { + return true; + UNUSED(backend); + UNUSED(op); +} + +static struct wsp_ggml_backend_i metal_backend_i = { + /* .get_name = */ wsp_ggml_backend_metal_name, + /* .free = */ wsp_ggml_backend_metal_free, + /* .alloc_buffer = */ wsp_ggml_backend_metal_alloc_buffer, + /* .get_alignment = */ wsp_ggml_backend_metal_get_alignment, + /* .set_tensor_async = */ wsp_ggml_backend_metal_set_tensor_async, + /* .get_tensor_async = */ wsp_ggml_backend_metal_get_tensor_async, + /* .synchronize = */ wsp_ggml_backend_metal_synchronize, + /* .cpy_tensor_from = */ wsp_ggml_backend_metal_cpy_tensor_from, + /* .cpy_tensor_to = */ wsp_ggml_backend_metal_cpy_tensor_to, + /* .graph_plan_create = */ NULL, // the metal implementation does not require creating graph plans atm + /* .graph_plan_free = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ wsp_ggml_backend_metal_graph_compute, + /* .supports_op = */ wsp_ggml_backend_metal_supports_op, +}; + +wsp_ggml_backend_t wsp_ggml_backend_metal_init(void) { + struct wsp_ggml_metal_context * ctx = malloc(sizeof(struct wsp_ggml_metal_context)); + + ctx = wsp_ggml_metal_init(WSP_GGML_DEFAULT_N_THREADS); + + wsp_ggml_backend_t metal_backend = malloc(sizeof(struct wsp_ggml_backend)); + + *metal_backend = (struct wsp_ggml_backend) { + /* .interface = */ metal_backend_i, + /* .context = */ ctx, + }; + + return metal_backend; +} + +bool wsp_ggml_backend_is_metal(wsp_ggml_backend_t backend) { + return backend->iface.get_name == wsp_ggml_backend_metal_name; +} + +void wsp_ggml_backend_metal_set_n_cb(wsp_ggml_backend_t backend, int n_cb) { + struct wsp_ggml_metal_context * ctx = (struct wsp_ggml_metal_context *)backend->context; + + wsp_ggml_metal_set_n_cb(ctx, n_cb); +} diff --git a/cpp/ggml-quants.c b/cpp/ggml-quants.c new file mode 100644 index 0000000..7603f86 --- /dev/null +++ b/cpp/ggml-quants.c @@ -0,0 +1,7377 @@ +#include "ggml-quants.h" +#include "ggml-impl.h" + +#include +#include +#include +#include + +#ifdef __ARM_NEON + +// if YCM cannot find , make a symbolic link to it, for example: +// +// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ +// +#include + +#else + +#ifdef __wasm_simd128__ +#include +#else +#ifdef __POWER9_VECTOR__ +#include +#undef bool +#define bool _Bool +#else +#if defined(_MSC_VER) || defined(__MINGW32__) +#include +#else +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) +#if !defined(__riscv) +#include +#endif +#endif +#endif +#endif +#endif +#endif + +#ifdef __riscv_v_intrinsic +#include +#endif + +#undef MIN +#undef MAX + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) + +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +// multiply int8_t, add results pairwise twice +static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { + // Get absolute values of x vectors + const __m128i ax = _mm_sign_epi8(x, x); + // Sign the values of the y vectors + const __m128i sy = _mm_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m128i dot = _mm_maddubs_epi16(ax, sy); + const __m128i ones = _mm_set1_epi16(1); + return _mm_madd_epi16(ones, dot); +} + +#if __AVX__ || __AVX2__ || __AVX512F__ +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +// horizontally add 4 int32_t +static inline int hsum_i32_4(const __m128i a) { + const __m128i hi64 = _mm_unpackhi_epi64(a, a); + const __m128i sum64 = _mm_add_epi32(hi64, a); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + +#if defined(__AVX2__) || defined(__AVX512F__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m256i shuf_mask = _mm256_set_epi64x( + 0x0303030303030303, 0x0202020202020202, + 0x0101010101010101, 0x0000000000000000); + __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); + const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytes = _mm256_or_si256(bytes, bit_mask); + return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8( 0xF ); + return _mm256_and_si256(lowMask, bytes); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { +#if __AVXVNNI__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +#endif +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + +static inline __m128i packNibbles( __m256i bytes ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh +#if __AVX512F__ + const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 + bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh + return _mm256_cvtepi16_epi8(bytes); // abcd_efgh +#else + const __m256i lowByte = _mm256_set1_epi16( 0xFF ); + __m256i high = _mm256_andnot_si256( lowByte, bytes ); + __m256i low = _mm256_and_si256( lowByte, bytes ); + high = _mm256_srli_epi16( high, 4 ); + bytes = _mm256_or_si256( low, high ); + + // Compress uint16_t lanes into bytes + __m128i r0 = _mm256_castsi256_si128( bytes ); + __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); + return _mm_packus_epi16( r0, r1 ); +#endif +} +#elif defined(__AVX__) +// spread 32 bits to 32 bytes { 0x00, 0xFF } +static inline __m256i bytes_from_bits_32(const uint8_t * x) { + uint32_t x32; + memcpy(&x32, x, sizeof(uint32_t)); + const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); + __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); + __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); + const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); + bytesl = _mm_or_si128(bytesl, bit_mask); + bytesh = _mm_or_si128(bytesh, bit_mask); + bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); + bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); + return MM256_SET_M128I(bytesh, bytesl); +} + +// Unpack 32 4-bit fields into 32 bytes +// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval +static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) +{ + // Load 16 bytes from memory + __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); + __m128i tmph = _mm_srli_epi16(tmpl, 4); + const __m128i lowMask = _mm_set1_epi8(0xF); + tmpl = _mm_and_si128(lowMask, tmpl); + tmph = _mm_and_si128(lowMask, tmph); + return MM256_SET_M128I(tmph, tmpl); +} + +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { + const __m128i ones = _mm_set1_epi16(1); + const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); + const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); + const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); + return _mm256_cvtepi32_ps(summed_pairs); +} + +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + const __m128i xl = _mm256_castsi256_si128(x); + const __m128i xh = _mm256_extractf128_si256(x, 1); + const __m128i yl = _mm256_castsi256_si128(y); + const __m128i yh = _mm256_extractf128_si256(y, 1); + // Get absolute values of x vectors + const __m128i axl = _mm_sign_epi8(xl, xl); + const __m128i axh = _mm_sign_epi8(xh, xh); + // Sign the values of the y vectors + const __m128i syl = _mm_sign_epi8(yl, xl); + const __m128i syh = _mm_sign_epi8(yh, xh); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + +static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) +{ + // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh + const __m128i lowByte = _mm_set1_epi16( 0xFF ); + __m128i high = _mm_andnot_si128( lowByte, bytes1 ); + __m128i low = _mm_and_si128( lowByte, bytes1 ); + high = _mm_srli_epi16( high, 4 ); + bytes1 = _mm_or_si128( low, high ); + high = _mm_andnot_si128( lowByte, bytes2 ); + low = _mm_and_si128( lowByte, bytes2 ); + high = _mm_srli_epi16( high, 4 ); + bytes2 = _mm_or_si128( low, high ); + + return _mm_packus_epi16( bytes1, bytes2); +} +#endif +#elif defined(__SSSE3__) +// horizontally add 4x4 floats +static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { + __m128 res_0 =_mm_hadd_ps(a, b); + __m128 res_1 =_mm_hadd_ps(c, d); + __m128 res =_mm_hadd_ps(res_0, res_1); + res =_mm_hadd_ps(res, res); + res =_mm_hadd_ps(res, res); + + return _mm_cvtss_f32(res); +} +#endif // __AVX__ || __AVX2__ || __AVX512F__ +#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) + +#if defined(__ARM_NEON) +#if !defined(__aarch64__) + +// 64-bit compatibility + +// vaddvq_s16 +// vpaddq_s16 +// vaddvq_s32 +// vaddvq_f32 +// vmaxvq_f32 +// vcvtnq_s32_f32 + +inline static int32_t vaddvq_s16(int16x8_t v) { + return + (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + + (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + + (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + + (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +} + +inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { + int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a)); + int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b)); + return vcombine_s16(a0, b0); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); +} + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +inline static float vmaxvq_f32(float32x4_t v) { + return + MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), + MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); +} + +inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { + int32x4_t res; + + res[0] = roundf(vgetq_lane_f32(v, 0)); + res[1] = roundf(vgetq_lane_f32(v, 1)); + res[2] = roundf(vgetq_lane_f32(v, 2)); + res[3] = roundf(vgetq_lane_f32(v, 3)); + + return res; +} + +// vld1q_s16_x2 +// vld1q_u8_x2 +// vld1q_u8_x4 +// vld1q_s8_x2 +// vld1q_s8_x4 +// TODO: double-check these work correctly + +typedef struct wsp_ggml_int16x8x2_t { + int16x8_t val[2]; +} wsp_ggml_int16x8x2_t; + +inline static wsp_ggml_int16x8x2_t wsp_ggml_vld1q_s16_x2(const int16_t * ptr) { + wsp_ggml_int16x8x2_t res; + + res.val[0] = vld1q_s16(ptr + 0); + res.val[1] = vld1q_s16(ptr + 8); + + return res; +} + +typedef struct wsp_ggml_uint8x16x2_t { + uint8x16_t val[2]; +} wsp_ggml_uint8x16x2_t; + +inline static wsp_ggml_uint8x16x2_t wsp_ggml_vld1q_u8_x2(const uint8_t * ptr) { + wsp_ggml_uint8x16x2_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + + return res; +} + +typedef struct wsp_ggml_uint8x16x4_t { + uint8x16_t val[4]; +} wsp_ggml_uint8x16x4_t; + +inline static wsp_ggml_uint8x16x4_t wsp_ggml_vld1q_u8_x4(const uint8_t * ptr) { + wsp_ggml_uint8x16x4_t res; + + res.val[0] = vld1q_u8(ptr + 0); + res.val[1] = vld1q_u8(ptr + 16); + res.val[2] = vld1q_u8(ptr + 32); + res.val[3] = vld1q_u8(ptr + 48); + + return res; +} + +typedef struct wsp_ggml_int8x16x2_t { + int8x16_t val[2]; +} wsp_ggml_int8x16x2_t; + +inline static wsp_ggml_int8x16x2_t wsp_ggml_vld1q_s8_x2(const int8_t * ptr) { + wsp_ggml_int8x16x2_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + + return res; +} + +typedef struct wsp_ggml_int8x16x4_t { + int8x16_t val[4]; +} wsp_ggml_int8x16x4_t; + +inline static wsp_ggml_int8x16x4_t wsp_ggml_vld1q_s8_x4(const int8_t * ptr) { + wsp_ggml_int8x16x4_t res; + + res.val[0] = vld1q_s8(ptr + 0); + res.val[1] = vld1q_s8(ptr + 16); + res.val[2] = vld1q_s8(ptr + 32); + res.val[3] = vld1q_s8(ptr + 48); + + return res; +} + +#else + +#define wsp_ggml_int16x8x2_t int16x8x2_t +#define wsp_ggml_uint8x16x2_t uint8x16x2_t +#define wsp_ggml_uint8x16x4_t uint8x16x4_t +#define wsp_ggml_int8x16x2_t int8x16x2_t +#define wsp_ggml_int8x16x4_t int8x16x4_t + +#define wsp_ggml_vld1q_s16_x2 vld1q_s16_x2 +#define wsp_ggml_vld1q_u8_x2 vld1q_u8_x2 +#define wsp_ggml_vld1q_u8_x4 vld1q_u8_x4 +#define wsp_ggml_vld1q_s8_x2 vld1q_s8_x2 +#define wsp_ggml_vld1q_s8_x4 vld1q_s8_x4 + +#endif +#endif + +#if defined(__ARM_NEON) || defined(__wasm_simd128__) +#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s +#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) +#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) +#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) +#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) +#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) +#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) +#define B8(c,s ) B7(c,s, c), B7(c,s, s) + +// precomputed tables for expanding 8bits to 8 bytes: +static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 +static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 +#endif + +// reference implementation for deterministic creation of model files +void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { + quantize_row_q4_0_reference(x, y, k); +} + +void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { + const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 4) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + y[i].m = WSP_GGML_FP32_TO_FP16(min); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { + quantize_row_q4_1_reference(x, y, k); +} + +void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -16; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); + const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(qh)); + } +} + +void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { + quantize_row_q5_0_reference(x, y, k); +} + +void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { + const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float min = FLT_MAX; + float max = -FLT_MAX; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + + if (v < min) min = v; + if (v > max) max = v; + } + + const float d = (max - min) / ((1 << 5) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + y[i].m = WSP_GGML_FP32_TO_FP16(min); + + uint32_t qh = 0; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = (x[i*qk + 0 + j] - min)*id; + const float x1 = (x[i*qk + qk/2 + j] - min)*id; + + const uint8_t xi0 = (uint8_t)(x0 + 0.5f); + const uint8_t xi1 = (uint8_t)(x1 + 0.5f); + + y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); + + // get the 5-th bit and store it in qh at the right position + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + qk/2); + } + + memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); + } +} + +void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { + quantize_row_q5_1_reference(x, y, k); +} + +// reference implementation for deterministic creation of model files +void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[i*QK8_0 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[i*QK8_0 + j]*id; + + y[i].qs[j] = roundf(x0); + } + } +} + +void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + } + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = WSP_GGML_FP32_TO_FP16(d); + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_0); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_0, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = WSP_GGML_FP32_TO_FP16(d); + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + } +#else + WSP_GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_reference(x, y, k); +#endif +} + +// reference implementation for deterministic creation of model files +void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { + assert(QK8_1 == 32); + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_1; j++) { + const float v = x[i*QK8_1 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + int sum = 0; + + for (int j = 0; j < QK8_1/2; ++j) { + const float v0 = x[i*QK8_1 + j]*id; + const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id; + + y[i].qs[ j] = roundf(v0); + y[i].qs[QK8_1/2 + j] = roundf(v1); + + sum += y[i].qs[ j]; + sum += y[i].qs[QK8_1/2 + j]; + } + + y[i].s = sum*d; + } +} + +void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { + assert(k % QK8_1 == 0); + const int nb = k / QK8_1; + + block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + int32x4_t accv = vdupq_n_s32(0); + + for (int j = 0; j < 8; j++) { + const float32x4_t v = vmulq_n_f32(srcv[j], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); + + accv = vaddq_s32(accv, vi); + } + + y[i].s = d * vaddvq_s32(accv); + } +#elif defined(__wasm_simd128__) + for (int i = 0; i < nb; i++) { + v128_t srcv [8]; + v128_t asrcv[8]; + v128_t amaxv[8]; + + for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); + for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); + + for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); + for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); + for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); + + const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), + wasm_f32x4_extract_lane(amaxv[0], 1)), + MAX(wasm_f32x4_extract_lane(amaxv[0], 2), + wasm_f32x4_extract_lane(amaxv[0], 3))); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + v128_t accv = wasm_i32x4_splat(0); + + for (int j = 0; j < 8; j++) { + const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); + const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); + + y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); + y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); + y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); + y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); + + accv = wasm_i32x4_add(accv, vi); + } + + y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + + wasm_i32x4_extract_lane(accv, 1) + + wasm_i32x4_extract_lane(accv, 2) + + wasm_i32x4_extract_lane(accv, 3)); + } +#elif defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + y[i].d = d; + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Compute the sum of the quants and set y[i].s + y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#elif defined(__riscv_v_intrinsic) + + size_t vl = __riscv_vsetvl_e32m4(QK8_1); + + for (int i = 0; i < nb; i++) { + // load elements + vfloat32m4_t v_x = __riscv_vle32_v_f32m4(x+i*QK8_1, vl); + + vfloat32m4_t vfabs = __riscv_vfabs_v_f32m4(v_x, vl); + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0, vl); + vfloat32m1_t vmax = __riscv_vfredmax_vs_f32m4_f32m1(vfabs, tmp, vl); + float amax = __riscv_vfmv_f_s_f32m1_f32(vmax); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = d; + + vfloat32m4_t x0 = __riscv_vfmul_vf_f32m4(v_x, id, vl); + + // convert to integer + vint16m2_t vi = __riscv_vfncvt_x_f_w_i16m2(x0, vl); + vint8m1_t vs = __riscv_vncvt_x_x_w_i8m1(vi, vl); + + // store result + __riscv_vse8_v_i8m1(y[i].qs , vs, vl); + + // compute sum for y[i].s + vint16m1_t tmp2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t vwrs = __riscv_vwredsum_vs_i8m1_i16m1(vs, tmp2, vl); + + // set y[i].s + int sum = __riscv_vmv_x_s_i16m1_i16(vwrs); + y[i].s = sum*d; + } +#else + WSP_GGML_UNUSED(nb); + // scalar + quantize_row_q8_1_reference(x, y, k); +#endif +} + +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { + static const int qk = QK4_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + const float m = WSP_GGML_FP16_TO_FP32(x[i].m); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F); + const int x1 = (x[i].qs[j] >> 4); + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { + static const int qk = QK5_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + +void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { + static const int qk = QK5_1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + const float m = WSP_GGML_FP16_TO_FP32(x[i].m); + + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int x0 = (x[i].qs[j] & 0x0F) | xh_0; + const int x1 = (x[i].qs[j] >> 4) | xh_1; + + y[i*qk + j + 0 ] = x0*d + m; + y[i*qk + j + qk/2] = x1*d + m; + } + } +} + +void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k) { + static const int qk = QK8_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + for (int j = 0; j < qk; ++j) { + y[i*qk + j] = x[i].qs[j]*d; + } + } +} + +// +// 2-6 bit quantization in super-blocks +// + +// +// ===================== Helper functions +// +static inline int nearest_int(float fval) { + assert(fval <= 4194303.f); + float val = fval + 12582912.f; + int i; memcpy(&i, &val, sizeof(int)); + return (i & 0x007fffff) - 0x00400000; +} + +static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (amax < 1e-30f) { // all zero + for (int i = 0; i < n; ++i) { + L[i] = 0; + } + return 0.f; + } + float iscale = -nmax / max; + if (rmse_type == 0) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + return 1/iscale; + } + bool return_early = false; + if (rmse_type < 0) { + rmse_type = -rmse_type; + return_early = true; + } + int weight_type = rmse_type%2; + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + float w = weight_type == 1 ? x[i] * x[i] : 1; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + float scale = sumlx/suml2; + if (return_early) return suml2 > 0 ? 0.5f*(scale + 1/iscale) : 1/iscale; + float best = scale * sumlx; + for (int is = -9; is <= 9; ++is) { + if (is == 0) { + continue; + } + iscale = -(nmax + 0.1f*is) / max; + sumlx = suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + float w = weight_type == 1 ? x[i] * x[i] : 1; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + if (suml2 > 0 && sumlx*sumlx > best*suml2) { + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + L[i] = nmax + MAX(-nmax, MIN(nmax-1, l)); + } + scale = sumlx/suml2; best = scale*sumlx; + } + } + return scale; +} + +static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) { + float max = 0; + float amax = 0; + for (int i = 0; i < n; ++i) { + float ax = fabsf(x[i]); + if (ax > amax) { amax = ax; max = x[i]; } + } + if (!amax) { // all zero + for (int i = 0; i < n; ++i) { L[i] = 0; } + return 0.f; + } + float iscale = -nmax / max; + if (do_rmse) { + float sumlx = 0; + float suml2 = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l; + float w = x[i]*x[i]; + sumlx += w*x[i]*l; + suml2 += w*l*l; + } + for (int itry = 0; itry < 5; ++itry) { + int n_changed = 0; + for (int i = 0; i < n; ++i) { + float w = x[i]*x[i]; + float slx = sumlx - w*x[i]*L[i]; + if (slx > 0) { + float sl2 = suml2 - w*L[i]*L[i]; + int new_l = nearest_int(x[i] * sl2 / slx); + new_l = MAX(-nmax, MIN(nmax-1, new_l)); + if (new_l != L[i]) { + slx += w*x[i]*new_l; + sl2 += w*new_l*new_l; + if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) { + L[i] = new_l; sumlx = slx; suml2 = sl2; + ++n_changed; + } + } + } + } + if (!n_changed) { + break; + } + } + for (int i = 0; i < n; ++i) { + L[i] += nmax; + } + return sumlx / suml2; + } + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale * x[i]); + l = MAX(-nmax, MIN(nmax-1, l)); + L[i] = l + nmax; + } + return 1/iscale; +} + +static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, + int ntry, float alpha) { + float min = x[0]; + float max = x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + } + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = 0; + return 0.f; + } + if (min > 0) min = 0; + float iscale = nmax/(max - min); + float scale = 1/iscale; + for (int itry = 0; itry < ntry; ++itry) { + float sumlx = 0; int suml2 = 0; + bool did_change = false; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + if (l != L[i]) { + L[i] = l; + did_change = true; + } + sumlx += (x[i] - min)*l; + suml2 += l*l; + } + scale = sumlx/suml2; + float sum = 0; + for (int i = 0; i < n; ++i) { + sum += x[i] - scale*L[i]; + } + min = alpha*min + (1 - alpha)*sum/n; + if (min > 0) min = 0; + iscale = 1/scale; + if (!did_change) break; + } + *the_min = -min; + return scale; +} + +static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights, + uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux, + float rmin, float rdelta, int nstep, bool use_mad) { + float min = x[0]; + float max = x[0]; + float sum_w = weights[0]; + float sum_x = sum_w * x[0]; + for (int i = 1; i < n; ++i) { + if (x[i] < min) min = x[i]; + if (x[i] > max) max = x[i]; + float w = weights[i]; + sum_w += w; + sum_x += w * x[i]; + } + if (min > 0) min = 0; + if (max == min) { + for (int i = 0; i < n; ++i) L[i] = 0; + *the_min = -min; + return 0.f; + } + float iscale = nmax/(max - min); + float scale = 1/iscale; + float best_mad = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + L[i] = MAX(0, MIN(nmax, l)); + float diff = scale * L[i] + min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + best_mad += w * diff; + } + if (nstep < 1) { + *the_min = -min; + return scale; + } + for (int is = 0; is <= nstep; ++is) { + iscale = (rmin + rdelta*is + nmax)/(max - min); + float sum_l = 0, sum_l2 = 0, sum_xl = 0; + for (int i = 0; i < n; ++i) { + int l = nearest_int(iscale*(x[i] - min)); + l = MAX(0, MIN(nmax, l)); + Laux[i] = l; + float w = weights[i]; + sum_l += w*l; + sum_l2 += w*l*l; + sum_xl += w*l*x[i]; + } + float D = sum_w * sum_l2 - sum_l * sum_l; + if (D > 0) { + float this_scale = (sum_w * sum_xl - sum_x * sum_l)/D; + float this_min = (sum_l2 * sum_x - sum_l * sum_xl)/D; + if (this_min > 0) { + this_min = 0; + this_scale = sum_xl / sum_l2; + } + float mad = 0; + for (int i = 0; i < n; ++i) { + float diff = this_scale * Laux[i] + this_min - x[i]; + diff = use_mad ? fabsf(diff) : diff * diff; + float w = weights[i]; + mad += w * diff; + } + if (mad < best_mad) { + for (int i = 0; i < n; ++i) { + L[i] = Laux[i]; + } + best_mad = mad; + scale = this_scale; + min = this_min; + } + } + } + *the_min = -min; + return scale; +} + +#if QK_K == 256 +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} +#endif + +//========================- 2-bit (de)-quantization + +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[16]; + float weights[16]; + float mins[QK_K/16]; + float scales[QK_K/16]; + + const float q4scale = 15.f; + + for (int i = 0; i < nb; i++) { + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 16; ++l) weights[l] = fabsf(x[16*j + l]); + scales[j] = make_qkx2_quants(16, 3, x + 16*j, weights, L + 16*j, &mins[j], Laux, -0.5f, 0.1f, 15, true); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + if (max_scale > 0) { + float iscale = q4scale/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = l; + } + y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/q4scale); + } else { + for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0; + y[i].d = WSP_GGML_FP32_TO_FP16(0.f); + } + if (max_min > 0) { + float iscale = q4scale/max_min; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*mins[j]); + y[i].scales[j] |= (l << 4); + } + y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/q4scale); + } else { + y[i].dmin = WSP_GGML_FP32_TO_FP16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * (y[i].scales[j] & 0xF); + if (!d) continue; + const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * (y[i].scales[j] >> 4); + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int((x[16*j + ii] + dm)/d); + l = MAX(0, MIN(3, l)); + L[16*j + ii] = l; + } + } + +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + const float min = WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * q = x[i].qs; + +#if QK_K == 256 + int is = 0; + float dl, ml; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + uint8_t sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml; + + sc = x[i].scales[is++]; + dl = d * (sc & 0xF); ml = min * (sc >> 4); + for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml; + + shift += 2; + } + q += 32; + } +#else + float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4); + float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4); + float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4); + float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4); + for (int l = 0; l < 16; ++l) { + y[l+ 0] = dl1 * ((int8_t)((q[l] >> 0) & 3)) - ml1; + y[l+16] = dl2 * ((int8_t)((q[l] >> 2) & 3)) - ml2; + y[l+32] = dl3 * ((int8_t)((q[l] >> 4) & 3)) - ml3; + y[l+48] = dl4 * ((int8_t)((q[l] >> 6) & 3)) - ml4; + } + y += QK_K; +#endif + } +} + +void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { + quantize_row_q2_K_reference(x, vy, k); +} + +size_t wsp_ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K; + quantize_row_q2_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q2_K)); +} + +//========================= 3-bit (de)-quantization + +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K / 16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true); + float scale = fabsf(scales[j]); + if (scale > amax) { + amax = scale; max_scale = scales[j]; + } + } + +#if QK_K == 256 + memset(y[i].scales, 0, 12); + if (max_scale) { + float iscale = -32.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int8_t l = nearest_int(iscale*scales[j]); + l = MAX(-32, MIN(31, l)) + 32; + if (j < 8) { + y[i].scales[j] = l & 0xF; + } else { + y[i].scales[j-8] |= ((l & 0xF) << 4); + } + l >>= 4; + y[i].scales[j%4 + 8] |= (l << (2*(j/4))); + } + y[i].d = WSP_GGML_FP32_TO_FP16(1/iscale); + } else { + y[i].d = WSP_GGML_FP32_TO_FP16(0.f); + } + + int8_t sc; + for (int j = 0; j < QK_K/16; ++j) { + sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4; + sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32; + float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } +#else + if (max_scale) { + float iscale = -8.f/max_scale; + for (int j = 0; j < QK_K/16; j+=2) { + int l1 = nearest_int(iscale*scales[j]); + l1 = 8 + MAX(-8, MIN(7, l1)); + int l2 = nearest_int(iscale*scales[j+1]); + l2 = 8 + MAX(-8, MIN(7, l2)); + y[i].scales[j/2] = l1 | (l2 << 4); + } + y[i].d = WSP_GGML_FP32_TO_FP16(1/iscale); + } else { + for (int j = 0; j < QK_K/16; j+=2) { + y[i].scales[j/2] = 0; + } + y[i].d = WSP_GGML_FP32_TO_FP16(0.f); + } + for (int j = 0; j < QK_K/16; ++j) { + int s = j%2 == 0 ? y[i].scales[j/2] & 0xF : y[i].scales[j/2] >> 4; + float d = WSP_GGML_FP16_TO_FP32(y[i].d) * (s - 8); + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-4, MIN(3, l)); + L[16*j + ii] = l + 4; + } + } +#endif + + memset(y[i].hmask, 0, QK_K/8); + // We put the high-bit for the 1st 8 quants into bit 0, the next 8 into bit 1, etc. + int m = 0; + uint8_t hm = 1; + for (int j = 0; j < QK_K; ++j) { + if (L[j] > 3) { + y[i].hmask[m] |= hm; + L[j] -= 4; + } + if (++m == QK_K/8) { + m = 0; hm <<= 1; + } + } +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6); + } + } +#else + for (int l = 0; l < 16; ++l) { + y[i].qs[l] = L[l] | (L[l + 16] << 2) | (L[l + 32] << 4) | (L[l + 48] << 6); + } +#endif + + x += QK_K; + } +} + +#if QK_K == 256 +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + uint32_t aux[4]; + const int8_t * scales = (const int8_t*)aux; + + for (int i = 0; i < nb; i++) { + + const float d_all = WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + uint8_t m = 1; + + memcpy(aux, x[i].scales, 12); + uint32_t tmp = aux[2]; + aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + int is = 0; + float dl; + for (int n = 0; n < QK_K; n += 128) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4)); + } + + dl = d_all * (scales[is++] - 32); + for (int l = 0; l < 16; ++l) { + *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4)); + } + + shift += 2; + m <<= 1; + } + q += 32; + } + + } +} +#else +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + assert(QK_K == 64); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d_all = WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + + const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); + const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); + const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); + const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); + + for (int l=0; l<8; ++l) { + uint8_t h = hm[l]; + y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4)); + y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4)); + y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4)); + y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4)); + y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4)); + y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4)); + y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4)); + y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4)); + } + y += QK_K; + } +} +#endif + +void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { + quantize_row_q3_K_reference(x, vy, k); +} + +size_t wsp_ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K; + quantize_row_q3_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q3_K)); +} + +// ====================== 4-bit (de)-quantization + +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + uint8_t L[QK_K]; + uint8_t Laux[32]; + float weights[32]; + float mins[QK_K/32]; + float scales[QK_K/32]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + //scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 15, x + 32*j, weights, L + 32*j, &mins[j], Laux, -1.f, 0.1f, 20, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + +#if QK_K == 256 + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/63.f); + y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + } + } +#else + const float s_factor = 15.f; + float inv_scale = max_scale > 0 ? s_factor/max_scale : 0.f; + float inv_min = max_min > 0 ? s_factor/max_min : 0.f; + int d1 = nearest_int(inv_scale*scales[0]); + int m1 = nearest_int(inv_min*mins[0]); + int d2 = nearest_int(inv_scale*scales[1]); + int m2 = nearest_int(inv_min*mins[1]); + y[i].scales[0] = d1 | (m1 << 4); + y[i].scales[1] = d2 | (m2 << 4); + y[i].d[0] = WSP_GGML_FP32_TO_FP16(max_scale/s_factor); + y[i].d[1] = WSP_GGML_FP32_TO_FP16(max_min/s_factor); + + float sumlx = 0; + int suml2 = 0; + for (int j = 0; j < QK_K/32; ++j) { + const uint8_t sd = y[i].scales[j] & 0xF; + const uint8_t sm = y[i].scales[j] >> 4; + const float d = WSP_GGML_FP16_TO_FP32(y[i].d[0]) * sd; + if (!d) continue; + const float m = WSP_GGML_FP16_TO_FP32(y[i].d[1]) * sm; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + m)/d); + l = MAX(0, MIN(15, l)); + L[32*j + ii] = l; + sumlx += (x[32*j + ii] + m)*l*sd; + suml2 += l*l*sd*sd; + } + } + if (suml2) { + y[i].d[0] = WSP_GGML_FP32_TO_FP16(sumlx/suml2); + } +#endif + uint8_t * q = y[i].qs; + for (int j = 0; j < QK_K; j += 64) { + for (int l = 0; l < 32; ++l) q[l] = L[j + l] | (L[j + l + 32] << 4); + q += 32; + } + + x += QK_K; + + } +} + +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * q = x[i].qs; + +#if QK_K == 256 + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + const float min = WSP_GGML_FP16_TO_FP32(x[i].dmin); + + int is = 0; + uint8_t sc, m; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } +#else + const float dall = WSP_GGML_FP16_TO_FP32(x[i].d[0]); + const float mall = WSP_GGML_FP16_TO_FP32(x[i].d[1]); + const float d1 = dall * (x[i].scales[0] & 0xF), m1 = mall * (x[i].scales[0] >> 4); + const float d2 = dall * (x[i].scales[1] & 0xF), m2 = mall * (x[i].scales[1] >> 4); + for (int l = 0; l < 32; ++l) { + y[l+ 0] = d1 * (q[l] & 0xF) - m1; + y[l+32] = d2 * (q[l] >> 4) - m2; + } + y += QK_K; +#endif + + } +} + +void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q4_K * restrict y = vy; + quantize_row_q4_K_reference(x, y, k); +} + +size_t wsp_ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; + quantize_row_q4_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q4_K)); +} + +// ====================== 5-bit (de)-quantization + +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + +#if QK_K == 256 + uint8_t L[QK_K]; + float mins[QK_K/32]; + float scales[QK_K/32]; + float weights[32]; + uint8_t Laux[32]; +#else + int8_t L[QK_K]; + float scales[QK_K/16]; +#endif + + for (int i = 0; i < nb; i++) { + +#if QK_K == 256 + + float max_scale = 0; // as we are deducting the min, scales are always positive + float max_min = 0; + for (int j = 0; j < QK_K/32; ++j) { + //scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 9, 0.5f); + float sum_x2 = 0; + for (int l = 0; l < 32; ++l) sum_x2 += x[32*j + l] * x[32*j + l]; + float av_x = sqrtf(sum_x2/32); + for (int l = 0; l < 32; ++l) weights[l] = av_x + fabsf(x[32*j + l]); + scales[j] = make_qkx2_quants(32, 31, x + 32*j, weights, L + 32*j, &mins[j], Laux, -0.5f, 0.1f, 15, false); + float scale = scales[j]; + if (scale > max_scale) { + max_scale = scale; + } + float min = mins[j]; + if (min > max_min) { + max_min = min; + } + } + + float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f; + float inv_min = max_min > 0 ? 63.f/max_min : 0.f; + for (int j = 0; j < QK_K/32; ++j) { + uint8_t ls = nearest_int(inv_scale*scales[j]); + uint8_t lm = nearest_int(inv_min*mins[j]); + ls = MIN(63, ls); + lm = MIN(63, lm); + if (j < 4) { + y[i].scales[j] = ls; + y[i].scales[j+4] = lm; + } else { + y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4); + y[i].scales[j-4] |= ((ls >> 4) << 6); + y[i].scales[j-0] |= ((lm >> 4) << 6); + } + } + y[i].d = WSP_GGML_FP32_TO_FP16(max_scale/63.f); + y[i].dmin = WSP_GGML_FP32_TO_FP16(max_min/63.f); + + uint8_t sc, m; + for (int j = 0; j < QK_K/32; ++j) { + get_scale_min_k4(j, y[i].scales, &sc, &m); + const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * sc; + if (!d) continue; + const float dm = WSP_GGML_FP16_TO_FP32(y[i].dmin) * m; + for (int ii = 0; ii < 32; ++ii) { + int l = nearest_int((x[32*j + ii] + dm)/d); + l = MAX(0, MIN(31, l)); + L[32*j + ii] = l; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + uint8_t m1 = 1, m2 = 2; + for (int n = 0; n < QK_K; n += 64) { + for (int j = 0; j < 32; ++j) { + int l1 = L[n + j]; + if (l1 > 15) { + l1 -= 16; qh[j] |= m1; + } + int l2 = L[n + j + 32]; + if (l2 > 15) { + l2 -= 16; qh[j] |= m2; + } + ql[j] = l1 | (l2 << 4); + } + m1 <<= 2; m2 <<= 2; + ql += 32; + } +#else + float max_scale = 0, amax = 0; + for (int j = 0; j < QK_K/16; ++j) { + scales[j] = make_qx_quants(16, 16, x + 16*j, L + 16*j, 1); + float abs_scale = fabsf(scales[j]); + if (abs_scale > amax) { + amax = abs_scale; + max_scale = scales[j]; + } + } + + float iscale = -128.f/max_scale; + for (int j = 0; j < QK_K/16; ++j) { + int l = nearest_int(iscale*scales[j]); + y[i].scales[j] = MAX(-128, MIN(127, l)); + } + y[i].d = WSP_GGML_FP32_TO_FP16(1/iscale); + + for (int j = 0; j < QK_K/16; ++j) { + const float d = WSP_GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) continue; + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-16, MIN(15, l)); + L[16*j + ii] = l + 16; + } + } + + uint8_t * restrict qh = y[i].qh; + uint8_t * restrict ql = y[i].qs; + memset(qh, 0, QK_K/8); + + for (int j = 0; j < 32; ++j) { + int jm = j%8; + int is = j/8; + int l1 = L[j]; + if (l1 > 15) { + l1 -= 16; qh[jm] |= (1 << is); + } + int l2 = L[j + 32]; + if (l2 > 15) { + l2 -= 16; qh[jm] |= (1 << (4 + is)); + } + ql[j] = l1 | (l2 << 4); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const uint8_t * ql = x[i].qs; + const uint8_t * qh = x[i].qh; + +#if QK_K == 256 + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + const float min = WSP_GGML_FP16_TO_FP32(x[i].dmin); + + int is = 0; + uint8_t sc, m; + uint8_t u1 = 1, u2 = 2; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float d1 = d * sc; const float m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float d2 = d * sc; const float m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2; + ql += 32; is += 2; + u1 <<= 2; u2 <<= 2; + } +#else + float d = WSP_GGML_FP16_TO_FP32(x[i].d); + const int8_t * restrict s = x[i].scales; + for (int l = 0; l < 8; ++l) { + y[l+ 0] = d * s[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16)); + y[l+ 8] = d * s[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16)); + y[l+16] = d * s[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16)); + y[l+24] = d * s[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16)); + y[l+32] = d * s[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16)); + y[l+40] = d * s[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16)); + y[l+48] = d * s[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16)); + y[l+56] = d * s[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16)); + } + y += QK_K; +#endif + } +} + +void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q5_K * restrict y = vy; + quantize_row_q5_K_reference(x, y, k); +} + +size_t wsp_ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; + quantize_row_q5_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q5_K)); +} + +// ====================== 6-bit (de)-quantization + +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + int8_t L[QK_K]; + float scales[QK_K/16]; + + for (int i = 0; i < nb; i++) { + + float max_scale = 0; + float max_abs_scale = 0; + + for (int ib = 0; ib < QK_K/16; ++ib) { + + const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1); + scales[ib] = scale; + + const float abs_scale = fabsf(scale); + if (abs_scale > max_abs_scale) { + max_abs_scale = abs_scale; + max_scale = scale; + } + + } + + if (!max_abs_scale) { + memset(&y[i], 0, sizeof(block_q6_K)); + y[i].d = WSP_GGML_FP32_TO_FP16(0.f); + x += QK_K; + continue; + } + + float iscale = -128.f/max_scale; + y[i].d = WSP_GGML_FP32_TO_FP16(1/iscale); + for (int ib = 0; ib < QK_K/16; ++ib) { + y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib])); + } + + for (int j = 0; j < QK_K/16; ++j) { + float d = WSP_GGML_FP16_TO_FP32(y[i].d) * y[i].scales[j]; + if (!d) { + continue; + } + for (int ii = 0; ii < 16; ++ii) { + int l = nearest_int(x[16*j + ii]/d); + l = MAX(-32, MIN(31, l)); + L[16*j + ii] = l + 32; + } + } + + uint8_t * restrict ql = y[i].ql; + uint8_t * restrict qh = y[i].qh; +#if QK_K == 256 + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[j + l + 0] & 0xF; + const uint8_t q2 = L[j + l + 32] & 0xF; + const uint8_t q3 = L[j + l + 64] & 0xF; + const uint8_t q4 = L[j + l + 96] & 0xF; + ql[l+ 0] = q1 | (q3 << 4); + ql[l+32] = q2 | (q4 << 4); + qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6); + } + ql += 64; + qh += 32; + } +#else + for (int l = 0; l < 32; ++l) { + const uint8_t q1 = L[l + 0] & 0xF; + const uint8_t q2 = L[l + 32] & 0xF; + ql[l] = q1 | (q2 << 4); + } + for (int l = 0; l < 16; ++l) { + qh[l] = (L[l] >> 4) | ((L[l + 16] >> 4) << 2) | ((L[l + 32] >> 4) << 4) | ((L[l + 48] >> 4) << 6); + } +#endif + + x += QK_K; + + } +} + +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict ql = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict sc = x[i].scales; + +#if QK_K == 256 + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } +#else + for (int l = 0; l < 16; ++l) { + const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l+ 0] = d * sc[0] * q1; + y[l+16] = d * sc[1] * q2; + y[l+32] = d * sc[2] * q3; + y[l+48] = d * sc[3] * q4; + } + y += 64; +#endif + + } +} + +void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { + assert(k % QK_K == 0); + block_q6_K * restrict y = vy; + quantize_row_q6_K_reference(x, y, k); +} + +size_t wsp_ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { + assert(k % QK_K == 0); + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { + block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; + quantize_row_q6_K_reference(src + j, y, k); + } + return (n/QK_K*sizeof(block_q6_K)); +} + +//===================================== Q8_K ============================================== + +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + + float max = 0; + float amax = 0; + for (int j = 0; j < QK_K; ++j) { + float ax = fabsf(x[j]); + if (ax > amax) { + amax = ax; max = x[j]; + } + } + if (!amax) { + y[i].d = 0; + memset(y[i].qs, 0, QK_K); + x += QK_K; + continue; + } + const float iscale = -128.f/max; + for (int j = 0; j < QK_K; ++j) { + int v = nearest_int(iscale*x[j]); + y[i].qs[j] = MIN(127, v); + } + for (int j = 0; j < QK_K/16; ++j) { + int sum = 0; + for (int ii = 0; ii < 16; ++ii) { + sum += y[i].qs[j*16 + ii]; + } + y[i].bsums[j] = sum; + } + y[i].d = 1/iscale; + x += QK_K; + } +} + +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) { + assert(k % QK_K == 0); + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < QK_K; ++j) { + *y++ = x[i].d * x[i].qs[j]; + } + } +} + +void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) { + quantize_row_q8_K_reference(x, y, k); +} + +//===================================== Dot ptoducts ================================= + +// +// Helper functions +// +#if __AVX__ || __AVX2__ || __AVX512F__ + +// shuffles to pick the required scales in dot products +static inline __m256i get_scale_shuffle_q3k(int i) { + static const uint8_t k_shuffle[128] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15, + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m256i get_scale_shuffle_k4(int i) { + static const uint8_t k_shuffle[256] = { + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, + 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, + 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, + 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, + 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11, + 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, + 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15 + }; + return _mm256_loadu_si256((const __m256i*)k_shuffle + i); +} +static inline __m128i get_scale_shuffle(int i) { + static const uint8_t k_shuffle[128] = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, + 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, + 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11, + 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13, + 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15 + }; + return _mm_loadu_si128((const __m128i*)k_shuffle + i); +} +#endif + +void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q4_0 * restrict x0 = &x[i + 0]; + const block_q4_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + bx = _mm256_sub_epi8( bx, off ); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx = _mm_and_si128(lowMask, tmp); + __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx, by); + + bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx = _mm_sub_epi8(bx, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx, by); + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); + + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + // First round without accumulation + { + _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[0].d) * WSP_GGML_FP16_TO_FP32(y[0].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[1].d) * WSP_GGML_FP16_TO_FP32(y[1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + acc_0 = _mm_mul_ps( d_0_1, p0 ); + acc_1 = _mm_mul_ps( d_0_1, p1 ); + acc_2 = _mm_mul_ps( d_2_3, p2 ); + acc_3 = _mm_mul_ps( d_2_3, p3 ); + } + + assert(nb % 2 == 0); // TODO: handle odd nb + + // Main loop + for (int i = 2; i < nb; i+=2) { + _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[i + 1].d) * WSP_GGML_FP16_TO_FP32(y[i + 1].d) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + // subtract offset + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 8, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 8, vl); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += sumi*WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d); + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F) - 8; + const int v1 = (x[i].qs[j] >> 4) - 8; + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += sumi*WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d); + } + + *s = sumf; +#endif +} + +void wsp_ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q4_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + + // TODO: add WASM SIMD +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs = 0; + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q4_1 * restrict x0 = &x[i + 0]; + const block_q4_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i + 0]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + summs += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s + WSP_GGML_FP16_TO_FP32(x1->m) * y1->s; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + // Main loop + for (int i = 0; i < nb; ++i) { + const float d0 = WSP_GGML_FP16_TO_FP32(x[i].d); + const float d1 = y[i].d; + + summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); + + // Compute combined scales + const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); + + // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes + const __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); + + const __m256 xy = mul_sum_us8_pairs_float(bx, by); + + // Accumulate d0*d1*x*y +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d0d1, xy, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); +#endif + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + for (int i = 0; i < nb; i++) { + // load elements + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + + // mask and store lower part of x, and then upper part + vuint8mf2_t x_a = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_l = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F); + const int v1 = (x[i].qs[j] >> 4); + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#endif +} + +void wsp_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(qk == QK5_0); + + const block_q5_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q5_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + // extract the 5th bit via lookup table ((!b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_1[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_1[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_0 * restrict y0 = &y[i]; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_1[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) + const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); + const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( + wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d)))); + } + + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); + bx = _mm256_or_si256(bx, bxhi); + + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps(d, q, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8((char)0xF0); + + // Main loop + for (int i = 0; i < nb; i++) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_andnot_si128(bxhil, mask); + bxhih = _mm_andnot_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx); + __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx = MM256_SET_M128I(bxh, bxl); + + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + /* Multiply q with scale and accumulate */ + acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); + } + + *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + uint32_t qh; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + // These tempory registers are for masking and shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl); + + vuint32m2_t vt_3 = __riscv_vsll_vx_u32m2(vt_2, 16, vl); + vuint32m2_t vt_4 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); + + // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(vt_2, qh, vl); + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(xha_0, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + + // ((qh & (1u << (j + 16))) >> (j + 12)); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(vt_3, qh, vl); + vuint32m2_t xhl_1 = __riscv_vsrl_vv_u32m2(xha_1, vt_4, vl); + + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xhl_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xhl_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); + + vint8mf2_t x_ai = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t x_li = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint8mf2_t v0 = __riscv_vsub_vx_i8mf2(x_ai, 16, vl); + vint8mf2_t v1 = __riscv_vsub_vx_i8mf2(x_li, 16, vl); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)) * sumi; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; + const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); + + const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; + const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; + + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)) * sumi; + } + + *s = sumf; +#endif +} + +void wsp_ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(qk == QK5_1); + + const block_q5_1 * restrict x = vx; + const block_q8_1 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float summs0 = 0.0f; + float summs1 = 0.0f; + + uint32_t qh0; + uint32_t qh1; + + uint64_t tmp0[4]; + uint64_t tmp1[4]; + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q5_1 * restrict x1 = &x[i + 1]; + const block_q8_1 * restrict y0 = &y[i]; + const block_q8_1 * restrict y1 = &y[i + 1]; + + const uint8x16_t m4b = vdupq_n_u8(0x0F); + + summs0 += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s; + summs1 += WSP_GGML_FP16_TO_FP32(x1->m) * y1->s; + + // extract the 5th bit via lookup table ((b) << 4) + memcpy(&qh0, x0->qh, sizeof(qh0)); + memcpy(&qh1, x1->qh, sizeof(qh1)); + + tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; + tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; + tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; + tmp0[3] = table_b2b_0[(qh0 >> 24) ]; + + tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; + tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; + tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; + tmp1[3] = table_b2b_0[(qh1 >> 24) ]; + + const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); + const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); + const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); + const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); + + const uint8x16_t v0_0 = vld1q_u8(x0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); + + // add high bit + const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); + const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); + const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); + const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); + + // load y + const int8x16_t v1_0l = vld1q_s8(y0->qs); + const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1_1l = vld1q_s8(y1->qs); + const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; +#elif defined(__wasm_simd128__) + v128_t sumv = wasm_f32x4_splat(0.0f); + + float summs = 0.0f; + + uint32_t qh; + uint64_t tmp[4]; + + // TODO: check if unrolling this is better + for (int i = 0; i < nb; ++i) { + const block_q5_1 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; + + summs += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s; + + const v128_t m4b = wasm_i8x16_splat(0x0F); + + // extract the 5th bit + memcpy(&qh, x0->qh, sizeof(qh)); + + tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; + tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; + tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; + tmp[3] = table_b2b_0[(qh >> 24) ]; + + const v128_t qhl = wasm_v128_load(tmp + 0); + const v128_t qhh = wasm_v128_load(tmp + 2); + + const v128_t v0 = wasm_v128_load(x0->qs); + + // 4-bit -> 8-bit + const v128_t v0l = wasm_v128_and (v0, m4b); + const v128_t v0h = wasm_u8x16_shr(v0, 4); + + // add high bit + const v128_t v0lf = wasm_v128_or(v0l, qhl); + const v128_t v0hf = wasm_v128_or(v0h, qhh); + + // load y + const v128_t v1l = wasm_v128_load(y0->qs); + const v128_t v1h = wasm_v128_load(y0->qs + 16); + + // int8x16 -> int16x8 + const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); + const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); + const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); + const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); + + const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); + const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); + const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); + const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); + + // dot product + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(WSP_GGML_FP16_TO_FP32(x0->d) * y0->d))); + } + + *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d)); + + summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + __m256i bxhi = bytes_from_bits_32(x[i].qh); + bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); + bx = _mm256_or_si256(bx, bxhi); + + const __m256 dy = _mm256_set1_ps(y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx, by); + + acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + __m128i mask = _mm_set1_epi8(0x10); + + float summs = 0.0f; + + // Main loop + for (int i = 0; i < nb; i++) { + const __m256 dx = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d)); + + summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; + + __m256i bx = bytes_from_nibbles_32(x[i].qs); + const __m256i bxhi = bytes_from_bits_32(x[i].qh); + __m128i bxhil = _mm256_castsi256_si128(bxhi); + __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); + bxhil = _mm_and_si128(bxhil, mask); + bxhih = _mm_and_si128(bxhih, mask); + __m128i bxl = _mm256_castsi256_si128(bx); + __m128i bxh = _mm256_extractf128_si256(bx, 1); + bxl = _mm_or_si128(bxl, bxhil); + bxh = _mm_or_si128(bxh, bxhih); + bx = MM256_SET_M128I(bxh, bxl); + + const __m256 dy = _mm256_set1_ps(y[i].d); + const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_us8_pairs_float(bx, by); + + acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); + } + + *s = hsum_float_8(acc) + summs; +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + + uint32_t qh; + + size_t vl = __riscv_vsetvl_e8m1(qk/2); + + // temporary registers for shift operations + vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl); + vuint32m2_t vt_2 = __riscv_vadd_vx_u32m2(vt_1, 12, vl); + + for (int i = 0; i < nb; i++) { + memcpy(&qh, x[i].qh, sizeof(uint32_t)); + + // load qh + vuint32m2_t vqh = __riscv_vmv_v_x_u32m2(qh, vl); + + // ((qh >> (j + 0)) << 4) & 0x10; + vuint32m2_t xhr_0 = __riscv_vsrl_vv_u32m2(vqh, vt_1, vl); + vuint32m2_t xhl_0 = __riscv_vsll_vx_u32m2(xhr_0, 4, vl); + vuint32m2_t xha_0 = __riscv_vand_vx_u32m2(xhl_0, 0x10, vl); + + // ((qh >> (j + 12)) ) & 0x10; + vuint32m2_t xhr_1 = __riscv_vsrl_vv_u32m2(vqh, vt_2, vl); + vuint32m2_t xha_1 = __riscv_vand_vx_u32m2(xhr_1, 0x10, vl); + + // narrowing + vuint16m1_t xhc_0 = __riscv_vncvt_x_x_w_u16m1(xha_0, vl); + vuint8mf2_t xh_0 = __riscv_vncvt_x_x_w_u8mf2(xhc_0, vl); + + vuint16m1_t xhc_1 = __riscv_vncvt_x_x_w_u16m1(xha_1, vl); + vuint8mf2_t xh_1 = __riscv_vncvt_x_x_w_u8mf2(xhc_1, vl); + + // load + vuint8mf2_t tx = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vint8mf2_t y0 = __riscv_vle8_v_i8mf2(y[i].qs, vl); + vint8mf2_t y1 = __riscv_vle8_v_i8mf2(y[i].qs+16, vl); + + vuint8mf2_t x_at = __riscv_vand_vx_u8mf2(tx, 0x0F, vl); + vuint8mf2_t x_lt = __riscv_vsrl_vx_u8mf2(tx, 0x04, vl); + + vuint8mf2_t x_a = __riscv_vor_vv_u8mf2(x_at, xh_0, vl); + vuint8mf2_t x_l = __riscv_vor_vv_u8mf2(x_lt, xh_1, vl); + + vint8mf2_t v0 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_a); + vint8mf2_t v1 = __riscv_vreinterpret_v_u8mf2_i8mf2(x_l); + + vint16m1_t vec_mul1 = __riscv_vwmul_vv_i16m1(v0, y0, vl); + vint16m1_t vec_mul2 = __riscv_vwmul_vv_i16m1(v1, y1, vl); + + vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + + vint32m1_t vs1 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul1, vec_zero, vl); + vint32m1_t vs2 = __riscv_vwredsum_vs_i16m1_i32m1(vec_mul2, vs1, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(vs2); + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); + + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; + + const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; + const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; + + sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); + } + + sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; + } + + *s = sumf; +#endif +} + +void wsp_ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + assert(nb % 2 == 0); // TODO: handle odd nb + + for (int i = 0; i < nb; i += 2) { + const block_q8_0 * restrict x0 = &x[i + 0]; + const block_q8_0 * restrict x1 = &x[i + 1]; + const block_q8_0 * restrict y0 = &y[i + 0]; + const block_q8_0 * restrict y1 = &y[i + 1]; + + const int8x16_t x0_0 = vld1q_s8(x0->qs); + const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); + const int8x16_t x1_0 = vld1q_s8(x1->qs); + const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); + + // load y + const int8x16_t y0_0 = vld1q_s8(y0->qs); + const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); + const int8x16_t y1_0 = vld1q_s8(y1->qs); + const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); + +#if defined(__ARM_FEATURE_DOTPROD) + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), + vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), + vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); + +#else + const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); + const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); + const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); + const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); + + const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); + const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); + const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); + const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); + + const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); + const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); + const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); + const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +#endif + } + + *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); +#elif defined(__AVX2__) || defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); + __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(bx, by); + + // Multiply q with scale and accumulate +#if defined(__AVX2__) + acc = _mm256_fmadd_ps( d, q, acc ); +#else + acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); +#endif + } + + *s = hsum_float_8(acc); +#elif defined(__riscv_v_intrinsic) + float sumf = 0.0; + size_t vl = __riscv_vsetvl_e8m1(qk); + + for (int i = 0; i < nb; i++) { + // load elements + vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); + vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); + + vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); + + int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + + sumf += sumi*(WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)); + } + + *s = sumf; +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[i].qs[j]*y[i].qs[j]; + } + + sumf += sumi*(WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)); + } + + *s = sumf; +#endif +} + +#if QK_K == 256 +void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); + const uint8x16_t m4 = vdupq_n_u8(0xF); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + wsp_ggml_int8x16x2_t q2bytes; + uint8_t aux[16]; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint8_t * restrict sc = x[i].scales; + + const uint8x16_t mins_and_scales = vld1q_u8(sc); + const uint8x16_t scales = vandq_u8(mins_and_scales, m4); + vst1q_u8(aux, scales); + + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); + const wsp_ggml_int16x8x2_t q8sums = wsp_ggml_vld1q_s16_x2(y[i].bsums); + const wsp_ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))}; + const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])), + vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0]))); + const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])), + vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1]))); + sum += dmin * vaddvq_s32(vaddq_s32(s0, s1)); + + int isum = 0; + int is = 0; + +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#if defined(__ARM_FEATURE_DOTPROD) +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; +#else +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + {\ + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ + isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ + } +#endif + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + + for (int j = 0; j < QK_K/128; ++j) { + + const wsp_ggml_uint8x16x2_t q2bits = wsp_ggml_vld1q_u8_x2(q2); q2 += 32; + + wsp_ggml_int8x16x2_t q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); + MULTIPLY_ACCUM_WITH_SCALE(0); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); + + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + + is += 8; + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales8 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m256i mins = _mm256_cvtepi8_epi16(mins8); + const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc); + + const __m256i all_scales = _mm256_cvtepi8_epi16(scales8); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32; + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + const __m256i q2_0 = _mm256_and_si256(q2bits, m3); + const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3); + const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3); + const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3); + + __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2); + __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3); + + p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0); + p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1); + p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2); + p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3); + + p0 = _mm256_add_epi32(p0, p1); + p2 = _mm256_add_epi32(p2, p3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2)); + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(0x3); + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float dall = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // load mins and scales from block_q2_K.scales[QK_K/16] + const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m128i scales16 = _mm_and_si128(mins_and_scales, m4); + const __m128i mins16 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4); + const __m128i mins_0 = _mm_cvtepi8_epi16(mins16); + const __m128i mins_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(mins16, mins16)); + + // summs = y[i].bsums * (x[i].scales >> 4) in 16bits*8*2 to 32bits*4*2 + const __m128i summs_0 = _mm_madd_epi16(mins_0, _mm_loadu_si128((const __m128i*)&y[i].bsums[0])); + const __m128i summs_1 = _mm_madd_epi16(mins_1, _mm_loadu_si128((const __m128i*)&y[i].bsums[8])); + + // sumf += -dmin * summs in 32bits*8 + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(MM256_SET_M128I(summs_1, summs_0))), acc); + + const __m128i scales_0 = _mm_cvtepi8_epi16(scales16); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales16, scales16)); + const __m128i scales[2] = { scales_0, scales_1 }; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + + // load Q8 quants int8*16*8 from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // load 2bits*16*8 from block_q2_K.qs[QK_K/4] + __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_4 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_6 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + q2bits = _mm_loadu_si128((const __m128i*)q2); q2 += 16; + const __m128i q2_1 = _mm_and_si128(q2bits, m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_5 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_7 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + // isuml = q8[l] * ((q2[l] >> shift) & 3) in 8bits*16*8 to 16bits*8*8 + __m128i p0 = _mm_maddubs_epi16(q2_0, q8_0); + __m128i p1 = _mm_maddubs_epi16(q2_1, q8_1); + __m128i p2 = _mm_maddubs_epi16(q2_2, q8_2); + __m128i p3 = _mm_maddubs_epi16(q2_3, q8_3); + __m128i p4 = _mm_maddubs_epi16(q2_4, q8_4); + __m128i p5 = _mm_maddubs_epi16(q2_5, q8_5); + __m128i p6 = _mm_maddubs_epi16(q2_6, q8_6); + __m128i p7 = _mm_maddubs_epi16(q2_7, q8_7); + + // isum += (x[i].scales[is++] & 0xF) * isuml in 16bits*8*8 to 32bits*4*8 + __m128i shuffle = _mm_set1_epi16(0x0100); + p0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p0); + shuffle = _mm_add_epi16(shuffle, m2); + p1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p1); + shuffle = _mm_add_epi16(shuffle, m2); + p2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p2); + shuffle = _mm_add_epi16(shuffle, m2); + p3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p3); + shuffle = _mm_add_epi16(shuffle, m2); + p4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p4); + shuffle = _mm_add_epi16(shuffle, m2); + p5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p5); + shuffle = _mm_add_epi16(shuffle, m2); + p6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p6); + shuffle = _mm_add_epi16(shuffle, m2); + p7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p7); + + p0 = _mm_add_epi32(p0, p1); + p2 = _mm_add_epi32(p2, p3); + p4 = _mm_add_epi32(p4, p5); + p6 = _mm_add_epi32(p6, p7); + + // isum in 32bits*4*2 + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p0, p2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p4, p6)); + } + + // sumf += dall * isum - dmin * summs in 32bits + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&dall), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + uint8_t temp_01[32] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + const float dall = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + size_t vl = 16; + + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + + vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + + uint8_t is=0; + int isum=0; + + for (int j = 0; j < QK_K/128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03 , vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03 , vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03 , vl); + + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0+is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2+is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4+is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6+is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8+64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8+96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2+=32; q8+=128; is=8; + + } + + sumf += dall * isum; + + } + + *s = sumf; + +#else + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < 16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + int isum = 0; + int is = 0; + int d; + for (int k = 0; k < QK_K/128; ++k) { + int shift = 0; + for (int j = 0; j < 4; ++j) { + d = sc[is++] & 0xF; + int isuml = 0; + for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + d = sc[is++] & 0xF; + isuml = 0; + for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3); + isum += d * isuml; + shift += 2; + q8 += 32; + } + q2 += 32; + } + sumf += dall * isum - dmin * summs; + } + *s = sumf; +#endif +} + +#else + +void wsp_ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + + const block_q2_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m3 = vdupq_n_u8(0x3); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + wsp_ggml_int8x16x4_t q2bytes; + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const float dmin = -y[i].d * (float)x[i].dmin; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + + sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + + int isum1 = 0, isum2 = 0; + + const uint8x16_t q2bits = vld1q_u8(q2); + + const wsp_ggml_int8x16x4_t q8bytes = wsp_ggml_vld1q_s8_x4(q8); + + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3)); + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3)); + q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3)); + q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3)); + +#if defined(__ARM_FEATURE_DOTPROD) + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1]; + isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2]; + isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum1 += vaddvq_s16(p1) * scales[0]; + isum2 += vaddvq_s16(p2) * scales[1]; + + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum1 += vaddvq_s16(p3) * scales[2]; + isum2 += vaddvq_s16(p4) * scales[3]; +#endif + sum += d * (isum1 + isum2); + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m256i q2_0 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 2), q2bits), m3); + const __m256i q2_1 = _mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q2bits, 6), _mm_srli_epi16(q2bits, 4)), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0); + const __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1); + + const __m256i p_0 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 0)); + const __m256i p_1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p0, 1)); + const __m256i p_2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 0)); + const __m256i p_3 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(p1, 1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2), acc); + acc = _mm256_fmadd_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3), acc); + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t ud, um; + const uint8_t * restrict db = (const uint8_t *)&ud; + const uint8_t * restrict mb = (const uint8_t *)&um; + + float summs = 0; + + // TODO: optimize this + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + ud = (sc[0] >> 0) & 0x0f0f0f0f; + um = (sc[0] >> 4) & 0x0f0f0f0f; + + int32_t smin = mb[0] * y[i].bsums[0] + mb[1] * y[i].bsums[1] + mb[2] * y[i].bsums[2] + mb[3] * y[i].bsums[3]; + summs += dmin * smin; + + const __m128i q2bits = _mm_loadu_si128((const __m128i*)q2); + const __m128i q2_0 = _mm_and_si128(q2bits, m3); + const __m128i q2_1 = _mm_and_si128(_mm_srli_epi16(q2bits, 2), m3); + const __m128i q2_2 = _mm_and_si128(_mm_srli_epi16(q2bits, 4), m3); + const __m128i q2_3 = _mm_and_si128(_mm_srli_epi16(q2bits, 6), m3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p0 = _mm_maddubs_epi16(q2_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p1 = _mm_maddubs_epi16(q2_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p2 = _mm_maddubs_epi16(q2_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p3 = _mm_maddubs_epi16(q2_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m256i p_0 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p0, p0)), _mm_cvtepi16_epi32(p0)); + const __m256i p_1 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p1, p1)), _mm_cvtepi16_epi32(p1)); + const __m256i p_2 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p2, p2)), _mm_cvtepi16_epi32(p2)); + const __m256i p_3 = MM256_SET_M128I(_mm_cvtepi16_epi32(_mm_unpackhi_epi64(p3, p3)), _mm_cvtepi16_epi32(p3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[0]), _mm256_cvtepi32_ps(p_0)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[1]), _mm256_cvtepi32_ps(p_1)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[2]), _mm256_cvtepi32_ps(p_2)), acc); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d * db[3]), _mm256_cvtepi32_ps(p_3)), acc); + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + uint32_t aux32[2]; + const uint8_t * scales = (const uint8_t *)aux32; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const float dmin = -y[i].d * (float)x[i].dmin; + + const uint8_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + const uint32_t * restrict sc = (const uint32_t *)x[i].scales; + + aux32[0] = sc[0] & 0x0f0f0f0f; + aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f; + + sumf += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]); + + int isum1 = 0; + int isum2 = 0; + + size_t vl = 16; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + // load Q2 + vuint8mf2_t q2_x = __riscv_vle8_v_u8mf2(q2, vl); + + vint8mf2_t q2_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q2_x, 0x03, vl)); + vint8mf2_t q2_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q2_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q2_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q2_x, 0x6, vl), 0x03 , vl)); + + // load Q8, and take product with Q2 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q2_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q2_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q2_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q2_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m1_i16m1(p0, vzero, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m1_i16m1(p1, vzero, vl); + vint16m1_t vs_2 = __riscv_vredsum_vs_i16m1_i16m1(p2, vzero, vl); + vint16m1_t vs_3 = __riscv_vredsum_vs_i16m1_i16m1(p3, vzero, vl); + + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[0]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[1]; + isum1 += __riscv_vmv_x_s_i16m1_i16(vs_2) * scales[2]; + isum2 += __riscv_vmv_x_s_i16m1_i16(vs_3) * scales[3]; + + sumf += d * (isum1 + isum2); + + } + + *s = sumf; + +#else + + float sumf = 0; + + int isum[4]; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + + int summs = 0; + for (int j = 0; j < QK_K/16; ++j) { + summs += y[i].bsums[j] * (sc[j] >> 4); + } + + const float dall = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + isum[0] = isum[1] = isum[2] = isum[3] = 0; + for (int l = 0; l < 16; ++l) { + isum[0] += q8[l+ 0] * ((q2[l] >> 0) & 3); + isum[1] += q8[l+16] * ((q2[l] >> 2) & 3); + isum[2] += q8[l+32] * ((q2[l] >> 4) & 3); + isum[3] += q8[l+48] * ((q2[l] >> 6) & 3); + } + for (int l = 0; l < 4; ++l) { + isum[l] *= (sc[l] & 0xF); + } + sumf += dall * (isum[0] + isum[1] + isum[2] + isum[3]) - dmin * summs; + } + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 +void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + const int8_t m32 = 32; + + wsp_ggml_int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + wsp_ggml_uint8x16x2_t qhbits = wsp_ggml_vld1q_u8_x2(qh); + + wsp_ggml_uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= m32; + + for (int j = 0; j < QK_K/128; ++j) { + + const wsp_ggml_uint8x16x2_t q3bits = wsp_ggml_vld1q_u8_x2(q3); q3 += 32; + const wsp_ggml_int8x16x4_t q8bytes_1 = wsp_ggml_vld1q_s8_x4(q8); q8 += 64; + const wsp_ggml_int8x16x4_t q8bytes_2 = wsp_ggml_vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3]; +#else + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0])); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1])); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2])); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3]; +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + scale += 4; + + if (j == 0) { + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + } + + } + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i mone = _mm256_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + memcpy(aux, x[i].scales, 12); + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m256i all_scales = _mm256_cvtepi8_epi16(scales128); + const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0); + const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1); + const __m256i scales[2] = {MM256_SET_M128I(l_scales, l_scales), MM256_SET_M128I(h_scales, h_scales)}; + + // high bit + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask); + + // integer accumulator + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits + const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32; + + // prepare low and high bits + const __m256i q3l_0 = _mm256_and_si256(q3bits, m3); + const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3); + const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3); + const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3); + const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2); + ++bit; + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3); + + // accumulate + p16_0 = _mm256_add_epi32(p16_0, p16_1); + p16_2 = _mm256_add_epi32(p16_2, p16_3); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2)); + + } + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m32 = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + const uint32_t *aux; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + // Set up scales + aux = (const uint32_t *)x[i].scales; + __m128i scales128 = _mm_set_epi32( + ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4), + ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4), + (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4), + (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4)); + scales128 = _mm_sub_epi8(scales128, m32); + const __m128i scales_0 = _mm_cvtepi8_epi16(scales128); + const __m128i scales_1 = _mm_cvtepi8_epi16(_mm_unpackhi_epi64(scales128, scales128)); + const __m128i scales[2] = { scales_0, scales_1 }; + + // high bit *128*2 from block_q3_K.hmask[QK_K/8] + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].hmask[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].hmask[16]); + + // integer accumulator + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + for (int j = 0; j < QK_K/128; ++j) { + // load low 2 bits *64*2 from block_q3_K.qs[QK_K/4] + const __m128i q3bits_0 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + const __m128i q3bits_1 = _mm_loadu_si128((const __m128i*)q3); q3 += 16; + + // prepare low and high bits + const int bit = j << 2; + + const __m128i q3l_0 = _mm_and_si128(q3bits_0, m3); + const __m128i q3l_1 = _mm_and_si128(q3bits_1, m3); + const __m128i q3h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit)), bit), 2); + const __m128i q3h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit)), bit), 2); + + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 2), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 2), m3); + const __m128i q3h_2 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + const __m128i q3h_3 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+1)), bit+1), 2); + + const __m128i q3l_4 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 4), m3); + const __m128i q3l_5 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 4), m3); + const __m128i q3h_4 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + const __m128i q3h_5 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+2)), bit+2), 2); + + const __m128i q3l_6 = _mm_and_si128(_mm_srli_epi16(q3bits_0, 6), m3); + const __m128i q3l_7 = _mm_and_si128(_mm_srli_epi16(q3bits_1, 6), m3); + const __m128i q3h_6 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_0, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + const __m128i q3h_7 = _mm_slli_epi16(_mm_srli_epi16(_mm_andnot_si128(hbits_1, _mm_slli_epi16(mone, bit+3)), bit+3), 2); + + // load Q8 quants from block_q8_K.qs[QK_K] + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(q3h_4, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(q3h_5, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(q3h_6, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(q3h_7, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q3l_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q3l_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q3l_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q3l_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + // multiply with scales + __m128i shuffle = _mm_set1_epi16(0x0100); + p16_0 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_0); + shuffle = _mm_add_epi16(shuffle, m2); + p16_1 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_1); + shuffle = _mm_add_epi16(shuffle, m2); + p16_2 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_2); + shuffle = _mm_add_epi16(shuffle, m2); + p16_3 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_3); + shuffle = _mm_add_epi16(shuffle, m2); + p16_4 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_4); + shuffle = _mm_add_epi16(shuffle, m2); + p16_5 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_5); + shuffle = _mm_add_epi16(shuffle, m2); + p16_6 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_6); + shuffle = _mm_add_epi16(shuffle, m2); + p16_7 = _mm_madd_epi16(_mm_shuffle_epi8(scales[j], shuffle), p16_7); + + // accumulate + p16_0 = _mm_add_epi32(p16_0, p16_1); + p16_2 = _mm_add_epi32(p16_2, p16_3); + p16_4 = _mm_add_epi32(p16_4, p16_5); + p16_6 = _mm_add_epi32(p16_6, p16_7); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_4, p16_6)); + + } + + // multiply with block scale and accumulate + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint32_t aux[3]; + uint32_t utmp[4]; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + + int sum_t = 0; + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_m(vmask_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_m(vmask_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_m(vmask_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_m(vmask_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + // retreive lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; + + } + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + + } + + *s = sumf; + +#else + // scalar version + // This function is written like this so the compiler can manage to vectorize most of it + // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the + // manually vectorized version above. Every other version I tried would run at least 4 times slower. + // The ideal situation would be if we could just write the code once, and the compiler would + // automatically produce the best possible set of machine instructions, instead of us having to manually + // write vectorized versions for AVX, ARM_NEON, etc. + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + uint32_t auxs[4]; + const int8_t * scales = (const int8_t*)auxs; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3; + for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4); + a += 32; m <<= 1; + q3 += 32; + } + a = aux8; + + memcpy(auxs, x[i].scales, 12); + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l]; + q8 += 8; a += 8; + } + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} + +#else + +void wsp_ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q3_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const uint8x16_t mh = vdupq_n_u8(4); + + wsp_ggml_int8x16x4_t q3bytes; + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + wsp_ggml_uint8x16x4_t q3h; + + const uint8x8_t hbits = vld1_u8(x[i].hmask); + const uint8x16_t q3bits = vld1q_u8(x[i].qs); + const wsp_ggml_int8x16x4_t q8bytes = wsp_ggml_vld1q_s8_x4(y[i].qs); + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * (float)x[i].d; + + const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); + q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q3h.val[1] = vandq_u8(mh, htmp); + q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); + + q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); + q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); + q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2])); + q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[2]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[1]; + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3]; +#else + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[2] + vaddvq_s16(p2) * scales[1] + vaddvq_s16(p3) * scales[3]; +#endif + + sum += d * isum; + + } + + *s = sum; + +#elif defined __AVX2__ + + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m1 = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m256i scale_0 = MM256_SET_M128I(_mm_set1_epi16(aux8[2] - 8), _mm_set1_epi16(aux8[0] - 8)); + const __m256i scale_1 = MM256_SET_M128I(_mm_set1_epi16(aux8[3] - 8), _mm_set1_epi16(aux8[1] - 8)); + + memcpy(&aux64, x[i].hmask, 8); + + const __m128i haux = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m256i q3h_0 = MM256_SET_M128I(_mm_srli_epi16(haux, 2), haux); + __m256i q3h_1 = _mm256_srli_epi16(q3h_0, 4); + q3h_0 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_0, m1), 2); + q3h_1 = _mm256_slli_epi16(_mm256_andnot_si256(q3h_1, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m256i q3aux = MM256_SET_M128I(_mm_srli_epi16(q3bits, 2), q3bits); + const __m256i q3l_0 = _mm256_and_si256(q3aux, m3); + const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3aux, 4), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0); + const __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + // multiply with scales + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + p16_0 = _mm256_add_epi32(p16_0, p16_1); + + // multiply with block scale and accumulate + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16_0), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m1 = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + uint64_t aux64; + + uint16_t aux16[2]; + const int8_t * aux8 = (const int8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + const __m128i scale_0 = _mm_set1_epi16(aux8[0] - 8); + const __m128i scale_1 = _mm_set1_epi16(aux8[2] - 8); + const __m128i scale_2 = _mm_set1_epi16(aux8[1] - 8); + const __m128i scale_3 = _mm_set1_epi16(aux8[3] - 8); + + memcpy(&aux64, x[i].hmask, 8); + + __m128i q3h_0 = _mm_set_epi64x(aux64 >> 1, aux64 >> 0); + __m128i q3h_1 = _mm_srli_epi16(q3h_0, 2); + __m128i q3h_2 = _mm_srli_epi16(q3h_0, 4); + __m128i q3h_3 = _mm_srli_epi16(q3h_0, 6); + q3h_0 = _mm_slli_epi16(_mm_andnot_si128(q3h_0, m1), 2); + q3h_1 = _mm_slli_epi16(_mm_andnot_si128(q3h_1, m1), 2); + q3h_2 = _mm_slli_epi16(_mm_andnot_si128(q3h_2, m1), 2); + q3h_3 = _mm_slli_epi16(_mm_andnot_si128(q3h_3, m1), 2); + + // load low 2 bits + const __m128i q3bits = _mm_loadu_si128((const __m128i*)q3); + + // prepare low and high bits + const __m128i q3l_0 = _mm_and_si128(q3bits, m3); + const __m128i q3l_1 = _mm_and_si128(_mm_srli_epi16(q3bits, 2), m3); + const __m128i q3l_2 = _mm_and_si128(_mm_srli_epi16(q3bits, 4), m3); + const __m128i q3l_3 = _mm_and_si128(_mm_srli_epi16(q3bits, 6), m3); + + // load Q8 quants + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm_maddubs_epi16, + // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set, + // and 2 if the high bit was set) + const __m128i q8s_0 = _mm_maddubs_epi16(q3h_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i q8s_1 = _mm_maddubs_epi16(q3h_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i q8s_2 = _mm_maddubs_epi16(q3h_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i q8s_3 = _mm_maddubs_epi16(q3h_3, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q3l_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q3l_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q3l_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q3l_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + // multiply with scales + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_1, p16_1); + p16_2 = _mm_madd_epi16(scale_2, p16_2); + p16_3 = _mm_madd_epi16(scale_3, p16_3); + + p16_0 = _mm_add_epi32(p16_0, p16_2); + p16_1 = _mm_add_epi32(p16_1, p16_3); + __m256i p16 = MM256_SET_M128I(p16_1, p16_0); + + // multiply with block scale and accumulate + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(p16)), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + uint16_t aux16[2]; + int8_t * scales = (int8_t *)aux16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q3 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t a = *(const uint16_t *)x[i].scales; + aux16[0] = a & 0x0f0f; + aux16[1] = (a >> 4) & 0x0f0f; + + for (int j = 0; j < 4; ++j) scales[j] -= 8; + + int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[2] * y[i].bsums[1] + scales[1] * y[i].bsums[2] + scales[3] * y[i].bsums[3]); + + const float d = y[i].d * (float)x[i].d; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load qh + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(x[i].hmask, 8); + vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + + size_t vl = 16; + + // extend and combine both qh_x1 and qh_x2 + vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(qh_x, 0x4, vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x4, vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x4, vl); + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vuint8mf2_t q3h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x3, vl), qh_0, vl); + vuint8mf2_t q3h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 2, vl), 0x3, vl), qh_1, vl); + vuint8mf2_t q3h_2 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 4, vl), 0x3, vl), qh_2, vl); + vuint8mf2_t q3h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), qh_3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_0); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_1); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_2); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(q3h_3); + + // load Q8 and take product with Q3 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q3_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q3_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q3_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q3_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scales[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scales[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scales[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scales[3]; + + sumf += d * isum; + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + int32_t scales[4]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict hm = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 8; ++l) { + a[l+ 0] = (int8_t)((q3[l+0] >> 0) & 3) - (hm[l] & 0x01 ? 0 : 4); + a[l+ 8] = (int8_t)((q3[l+8] >> 0) & 3) - (hm[l] & 0x02 ? 0 : 4); + a[l+16] = (int8_t)((q3[l+0] >> 2) & 3) - (hm[l] & 0x04 ? 0 : 4); + a[l+24] = (int8_t)((q3[l+8] >> 2) & 3) - (hm[l] & 0x08 ? 0 : 4); + a[l+32] = (int8_t)((q3[l+0] >> 4) & 3) - (hm[l] & 0x10 ? 0 : 4); + a[l+40] = (int8_t)((q3[l+8] >> 4) & 3) - (hm[l] & 0x20 ? 0 : 4); + a[l+48] = (int8_t)((q3[l+0] >> 6) & 3) - (hm[l] & 0x40 ? 0 : 4); + a[l+56] = (int8_t)((q3[l+8] >> 6) & 3) - (hm[l] & 0x80 ? 0 : 4); + } + + scales[0] = (x[i].scales[0] & 0xF) - 8; + scales[1] = (x[i].scales[0] >> 4) - 8; + scales[2] = (x[i].scales[1] & 0xF) - 8; + scales[3] = (x[i].scales[1] >> 4) - 8; + + memset(aux32, 0, 8*sizeof(int32_t)); + for (int j = 0; j < QK_K/16; ++j) { + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] += q8[l] * a[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux32[l] += scales[j] * aux16[l]; + } + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; + +#endif + +} +#endif + +#if QK_K == 256 +void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + wsp_ggml_int8x16x2_t q4bytes; + wsp_ggml_int8x16x2_t q8bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + + uint32x2_t mins8 = { 0 }; + mins8 = vset_lane_u32(utmp[1] & kmask1, mins8, 0); + mins8 = vset_lane_u32(((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4), mins8, 1); + + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[0] &= kmask1; + + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8))); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + sumf -= dmin * vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + int32_t sumi1 = 0; + int32_t sumi2 = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); q4 += 32; + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + sumi1 += vaddvq_s32(p1) * scales[2*j+0]; + + q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + + sumi2 += vaddvq_s32(p2) * scales[2*j+1]; +#else + q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0]; + + q8bytes = wsp_ggml_vld1q_s8_x2(q8); q8 += 32; + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1]; + +#endif + } + + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + __m256i sumi = _mm256_setzero_si256(); + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + p16l = _mm256_madd_epi16(scale_l, p16l); + + const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + p16h = _mm256_madd_epi16(scale_h, p16h); + const __m256i sumj = _mm256_add_epi32(p16l, p16h); + + sumi = _mm256_add_epi32(sumi, sumj); + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(0x2); + + __m256 acc = _mm256_setzero_ps(); + __m128 acc_m = _mm_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + acc_m = _mm_add_ps(_mm_mul_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod)), acc_m); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_l = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_h = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + __m128i q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_0 = _mm_and_si128(q4bits, m4); + const __m128i q4h_0 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + q4bits = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4l_1 = _mm_and_si128(q4bits, m4); + const __m128i q4h_1 = _mm_and_si128(_mm_srli_epi16(q4bits, 4), m4); + + const __m128i q8l_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16l = _mm_maddubs_epi16(q4l_0, q8l_0); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_0 = _mm_add_epi32(sumi_0, p16l); + const __m128i q8l_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16l = _mm_maddubs_epi16(q4l_1, q8l_1); + p16l = _mm_madd_epi16(scale_l, p16l); + sumi_1 = _mm_add_epi32(sumi_1, p16l); + + const __m128i q8h_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16h = _mm_maddubs_epi16(q4h_0, q8h_0); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_0 = _mm_add_epi32(sumi_0, p16h); + const __m128i q8h_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + p16h = _mm_maddubs_epi16(q4h_1, q8h_1); + p16h = _mm_madd_epi16(scale_h, p16h); + sumi_1 = _mm_add_epi32(sumi_1, p16h); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m)); + acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m)); + + *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m); + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + size_t vl = 8; + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + vl = 32; + + int32_t sum_1 = 0; + int32_t sum_2 = 0; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + + q4 += 32; q8 += 64; + + } + + sumf += d*(sum_1 + sum_2); + + } + + *s = sumf; + +#else + + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + a += 32; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + a += 32; q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = WSP_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#else +void wsp_ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q4_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + +#ifdef __ARM_FEATURE_DOTPROD + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + float sumf = 0; + + wsp_ggml_int8x16x2_t q4bytes; + wsp_ggml_int8x16x4_t q8bytes; + + float sum_mins = 0.f; + + uint16_t aux16[2]; + const uint8_t * restrict scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + const int32_t summi = scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3]); + sum_mins += y[i].d * (float)x[i].d[1] * summi; + + const float d = y[i].d * (float)x[i].d[0]; + + const wsp_ggml_uint8x16x2_t q4bits = wsp_ggml_vld1q_u8_x2(q4); + +#ifdef __ARM_FEATURE_DOTPROD + q8bytes = wsp_ggml_vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + + const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]); + const int32_t sumi1 = vaddvq_s32(p1) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + + const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]); + const int32_t sumi2 = vaddvq_s32(p2) * scales[1]; + +#else + q8bytes = wsp_ggml_vld1q_s8_x4(q8); + q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b)); + q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b)); + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + int32_t sumi1 = vaddvq_s16(vaddq_s16(p0, p1)) * scales[0]; + + q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4)); + q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4)); + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3]))); + int32_t sumi2 = vaddvq_s16(vaddq_s16(p2, p3)) * scales[1]; + +#endif + sumf += d * (sumi1 + sumi2); + + } + + *s = sumf - sum_mins; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; + const float m = WSP_GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m256i q4l = _mm256_and_si256(q4bits, m4); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4); + + const __m256i q8l = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8h = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16l = _mm256_maddubs_epi16(q4l, q8l); + const __m256i p16h = _mm256_maddubs_epi16(q4h, q8h); + + const __m256i p32l = _mm256_madd_epi16(_mm256_set1_epi16(scales[0]), p16l); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32l), acc); + + const __m256i p32h = _mm256_madd_epi16(_mm256_set1_epi16(scales[1]), p16h); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(p32h), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0; + + uint16_t aux16[2]; + const uint8_t * scales = (const uint8_t *)aux16; + + for (int i = 0; i < nb; ++i) { + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d[0]) * y[i].d; + const float m = WSP_GGML_FP16_TO_FP32(x[i].d[1]) * y[i].d; + const __m256 vd = _mm256_set1_ps(d); + + const uint16_t * a = (const uint16_t *)x[i].scales; + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + + summs += m * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bits_0 = _mm256_extractf128_si256(q4bits, 0); + const __m128i q4bits_1 = _mm256_extractf128_si256(q4bits, 1); + const __m128i q4_0 = _mm_and_si128(q4bits_0, m4); + const __m128i q4_1 = _mm_and_si128(q4bits_1, m4); + const __m128i q4_2 = _mm_and_si128(_mm_srli_epi16(q4bits_0, 4), m4); + const __m128i q4_3 = _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + const __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + const __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + const __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + const __m128i p32_0 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_0); + const __m128i p32_1 = _mm_madd_epi16(_mm_set1_epi16(scales[0]), p16_1); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_1, p32_0))), acc); + + const __m128i p32_2 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_2); + const __m128i p32_3 = _mm_madd_epi16(_mm_set1_epi16(scales[1]), p16_3); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(MM256_SET_M128I(p32_3, p32_2))), acc); + + } + + *s = hsum_float_8(acc) - summs; + +#elif defined __riscv_v_intrinsic + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d[0]); + + size_t vl = 32; + + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q4_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t va_0 = __riscv_vwmul_vv_i16m2(q4_a, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m1_t aux1 = __riscv_vredsum_vs_i16m2_i16m1(va_0, vzero, vl); + + sumf += d*scales[0]*__riscv_vmv_x_s_i16m1_i16(aux1); + + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q4_s = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t va_1 = __riscv_vwmul_vv_i16m2(q4_s, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m1_t aux2 = __riscv_vredsum_vs_i16m2_i16m1(va_1, vzero, vl); + + sumf += d*scales[1]*__riscv_vmv_x_s_i16m1_i16(aux2); + + } + + *s = sumf; + +#else + + uint8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + uint16_t s16[2]; + const uint8_t * restrict scales = (const uint8_t *)s16; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + uint8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) a[l+ 0] = q4[l] & 0xF; + for (int l = 0; l < 32; ++l) a[l+32] = q4[l] >> 4; + + const uint16_t * restrict b = (const uint16_t *)x[i].scales; + s16[0] = b[0] & 0x0f0f; + s16[1] = (b[0] >> 4) & 0x0f0f; + + sumf -= y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d[1]) * (scales[2] * (y[i].bsums[0] + y[i].bsums[1]) + scales[3] * (y[i].bsums[2] + y[i].bsums[3])); + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d[0]); + + for (int j = 0; j < QK_K/32; ++j) { + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + q8 += 16; a += 16; + for (int l = 0; l < 16; ++l) aux16[l] += q8[l] * a[l]; + q8 += 16; a += 16; + const float dl = d * scales[j]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[l+8]); + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + +#if QK_K == 256 +void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; + + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + wsp_ggml_int8x16x4_t q5bytes; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8)); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8); + const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8)); + const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)), + vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins))); + int32_t sumi_mins = vaddvq_s32(prod); + + const uint8_t * scales = (const uint8_t *)utmp; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + wsp_ggml_uint8x16x2_t qhbits = wsp_ggml_vld1q_u8_x2(qh); + + wsp_ggml_uint8x16x4_t q5h; + + int32_t sumi = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const wsp_ggml_uint8x16x2_t q5bits = wsp_ggml_vld1q_u8_x2(q5); q5 += 32; + const wsp_ggml_int8x16x4_t q8bytes = wsp_ggml_vld1q_s8_x4(q8); q8 += 64; + + q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3); + q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3); + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2); + + q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); + q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); + q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); + q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++; + sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++; +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++; + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++; +#endif + } + + sumf += d * sumi - dmin * sumi_mins; + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + +#if QK_K == 256 + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +#else + // TODO + const float d = 0, dmin = 0; +#endif + + const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0])); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0); + const __m256i scales = MM256_SET_M128I(sc128, sc128); + + const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh); + __m256i hmask = mone; + + __m256i sumi = _mm256_setzero_si256(); + + int bit = 0; + + for (int j = 0; j < QK_K/64; ++j) { + + const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0)); + const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1)); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32; + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1); + + p16_0 = _mm256_madd_epi16(scale_0, p16_0); + p16_1 = _mm256_madd_epi16(scale_1, p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + } + + __m256 vd = _mm256_set1_ps(d); + acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mzero = _mm_setzero_si128(); + const __m128i mone = _mm_set1_epi8(1); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + float summs = 0.f; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * WSP_GGML_FP16_TO_FP32(x[i].dmin); + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + const __m128i utmps = _mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]); + const __m128i scales = _mm_cvtepu8_epi16(utmps); + const __m128i mins = _mm_cvtepu8_epi16(_mm_unpackhi_epi64(utmps, utmps)); + + const __m128i q8sums_0 = _mm_loadu_si128((const __m128i*)&y[i].bsums[0]); + const __m128i q8sums_1 = _mm_loadu_si128((const __m128i*)&y[i].bsums[8]); + const __m128i q8s = _mm_hadd_epi16(q8sums_0, q8sums_1); + const __m128i prod = _mm_madd_epi16(mins, q8s); + const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero); + summs += dmin * _mm_extract_epi32(hsum, 0); + + const __m128i hbits_0 = _mm_loadu_si128((const __m128i*)&x[i].qh[0]); + const __m128i hbits_1 = _mm_loadu_si128((const __m128i*)&x[i].qh[16]); + __m128i hmask = mone; + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + int bit = 0; + + __m128i shuffle = _mm_set1_epi16(0x0100); + for (int j = 0; j < QK_K/64; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi16(shuffle, m2); + + const __m128i q5bits_0 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + const __m128i q5bits_1 = _mm_loadu_si128((const __m128i*)q5); q5 += 16; + + __m128i q5l_0 = _mm_and_si128(q5bits_0, m4); + __m128i q5l_1 = _mm_and_si128(q5bits_1, m4); + __m128i q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + __m128i q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + __m128i q5_0 = _mm_add_epi8(q5l_0, q5h_0); + __m128i q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_0 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q5_1, q8_1); + p16_0 = _mm_madd_epi16(scale_0, p16_0); + p16_1 = _mm_madd_epi16(scale_0, p16_1); + + q5l_0 = _mm_and_si128(_mm_srli_epi16(q5bits_0, 4), m4); + q5l_1 = _mm_and_si128(_mm_srli_epi16(q5bits_1, 4), m4); + q5h_0 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_0, hmask), bit), 4); + q5h_1 = _mm_slli_epi16(_mm_srli_epi16(_mm_and_si128(hbits_1, hmask), bit++), 4); + q5_0 = _mm_add_epi8(q5l_0, q5h_0); + q5_1 = _mm_add_epi8(q5l_1, q5h_1); + hmask = _mm_slli_epi16(hmask, 1); + + q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + __m128i p16_2 = _mm_maddubs_epi16(q5_0, q8_0); + __m128i p16_3 = _mm_maddubs_epi16(q5_1, q8_1); + p16_2 = _mm_madd_epi16(scale_1, p16_2); + p16_3 = _mm_madd_epi16(scale_1, p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + } + + __m256 vd = _mm256_set1_ps(d); + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(vd, _mm256_cvtepi32_ps(sumi)), acc); + + } + + *s = hsum_float_8(acc) + summs; + +#elif defined __riscv_v_intrinsic + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + float sumf = 0; + float sums = 0.0; + + size_t vl; + + for (int i = 0; i < nb; ++i) { + + vl = 8; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const float dmin = WSP_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + + vl = 32; + int32_t aux32 = 0; + int is = 0; + + uint8_t m = 1; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(hm, vl); + + for (int j = 0; j < QK_K/64; ++j) { + // load Q5 and Q8 + vuint8m1_t q5_x = __riscv_vle8_v_u8m1(q5, vl); + vint8m1_t q8_y1 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_y2 = __riscv_vle8_v_i8m1(q8+32, vl); + + // compute mask for addition + vint8m1_t q5_a = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q5_x, 0x0F, vl)); + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmsne_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q5_m1 = __riscv_vadd_vx_i8m1_m(vmask_1, q5_a, 16, vl); + m <<= 1; + + vint8m1_t q5_l = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q5_x, 0x04, vl)); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmsne_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q5_m2 = __riscv_vadd_vx_i8m1_m(vmask_2, q5_l, 16, vl); + m <<= 1; + + vint16m2_t v0 = __riscv_vwmul_vv_i16m2(q5_m1, q8_y1, vl); + vint16m2_t v1 = __riscv_vwmul_vv_i16m2(q5_m2, q8_y2, vl); + + vint32m4_t vs1 = __riscv_vwmul_vx_i32m4(v0, scales[is++], vl); + vint32m4_t vs2 = __riscv_vwmul_vx_i32m4(v1, scales[is++], vl); + + vint32m1_t vacc1 = __riscv_vredsum_vs_i32m4_i32m1(vs1, vzero, vl); + vint32m1_t vacc2 = __riscv_vredsum_vs_i32m4_i32m1(vs2, vzero, vl); + + aux32 += __riscv_vmv_x_s_i32m1_i32(vacc1) + __riscv_vmv_x_s_i32m1_i32(vacc2); + q5 += 32; q8 += 64; + + } + + vfloat32m1_t vaux = __riscv_vfmul_vf_f32m1(__riscv_vfmv_v_f_f32m1(aux32, 1), d, 1); + sums += __riscv_vfmv_f_s_f32m1_f32(vaux); + + } + + *s = sumf+sums; + +#else + + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + uint8_t m = 1; + for (int j = 0; j < QK_K/64; ++j) { + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4); + for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0); + a += 32; m <<= 1; + q4 += 32; + } + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + int sumi = 0; + for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2]; + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/32; ++j) { + int32_t scale = scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + const float dmin = WSP_GGML_FP16_TO_FP32(x[i].dmin) * y[i].d; + sumf -= dmin * sumi; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#else + +void wsp_ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q5_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const uint8x16_t mh = vdupq_n_u8(16); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t mzero = vdupq_n_s32(0); +#endif + + wsp_ggml_int8x16x4_t q5bytes; + wsp_ggml_uint8x16x4_t q5h; + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const uint8x8_t qhbits = vld1_u8(qh); + + const wsp_ggml_uint8x16x2_t q5bits = wsp_ggml_vld1q_u8_x2(q5); + const wsp_ggml_int8x16x4_t q8bytes = wsp_ggml_vld1q_s8_x4(q8); + + const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); + q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vbicq_u8(mh, htmp); + q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); + + q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); + q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); + q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); + q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + int32_t sumi1 = sc[0] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0])); + int32_t sumi2 = sc[1] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[1], q8bytes.val[1])); + int32_t sumi3 = sc[2] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2])); + int32_t sumi4 = sc[3] * vaddvq_s32(vdotq_s32(mzero, q5bytes.val[3], q8bytes.val[3])); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + +#else + + const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + int32_t sumi = sc[0] * vaddvq_s16(p0) + sc[1] * vaddvq_s16(p1); + + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + sumi += sc[2] * vaddvq_s16(p2) + sc[3] * vaddvq_s16(p3); + + sumf += d*sumi; +#endif + + } + + *s = sumf; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i mone = _mm256_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + + const __m256i scale_l = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[1]), _mm_set1_epi16(x[i].scales[0])); + const __m256i scale_h = MM256_SET_M128I(_mm_set1_epi16(x[i].scales[3]), _mm_set1_epi16(x[i].scales[2])); + + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m256i haux256 = MM256_SET_M128I(_mm_srli_epi16(haux128, 2), haux128); + + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_andnot_si256(haux256, mone), 4); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_andnot_si256(_mm256_srli_epi16(haux256, 4), mone), 4); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, m4); + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m256i p16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5l_0, q8_0)); + const __m256i p16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5l_1, q8_1)); + const __m256i s16_0 = _mm256_madd_epi16(scale_l, _mm256_maddubs_epi16(q5h_0, q8_0)); + const __m256i s16_1 = _mm256_madd_epi16(scale_h, _mm256_maddubs_epi16(q5h_1, q8_1)); + + const __m256i dot = _mm256_sub_epi32(_mm256_add_epi32(p16_0, p16_1), _mm256_add_epi32(s16_0, s16_1)); + + acc = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(dot), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i mone = _mm_set1_epi8(1); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const uint8_t * restrict q5 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); + + const __m128i scale_0 = _mm_set1_epi16(x[i].scales[0]); + const __m128i scale_1 = _mm_set1_epi16(x[i].scales[1]); + const __m128i scale_2 = _mm_set1_epi16(x[i].scales[2]); + const __m128i scale_3 = _mm_set1_epi16(x[i].scales[3]); + + int64_t aux64; + memcpy(&aux64, x[i].qh, 8); + const __m128i haux128_0 = _mm_set_epi64x(aux64 >> 1, aux64); + const __m128i haux128_1 = _mm_srli_epi16(haux128_0, 2); + + const __m128i q5h_0 = _mm_slli_epi16(_mm_andnot_si128(haux128_0, mone), 4); + const __m128i q5h_1 = _mm_slli_epi16(_mm_andnot_si128(haux128_1, mone), 4); + const __m128i q5h_2 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_0, 4), mone), 4); + const __m128i q5h_3 = _mm_slli_epi16(_mm_andnot_si128(_mm_srli_epi16(haux128_1, 4), mone), 4); + + const __m128i q5l_0 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 0), m4); + const __m128i q5l_1 = _mm_and_si128(_mm256_extractf128_si256(q5bits, 1), m4); + const __m128i q5l_2 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 0), 4), m4); + const __m128i q5l_3 = _mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q5bits, 1), 4), m4); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + const __m128i p16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5l_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i p16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5l_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i p16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5l_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i p16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5l_3, _mm256_extractf128_si256(q8_1, 1))); + const __m128i s16_0 = _mm_madd_epi16(scale_0, _mm_maddubs_epi16(q5h_0, _mm256_extractf128_si256(q8_0, 0))); + const __m128i s16_1 = _mm_madd_epi16(scale_1, _mm_maddubs_epi16(q5h_1, _mm256_extractf128_si256(q8_0, 1))); + const __m128i s16_2 = _mm_madd_epi16(scale_2, _mm_maddubs_epi16(q5h_2, _mm256_extractf128_si256(q8_1, 0))); + const __m128i s16_3 = _mm_madd_epi16(scale_3, _mm_maddubs_epi16(q5h_3, _mm256_extractf128_si256(q8_1, 1))); + + const __m128i dot_0 = _mm_sub_epi32(_mm_add_epi32(p16_0, p16_2), _mm_add_epi32(s16_0, s16_2)); + const __m128i dot_1 = _mm_sub_epi32(_mm_add_epi32(p16_1, p16_3), _mm_add_epi32(s16_1, s16_3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(MM256_SET_M128I(dot_1, dot_0))), acc); + + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * (float)x[i].d; + const int8_t * sc = x[i].scales; + + const uint8_t * restrict q5 = x[i].qs; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load qh + vuint8mf4_t qh_x1 = __riscv_vle8_v_u8mf4(qh, 8); + vuint8mf2_t qh_x2 = __riscv_vlmul_ext_v_u8mf4_u8mf2(__riscv_vsrl_vx_u8mf4(qh_x1, 1, 8)); + + size_t vl = 16; + + // combine both qh_1 and qh_2 + vuint8mf2_t qh_x = __riscv_vslideup_vx_u8mf2(__riscv_vlmul_ext_v_u8mf4_u8mf2(qh_x1), qh_x2, vl/2, vl); + + vuint8mf2_t qh_h0 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + vuint8mf2_t qh_h1 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsll_vx_u8mf2(qh_x, 0x2, vl), vl), 16, vl); + vuint8mf2_t qh_h2 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(qh_x, vl), 16, vl); + vuint8mf2_t qh_h3 = __riscv_vand_vx_u8mf2(__riscv_vnot_v_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), vl), 16, vl); + + vint8mf2_t qh_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h0); + vint8mf2_t qh_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h1); + vint8mf2_t qh_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h2); + vint8mf2_t qh_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(qh_h3); + + // load q5 + vuint8mf2_t q5_x1 = __riscv_vle8_v_u8mf2(q5, vl); + vuint8mf2_t q5_x2 = __riscv_vle8_v_u8mf2(q5+16, vl); + + vint8mf2_t q5s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x1, 0xF, vl)); + vint8mf2_t q5s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q5_x2, 0xF, vl)); + vint8mf2_t q5s_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x1, 0x4, vl)); + vint8mf2_t q5s_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(q5_x2, 0x4, vl)); + + vint8mf2_t q5_0 = __riscv_vsub_vv_i8mf2(q5s_0, qh_0, vl); + vint8mf2_t q5_1 = __riscv_vsub_vv_i8mf2(q5s_1, qh_1, vl); + vint8mf2_t q5_2 = __riscv_vsub_vv_i8mf2(q5s_2, qh_2, vl); + vint8mf2_t q5_3 = __riscv_vsub_vv_i8mf2(q5s_3, qh_3, vl); + + // load Q8 and multiply it with Q5 + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q5_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q5_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q5_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q5_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + int32_t sumi1 = sc[0] * __riscv_vmv_x_s_i32m1_i32(vs_0); + int32_t sumi2 = sc[1] * __riscv_vmv_x_s_i32m1_i32(vs_1); + int32_t sumi3 = sc[2] * __riscv_vmv_x_s_i32m1_i32(vs_2); + int32_t sumi4 = sc[3] * __riscv_vmv_x_s_i32m1_i32(vs_3); + + sumf += d * (sumi1 + sumi2 + sumi3 + sumi4); + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[16]; + float sums [8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].qs; + const uint8_t * restrict hm = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + int8_t * restrict a = aux8; + for (int l = 0; l < 32; ++l) { + a[l+ 0] = q4[l] & 0xF; + a[l+32] = q4[l] >> 4; + } + for (int is = 0; is < 8; ++is) { + uint8_t m = 1 << is; + for (int l = 0; l < 8; ++l) a[8*is + l] -= (hm[l] & m ? 0 : 16); + } + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + const int8_t * restrict sc = x[i].scales; + + for (int j = 0; j < QK_K/16; ++j) { + const float dl = d * sc[j]; + for (int l = 0; l < 16; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) sums[l] += dl * (aux16[l] + aux16[8+l]); + q8 += 16; a += 16; + } + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} +#endif + + +#if QK_K == 256 +void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + //const int8x16_t m32s = vdupq_n_s8(32); + + const uint8x16_t mone = vdupq_n_u8(3); + + wsp_ggml_int8x16x4_t q6bytes; + wsp_ggml_uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + const wsp_ggml_int16x8x2_t q8sums = wsp_ggml_vld1q_s16_x2(y[i].bsums); + const int8x16_t scales = vld1q_s8(scale); + const wsp_ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))}; + + const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])), + vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))), + vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])), + vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1])))); + int32_t isum_mins = vaddvq_s32(prod); + + int32_t isum = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + wsp_ggml_uint8x16x2_t qhbits = wsp_ggml_vld1q_u8_x2(qh); qh += 32; + wsp_ggml_uint8x16x4_t q6bits = wsp_ggml_vld1q_u8_x4(q6); q6 += 64; + wsp_ggml_int8x16x4_t q8bytes = wsp_ggml_vld1q_s8_x4(q8); q8 += 64; + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 2); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + q8bytes = wsp_ggml_vld1q_s8_x4(q8); q8 += 64; + + shifted = vshrq_n_u8(qhbits.val[0], 4); + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 4); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[0], 6); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits.val[1], 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s); + //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s); + //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s); + //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s); + q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])); + q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])); + q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])); + q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; + scale += 4; + + //for (int l = 0; l < 4; ++l) { + // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]); + // isum += vaddvq_s32(p) * *scale++; + //} +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + scale += 2; + + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1]; + scale += 2; +#endif + + } + //sum += isum * d_all * y[i].d; + sum += d_all * y[i].d * (isum - 32 * isum_mins); + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m256i sumi = _mm256_setzero_si256(); + + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; + const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); + const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); + const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); + __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); + __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + p16_2 = _mm256_sub_epi16(p16_2, q8s_2); + p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2); + p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3)); + + } + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m3 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + const __m128i m2 = _mm_set1_epi8(2); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + __m128i shuffle = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + const __m128i q4bitsH_1 = _mm_loadu_si128((const __m128i*)qh); qh += 16; + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH_0, m3), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(q4bitsH_1, m3), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 2), m3), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 2), m3), 4); + const __m128i q4h_4 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 4), m3), 4); + const __m128i q4h_5 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 4), m3), 4); + const __m128i q4h_6 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_0, 6), m3), 4); + const __m128i q4h_7 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH_1, 6), m3), 4); + + const __m128i q4bits1_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits1_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_0 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + const __m128i q4bits2_1 = _mm_loadu_si128((const __m128i*)q4); q4 += 16; + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(q4bits1_0, m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(q4bits1_1, m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(q4bits2_0, m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(q4bits2_1, m4), q4h_3); + const __m128i q4_4 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_0, 4), m4), q4h_4); + const __m128i q4_5 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits1_1, 4), m4), q4h_5); + const __m128i q4_6 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_0, 4), m4), q4h_6); + const __m128i q4_7 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(q4bits2_1, 4), m4), q4h_7); + + const __m128i q8_0 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_1 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_2 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_3 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_4 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_5 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_6 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + const __m128i q8_7 = _mm_loadu_si128((const __m128i*)q8); q8 += 16; + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, q8_0); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, q8_1); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, q8_2); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, q8_3); + __m128i q8s_4 = _mm_maddubs_epi16(m32s, q8_4); + __m128i q8s_5 = _mm_maddubs_epi16(m32s, q8_5); + __m128i q8s_6 = _mm_maddubs_epi16(m32s, q8_6); + __m128i q8s_7 = _mm_maddubs_epi16(m32s, q8_7); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, q8_0); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, q8_1); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, q8_2); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, q8_3); + __m128i p16_4 = _mm_maddubs_epi16(q4_4, q8_4); + __m128i p16_5 = _mm_maddubs_epi16(q4_5, q8_5); + __m128i p16_6 = _mm_maddubs_epi16(q4_6, q8_6); + __m128i p16_7 = _mm_maddubs_epi16(q4_7, q8_7); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + p16_4 = _mm_sub_epi16(p16_4, q8s_4); + p16_5 = _mm_sub_epi16(p16_5, q8s_5); + p16_6 = _mm_sub_epi16(p16_6, q8s_6); + p16_7 = _mm_sub_epi16(p16_7, q8s_7); + + const __m128i scale_0 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_1 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_2 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + const __m128i scale_3 = _mm_shuffle_epi8(scales, shuffle); + shuffle = _mm_add_epi8(shuffle, m2); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + p16_4 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_2), p16_4); + p16_5 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_2, scale_2)), p16_5); + p16_6 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_3), p16_6); + p16_7 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_3, scale_3)), p16_7); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_4, p16_6)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_5, p16_7)); + + } + + __m256i sumi = MM256_SET_M128I(sumi_1, sumi_0); + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi)), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + size_t vl; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + for (int j = 0; j < QK_K/128; ++j) { + + vl = 32; + + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + + vl = 16; + + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + sumf += d * sum_t; + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int j = 0; j < QK_K; j += 128) { + for (int l = 0; l < 32; ++l) { + a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + a += 128; + q4 += 64; + qh += 32; + } + a = aux8; + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#else + +void wsp_ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + assert(n % QK_K == 0); + + const block_q6_K * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#ifdef __ARM_NEON + + float sum = 0; + + const uint8x16_t m4b = vdupq_n_u8(0xF); + const int8x16_t m32s = vdupq_n_s8(32); +#if defined(__ARM_FEATURE_DOTPROD) + const int32x4_t vzero = vdupq_n_s32(0); +#endif + + const uint8x16_t mone = vdupq_n_u8(3); + + wsp_ggml_int8x16x4_t q6bytes; + wsp_ggml_uint8x16x4_t q6h; + + for (int i = 0; i < nb; ++i) { + + const float d_all = (float)x[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int32_t isum = 0; + + uint8x16_t qhbits = vld1q_u8(qh); + wsp_ggml_uint8x16x2_t q6bits = wsp_ggml_vld1q_u8_x2(q6); + wsp_ggml_int8x16x4_t q8bytes = wsp_ggml_vld1q_s8_x4(q8); + + q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4); + uint8x16_t shifted = vshrq_n_u8(qhbits, 2); + q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 4); + q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + shifted = vshrq_n_u8(qhbits, 6); + q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4); + + q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s); + q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s); + q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[2])), m32s); + q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[3])), m32s); + +#if defined(__ARM_FEATURE_DOTPROD) + + isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] + + vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3]; +#else + + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])), + vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])), + vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1]))); + isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1]; + + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])), + vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])), + vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3]))); + isum += vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3]; +#endif + + sum += isum * d_all * y[i].d; + + } + *s = sum; + +#elif defined __AVX2__ + + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(3); + const __m256i m32s = _mm256_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m256i sumi = _mm256_setzero_si256(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 2), q4bitsH), m2), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(MM256_SET_M128I(_mm_srli_epi16(q4bitsH, 6), _mm_srli_epi16(q4bitsH, 4)), m2), 4); + + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_1); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); + __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); + + __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); + __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); + + p16_0 = _mm256_sub_epi16(p16_0, q8s_0); + p16_1 = _mm256_sub_epi16(p16_1, q8s_1); + + p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); + + sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1)); + + acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __AVX__ + + const __m128i m4 = _mm_set1_epi8(0xF); + const __m128i m2 = _mm_set1_epi8(3); + const __m128i m32s = _mm_set1_epi8(32); + + __m256 acc = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * WSP_GGML_FP16_TO_FP32(x[i].d); + + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const __m64 scales_1 = _mm_set1_pi8(x[i].scales[0]); + const __m64 scales_2 = _mm_set1_pi8(x[i].scales[1]); + const __m64 scales_3 = _mm_set1_pi8(x[i].scales[2]); + const __m64 scales_4 = _mm_set1_pi8(x[i].scales[3]); + + __m128i sumi_0 = _mm_setzero_si128(); + __m128i sumi_1 = _mm_setzero_si128(); + + const __m128i scale_0 = _mm_set_epi64(scales_2, scales_1); + const __m128i scale_1 = _mm_set_epi64(scales_4, scales_3); + + const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); + const __m128i q4bitsH = _mm_loadu_si128((const __m128i*)qh); + + const __m128i q4h_0 = _mm_slli_epi16(_mm_and_si128(q4bitsH, m2), 4); + const __m128i q4h_1 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 2), m2), 4); + const __m128i q4h_2 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 4), m2), 4); + const __m128i q4h_3 = _mm_slli_epi16(_mm_and_si128(_mm_srli_epi16(q4bitsH, 6), m2), 4); + + const __m128i q4_0 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 0), m4), q4h_0); + const __m128i q4_1 = _mm_or_si128(_mm_and_si128(_mm256_extractf128_si256(q4bits1, 1), m4), q4h_1); + const __m128i q4_2 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 0), 4), m4), q4h_2); + const __m128i q4_3 = _mm_or_si128(_mm_and_si128(_mm_srli_epi16(_mm256_extractf128_si256(q4bits1, 1), 4), m4), q4h_3); + + const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)(q8+ 0)); + const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)(q8+32)); + + __m128i q8s_0 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 0)); + __m128i q8s_1 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_0, 1)); + __m128i q8s_2 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 0)); + __m128i q8s_3 = _mm_maddubs_epi16(m32s, _mm256_extractf128_si256(q8_1, 1)); + + __m128i p16_0 = _mm_maddubs_epi16(q4_0, _mm256_extractf128_si256(q8_0, 0)); + __m128i p16_1 = _mm_maddubs_epi16(q4_1, _mm256_extractf128_si256(q8_0, 1)); + __m128i p16_2 = _mm_maddubs_epi16(q4_2, _mm256_extractf128_si256(q8_1, 0)); + __m128i p16_3 = _mm_maddubs_epi16(q4_3, _mm256_extractf128_si256(q8_1, 1)); + + p16_0 = _mm_sub_epi16(p16_0, q8s_0); + p16_1 = _mm_sub_epi16(p16_1, q8s_1); + p16_2 = _mm_sub_epi16(p16_2, q8s_2); + p16_3 = _mm_sub_epi16(p16_3, q8s_3); + + p16_0 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_0), p16_0); + p16_1 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_0, scale_0)), p16_1); + p16_2 = _mm_madd_epi16(_mm_cvtepi8_epi16(scale_1), p16_2); + p16_3 = _mm_madd_epi16(_mm_cvtepi8_epi16(_mm_unpackhi_epi64(scale_1, scale_1)), p16_3); + + sumi_0 = _mm_add_epi32(sumi_0, _mm_add_epi32(p16_0, p16_2)); + sumi_1 = _mm_add_epi32(sumi_1, _mm_add_epi32(p16_1, p16_3)); + + acc = _mm256_add_ps(_mm256_mul_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(MM256_SET_M128I(sumi_1, sumi_0))), acc); + } + + *s = hsum_float_8(acc); + +#elif defined __riscv_v_intrinsic + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + + const float d_all = (float)x[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int32_t isum = 0; + + size_t vl = 16; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+16, vl); + + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); + + vuint8mf2_t qh0 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh1 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh2 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + qh_x = __riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl); + vuint8mf2_t qh3 = __riscv_vsll_vx_u8mf2(__riscv_vand_vx_u8mf2(qh_x, 0x3, vl), 0x4, vl); + + vuint8mf2_t q6h_0 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_0, 0xF, vl), qh0, vl); + vuint8mf2_t q6h_1 = __riscv_vor_vv_u8mf2(__riscv_vand_vx_u8mf2(q6_1, 0xF, vl), qh1, vl); + vuint8mf2_t q6h_2 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_0, 0x4, vl), qh2, vl); + vuint8mf2_t q6h_3 = __riscv_vor_vv_u8mf2(__riscv_vsrl_vx_u8mf2(q6_1, 0x4, vl), qh3, vl); + + vint8mf2_t q6v_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_0), 32, vl); + vint8mf2_t q6v_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_1), 32, vl); + vint8mf2_t q6v_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_2), 32, vl); + vint8mf2_t q6v_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(q6h_3), 32, vl); + + // load Q8 and take product + vint16m1_t p0 = __riscv_vwmul_vv_i16m1(q6v_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t p1 = __riscv_vwmul_vv_i16m1(q6v_1, __riscv_vle8_v_i8mf2(q8+16, vl), vl); + vint16m1_t p2 = __riscv_vwmul_vv_i16m1(q6v_2, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t p3 = __riscv_vwmul_vv_i16m1(q6v_3, __riscv_vle8_v_i8mf2(q8+48, vl), vl); + + vint32m1_t vs_0 = __riscv_vwredsum_vs_i16m1_i32m1(p0, vzero, vl); + vint32m1_t vs_1 = __riscv_vwredsum_vs_i16m1_i32m1(p1, vzero, vl); + vint32m1_t vs_2 = __riscv_vwredsum_vs_i16m1_i32m1(p2, vzero, vl); + vint32m1_t vs_3 = __riscv_vwredsum_vs_i16m1_i32m1(p3, vzero, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(vs_0) * scale[0]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_1) * scale[1]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_2) * scale[2]; + isum += __riscv_vmv_x_s_i32m1_i32(vs_3) * scale[3]; + + sumf += isum * d_all * y[i].d; + + } + + *s = sumf; + +#else + + int8_t aux8[QK_K]; + int16_t aux16[8]; + float sums [8]; + int32_t aux32[8]; + memset(sums, 0, 8*sizeof(float)); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q4 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + memset(aux32, 0, 8*sizeof(int32_t)); + int8_t * restrict a = aux8; + for (int l = 0; l < 16; ++l) { + a[l+ 0] = (int8_t)((q4[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + a[l+16] = (int8_t)((q4[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + a[l+32] = (int8_t)((q4[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + a[l+48] = (int8_t)((q4[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + int is = 0; + for (int j = 0; j < QK_K/16; ++j) { + int scale = x[i].scales[is++]; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l]; + for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l]; + q8 += 8; a += 8; + } + const float d = WSP_GGML_FP16_TO_FP32(x[i].d) * y[i].d; + for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l]; + } + for (int l = 0; l < 8; ++l) sumf += sums[l]; + *s = sumf; +#endif +} + +#endif diff --git a/cpp/ggml-quants.h b/cpp/ggml-quants.h new file mode 100644 index 0000000..603f973 --- /dev/null +++ b/cpp/ggml-quants.h @@ -0,0 +1,224 @@ +#pragma once + +#include "ggml-impl.h" + +// GGML internal header + +#include +#include + +#define QK4_0 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(wsp_ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + wsp_ggml_fp16_t m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == 2 * sizeof(wsp_ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(wsp_ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + wsp_ggml_fp16_t m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(wsp_ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +typedef struct { + wsp_ggml_fp16_t d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(wsp_ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); + +#define QK8_1 32 +typedef struct { + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_1]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); + +// +// Super-block quantization structures +// + +// Super-block size +#ifdef WSP_GGML_QKK_64 +#define QK_K 64 +#define K_SCALE_SIZE 4 +#else +#define QK_K 256 +#define K_SCALE_SIZE 12 +#endif + +// 2-bit quantization +// weight is represented as x = a * q + b +// 16 blocks of 16 elements each +// Effectively 2.5625 bits per weight +typedef struct { + uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits + uint8_t qs[QK_K/4]; // quants + wsp_ggml_fp16_t d; // super-block scale for quantized scales + wsp_ggml_fp16_t dmin; // super-block scale for quantized mins +} block_q2_K; +static_assert(sizeof(block_q2_K) == 2*sizeof(wsp_ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); + +// 3-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 3.4375 bits per weight +#ifdef WSP_GGML_QKK_64 +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[2]; + wsp_ggml_fp16_t d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(wsp_ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 2, "wrong q3_K block size/padding"); +#else +typedef struct { + uint8_t hmask[QK_K/8]; // quants - high bit + uint8_t qs[QK_K/4]; // quants - low 2 bits + uint8_t scales[12]; // scales, quantized with 6 bits + wsp_ggml_fp16_t d; // super-block scale +} block_q3_K; +static_assert(sizeof(block_q3_K) == sizeof(wsp_ggml_fp16_t) + QK_K / 4 + QK_K / 8 + 12, "wrong q3_K block size/padding"); +#endif + +// 4-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +#ifdef WSP_GGML_QKK_64 +typedef struct { + wsp_ggml_fp16_t d[2]; // super-block scales/mins + uint8_t scales[2]; // 4-bit block scales/mins + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(wsp_ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); +#else +typedef struct { + wsp_ggml_fp16_t d; // super-block scale for quantized scales + wsp_ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; +static_assert(sizeof(block_q4_K) == 2*sizeof(wsp_ggml_fp16_t) + K_SCALE_SIZE + QK_K/2, "wrong q4_K block size/padding"); +#endif + +// 5-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 5.5 bits per weight +#ifdef WSP_GGML_QKK_64 +typedef struct { + wsp_ggml_fp16_t d; // super-block scale + int8_t scales[QK_K/16]; // 8-bit block scales + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == sizeof(wsp_ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding"); +#else +typedef struct { + wsp_ggml_fp16_t d; // super-block scale for quantized scales + wsp_ggml_fp16_t dmin; // super-block scale for quantized mins + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K/8]; // quants, high bit + uint8_t qs[QK_K/2]; // quants, low 4 bits +} block_q5_K; +static_assert(sizeof(block_q5_K) == 2*sizeof(wsp_ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); +#endif + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + wsp_ggml_fp16_t d; // super-block scale +} block_q6_K; +static_assert(sizeof(block_q6_K) == sizeof(wsp_ggml_fp16_t) + QK_K / 16 + 3*QK_K/4, "wrong q6_K block size/padding"); + +// This is only used for intermediate quantization and dot products +typedef struct { + float d; // delta + int8_t qs[QK_K]; // quants + int16_t bsums[QK_K/16]; // sum of quants in groups of 16 +} block_q8_K; +static_assert(sizeof(block_q8_K) == sizeof(float) + QK_K + QK_K/16*sizeof(int16_t), "wrong q8_K block size/padding"); + + +// Quantization +void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k); +void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k); +void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k); +void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k); +void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k); +void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k); + +void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k); +void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k); +void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k); +void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k); +void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k); +void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k); + +void quantize_row_q4_0(const float * restrict x, void * restrict y, int k); +void quantize_row_q4_1(const float * restrict x, void * restrict y, int k); +void quantize_row_q5_0(const float * restrict x, void * restrict y, int k); +void quantize_row_q5_1(const float * restrict x, void * restrict y, int k); +void quantize_row_q8_0(const float * restrict x, void * restrict y, int k); +void quantize_row_q8_1(const float * restrict x, void * restrict y, int k); + +void quantize_row_q2_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q3_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q4_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q5_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q6_K(const float * restrict x, void * restrict y, int k); +void quantize_row_q8_K(const float * restrict x, void * restrict y, int k); + +// Dequantization +void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k); +void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k); +void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k); +void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k); +void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int k); +//void dequantize_row_q8_1(const block_q8_1 * restrict x, float * restrict y, int k); + +void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k); +void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k); +void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k); +void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k); +void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); +void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); + +// Dot product +void wsp_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); + +void wsp_ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_q3_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_q5_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); +void wsp_ggml_vec_dot_q6_K_q8_K(int n, float * restrict s, const void * restrict vx, const void * restrict vy); diff --git a/cpp/ggml.c b/cpp/ggml.c index ded2caf..9c079d8 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -1,10 +1,8 @@ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows +#define _USE_MATH_DEFINES // For M_PI on MSVC -#include "ggml.h" - -#ifdef WSP_GGML_USE_K_QUANTS -#include "k_quants.h" -#endif +#include "ggml-impl.h" +#include "ggml-quants.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -30,18 +28,6 @@ #include #endif -// static_assert should be a #define, but if it's not, -// fall back to the _Static_assert C11 keyword. -// if C99 - static_assert is noop -// ref: https://stackoverflow.com/a/53923785/4039976 -#ifndef static_assert -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L) -#define static_assert(cond, msg) _Static_assert(cond, msg) -#else -#define static_assert(cond, msg) struct global_scope_noop_trick -#endif -#endif - #if defined(_MSC_VER) // disable "possible loss of data" to avoid hundreds of casts // we should just be careful :) @@ -89,7 +75,9 @@ static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(vo static int pthread_join(pthread_t thread, void * unused) { (void) unused; - return (int) WaitForSingleObject(thread, INFINITE); + int ret = (int) WaitForSingleObject(thread, INFINITE); + CloseHandle(thread); + return ret; } static int sched_yield (void) { @@ -107,21 +95,52 @@ typedef void * thread_ret_t; #include #endif + #ifdef WSP_GGML_USE_CPU_HBM #include #endif -// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512 -#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__)) -#ifndef __FMA__ -#define __FMA__ -#endif -#ifndef __F16C__ -#define __F16C__ -#endif -#ifndef __SSE3__ -#define __SSE3__ +#if defined(__APPLE__) +#include #endif + +#if (defined(__linux__) || defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__)) && \ + (!defined(TARGET_OS_TV) && !defined(TARGET_OS_WATCH)) + +#include + +void wsp_ggml_print_backtrace(void) { + /* + #include + #include + + void * trace[100]; + + int nptrs = backtrace(trace, sizeof(trace)/sizeof(trace[0])); + + backtrace_symbols_fd(trace, nptrs, STDERR_FILENO); + */ + + // backtrack_symbols does not show line numbers, use gdb instead + char attach[32]; + snprintf(attach, sizeof(attach), "attach %d", getpid()); + int pid = fork(); + if (pid == 0) { + execlp("gdb", "gdb", "--batch", + "-ex", "set style enabled on", + "-ex", attach, + "-ex", "bt -frame-info source-and-location", + "-ex", "detach", + "-ex", "quit", + NULL); + } else { + waitpid(pid, NULL, 0); + } +} +#else +void wsp_ggml_print_backtrace(void) { + // platform not supported +} #endif /*#define WSP_GGML_PERF*/ @@ -134,6 +153,7 @@ typedef void * thread_ret_t; #define WSP_GGML_SOFT_MAX_UNROLL 4 #define WSP_GGML_VEC_DOT_UNROLL 2 +#define WSP_GGML_VEC_MAD_UNROLL 32 // // logging @@ -159,40 +179,16 @@ typedef void * thread_ret_t; #define WSP_GGML_PRINT(...) printf(__VA_ARGS__) +// +// end of logging block +// + #ifdef WSP_GGML_USE_ACCELERATE // uncomment to use vDSP for soft max computation // note: not sure if it is actually faster //#define WSP_GGML_SOFT_MAX_ACCELERATE #endif -// -// logging -// - -#if (WSP_GGML_DEBUG >= 1) -#define WSP_GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define WSP_GGML_PRINT_DEBUG(...) -#endif - -#if (WSP_GGML_DEBUG >= 5) -#define WSP_GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define WSP_GGML_PRINT_DEBUG_5(...) -#endif - -#if (WSP_GGML_DEBUG >= 10) -#define WSP_GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define WSP_GGML_PRINT_DEBUG_10(...) -#endif - -#define WSP_GGML_PRINT(...) printf(__VA_ARGS__) - -// -// end of logging block -// - #if defined(_MSC_VER) || defined(__MINGW32__) #define WSP_GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, WSP_GGML_MEM_ALIGN) #define WSP_GGML_ALIGNED_FREE(ptr) _aligned_free(ptr) @@ -242,18 +238,18 @@ inline static void * wsp_ggml_aligned_malloc(size_t size) { // #define WSP_GGML_TENSOR_UNARY_OP_LOCALS \ - WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ - WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ - WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ - WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) #define WSP_GGML_TENSOR_BINARY_OP_LOCALS \ - WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); \ - WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); \ - WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); \ - WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); \ - WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); \ - WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) #if defined(WSP_GGML_USE_ACCELERATE) #include @@ -272,228 +268,33 @@ inline static void * wsp_ggml_aligned_malloc(size_t size) { #include "ggml-opencl.h" #endif -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - // floating point type used to accumulate sums typedef double wsp_ggml_float; -// 16-bit float -// on Arm, we use __fp16 -// on x86, we use uint16_t -#ifdef __ARM_NEON - -// if YCM cannot find , make a symbolic link to it, for example: -// -// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ -// -#include - -#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) -#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) (x) - -#define WSP_GGML_FP16_TO_FP32(x) ((float) (x)) -#define WSP_GGML_FP32_TO_FP16(x) (x) - -#else - -#ifdef __wasm_simd128__ -#include -#else -#ifdef __POWER9_VECTOR__ -#include -#undef bool -#define bool _Bool -#else -#if defined(_MSC_VER) || defined(__MINGW32__) -#include -#else -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) -#if !defined(__riscv) -#include -#endif -#endif -#endif -#endif -#endif - -#ifdef __riscv_v_intrinsic -#include -#endif - -#ifdef __F16C__ - -#ifdef _MSC_VER -#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x))) -#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0) -#else -#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x) -#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0) -#endif - -#elif defined(__POWER9_VECTOR__) - -#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x) -#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x) -/* the inline asm below is about 12% faster than the lookup method */ -#define WSP_GGML_FP16_TO_FP32(x) WSP_GGML_COMPUTE_FP16_TO_FP32(x) -#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x) - -static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) { - register float f; - register double d; - __asm__( - "mtfprd %0,%2\n" - "xscvhpdp %0,%0\n" - "frsp %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=f"(f): - /* in */ "r"(h)); - return f; -} - -static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) { - register double d; - register wsp_ggml_fp16_t r; - __asm__( /* xscvdphp can work on double or single precision */ - "xscvdphp %0,%2\n" - "mffprd %1,%0\n" : - /* temp */ "=d"(d), - /* out */ "=r"(r): - /* in */ "f"(f)); - return r; -} - -#else - -// FP16 <-> FP32 -// ref: https://github.com/Maratyszcza/FP16 - -static inline float fp32_from_bits(uint32_t w) { - union { - uint32_t as_bits; - float as_value; - } fp32; - fp32.as_bits = w; - return fp32.as_value; -} - -static inline uint32_t fp32_to_bits(float f) { - union { - float as_value; - uint32_t as_bits; - } fp32; - fp32.as_value = f; - return fp32.as_bits; -} - -static inline float wsp_ggml_compute_fp16_to_fp32(wsp_ggml_fp16_t h) { - const uint32_t w = (uint32_t) h << 16; - const uint32_t sign = w & UINT32_C(0x80000000); - const uint32_t two_w = w + w; - - const uint32_t exp_offset = UINT32_C(0xE0) << 23; -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float exp_scale = 0x1.0p-112f; -#else - const float exp_scale = fp32_from_bits(UINT32_C(0x7800000)); -#endif - const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; - - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; - - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = sign | - (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value)); - return fp32_from_bits(result); -} - -static inline wsp_ggml_fp16_t wsp_ggml_compute_fp32_to_fp16(float f) { -#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__) - const float scale_to_inf = 0x1.0p+112f; - const float scale_to_zero = 0x1.0p-110f; -#else - const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000)); - const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000)); -#endif - float base = (fabsf(f) * scale_to_inf) * scale_to_zero; - - const uint32_t w = fp32_to_bits(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } - - base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = fp32_to_bits(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); -} - -#define WSP_GGML_COMPUTE_FP16_TO_FP32(x) wsp_ggml_compute_fp16_to_fp32(x) -#define WSP_GGML_COMPUTE_FP32_TO_FP16(x) wsp_ggml_compute_fp32_to_fp16(x) - -#endif // __F16C__ +#undef MIN +#undef MAX -#endif // __ARM_NEON +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) // // global data // // precomputed gelu table for f16 (128 KB) -static wsp_ggml_fp16_t table_gelu_f16[1 << 16]; +static wsp_ggml_fp16_t wsp_ggml_table_gelu_f16[1 << 16]; // precomputed quick gelu table for f16 (128 KB) -static wsp_ggml_fp16_t table_gelu_quick_f16[1 << 16]; +static wsp_ggml_fp16_t wsp_ggml_table_gelu_quick_f16[1 << 16]; // precomputed silu table for f16 (128 KB) -static wsp_ggml_fp16_t table_silu_f16[1 << 16]; +static wsp_ggml_fp16_t wsp_ggml_table_silu_f16[1 << 16]; // precomputed exp table for f16 (128 KB) -static wsp_ggml_fp16_t table_exp_f16[1 << 16]; - -// precomputed f32 table for f16 (256 KB) -static float table_f32_f16[1 << 16]; - -#if defined(__ARM_NEON) || defined(__wasm_simd128__) -#define B1(c,s,n) 0x ## n ## c , 0x ## n ## s -#define B2(c,s,n) B1(c,s,n ## c), B1(c,s,n ## s) -#define B3(c,s,n) B2(c,s,n ## c), B2(c,s,n ## s) -#define B4(c,s,n) B3(c,s,n ## c), B3(c,s,n ## s) -#define B5(c,s,n) B4(c,s,n ## c), B4(c,s,n ## s) -#define B6(c,s,n) B5(c,s,n ## c), B5(c,s,n ## s) -#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s) -#define B8(c,s ) B7(c,s, c), B7(c,s, s) - -// precomputed tables for expanding 8bits to 8 bytes: -static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4 -static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4 -#endif - -// On ARM NEON, it's quicker to directly convert x -> x instead of calling into wsp_ggml_lookup_fp16_to_fp32, -// so we define WSP_GGML_FP16_TO_FP32 and WSP_GGML_FP32_TO_FP16 elsewhere for NEON. -// This is also true for POWER9. -#if !defined(WSP_GGML_FP16_TO_FP32) || !defined(WSP_GGML_FP32_TO_FP16) - -inline static float wsp_ggml_lookup_fp16_to_fp32(wsp_ggml_fp16_t f) { - uint16_t s; - memcpy(&s, &f, sizeof(uint16_t)); - return table_f32_f16[s]; -} +static wsp_ggml_fp16_t wsp_ggml_table_exp_f16[1 << 16]; -#define WSP_GGML_FP16_TO_FP32(x) wsp_ggml_lookup_fp16_to_fp32(x) -#define WSP_GGML_FP32_TO_FP16(x) WSP_GGML_COMPUTE_FP32_TO_FP16(x) - -#endif +// precomputed f32 table for f16 (256 KB) (ggml-impl.h) +float wsp_ggml_table_f32_f16[1 << 16]; // note: do not use these inside ggml.c // these are meant to be used via the ggml.h API @@ -592,7 +393,6 @@ int64_t wsp_ggml_cycles_per_ms(void) { #define wsp_ggml_perf_cycles_per_ms() 0 #endif - // // cache line // @@ -609,1009 +409,8 @@ int64_t wsp_ggml_cycles_per_ms(void) { static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); -// -// quantization -// - -#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) - -#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -// multiply int8_t, add results pairwise twice -static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) { - // Get absolute values of x vectors - const __m128i ax = _mm_sign_epi8(x, x); - // Sign the values of the y vectors - const __m128i sy = _mm_sign_epi8(y, x); - // Perform multiplication and create 16-bit values - const __m128i dot = _mm_maddubs_epi16(ax, sy); - const __m128i ones = _mm_set1_epi16(1); - return _mm_madd_epi16(ones, dot); -} - -#if __AVX__ || __AVX2__ || __AVX512F__ -// horizontally add 8 floats -static inline float hsum_float_8(const __m256 x) { - __m128 res = _mm256_extractf128_ps(x, 1); - res = _mm_add_ps(res, _mm256_castps256_ps128(x)); - res = _mm_add_ps(res, _mm_movehl_ps(res, res)); - res = _mm_add_ss(res, _mm_movehdup_ps(res)); - return _mm_cvtss_f32(res); -} - -// horizontally add 8 int32_t -static inline int hsum_i32_8(const __m256i a) { - const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); - const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - const __m128i sum64 = _mm_add_epi32(hi64, sum128); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -// horizontally add 4 int32_t -static inline int hsum_i32_4(const __m128i a) { - const __m128i hi64 = _mm_unpackhi_epi64(a, a); - const __m128i sum64 = _mm_add_epi32(hi64, a); - const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} - -#if defined(__AVX2__) || defined(__AVX512F__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m256i shuf_mask = _mm256_set_epi64x( - 0x0303030303030303, 0x0202020202020202, - 0x0101010101010101, 0x0000000000000000); - __m256i bytes = _mm256_shuffle_epi8(_mm256_set1_epi32(x32), shuf_mask); - const __m256i bit_mask = _mm256_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytes = _mm256_or_si256(bytes, bit_mask); - return _mm256_cmpeq_epi8(bytes, _mm256_set1_epi64x(-1)); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); - const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); - const __m256i lowMask = _mm256_set1_epi8( 0xF ); - return _mm256_and_si256(lowMask, bytes); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m256i x) { - const __m256i ones = _mm256_set1_epi16(1); - const __m256i summed_pairs = _mm256_madd_epi16(ones, x); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { -#if __AVXVNNI__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Perform multiplication and create 16-bit values - const __m256i dot = _mm256_maddubs_epi16(ax, sy); - return sum_i16_pairs_float(dot); -#endif -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { -#if __AVXVNNIINT8__ - const __m256i zero = _mm256_setzero_si256(); - const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); - return _mm256_cvtepi32_ps(summed_pairs); -#else - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); - return mul_sum_us8_pairs_float(ax, sy); -#endif -} - -static inline __m128i packNibbles( __m256i bytes ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh -#if __AVX512F__ - const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000 - bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh - return _mm256_cvtepi16_epi8(bytes); // abcd_efgh -#else - const __m256i lowByte = _mm256_set1_epi16( 0xFF ); - __m256i high = _mm256_andnot_si256( lowByte, bytes ); - __m256i low = _mm256_and_si256( lowByte, bytes ); - high = _mm256_srli_epi16( high, 4 ); - bytes = _mm256_or_si256( low, high ); - - // Compress uint16_t lanes into bytes - __m128i r0 = _mm256_castsi256_si128( bytes ); - __m128i r1 = _mm256_extracti128_si256( bytes, 1 ); - return _mm_packus_epi16( r0, r1 ); -#endif -} -#elif defined(__AVX__) -// spread 32 bits to 32 bytes { 0x00, 0xFF } -static inline __m256i bytes_from_bits_32(const uint8_t * x) { - uint32_t x32; - memcpy(&x32, x, sizeof(uint32_t)); - const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); - const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202); - __m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl); - __m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh); - const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe); - bytesl = _mm_or_si128(bytesl, bit_mask); - bytesh = _mm_or_si128(bytesh, bit_mask); - bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1)); - bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1)); - return MM256_SET_M128I(bytesh, bytesl); -} - -// Unpack 32 4-bit fields into 32 bytes -// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval -static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) -{ - // Load 16 bytes from memory - __m128i tmpl = _mm_loadu_si128((const __m128i *)rsi); - __m128i tmph = _mm_srli_epi16(tmpl, 4); - const __m128i lowMask = _mm_set1_epi8(0xF); - tmpl = _mm_and_si128(lowMask, tmpl); - tmph = _mm_and_si128(lowMask, tmph); - return MM256_SET_M128I(tmph, tmpl); -} - -// add int16_t pairwise and return as float vector -static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { - const __m128i ones = _mm_set1_epi16(1); - const __m128i summed_pairsl = _mm_madd_epi16(ones, xl); - const __m128i summed_pairsh = _mm_madd_epi16(ones, xh); - const __m256i summed_pairs = MM256_SET_M128I(summed_pairsh, summed_pairsl); - return _mm256_cvtepi32_ps(summed_pairs); -} - -static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { - const __m128i axl = _mm256_castsi256_si128(ax); - const __m128i axh = _mm256_extractf128_si256(ax, 1); - const __m128i syl = _mm256_castsi256_si128(sy); - const __m128i syh = _mm256_extractf128_si256(sy, 1); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - const __m128i xl = _mm256_castsi256_si128(x); - const __m128i xh = _mm256_extractf128_si256(x, 1); - const __m128i yl = _mm256_castsi256_si128(y); - const __m128i yh = _mm256_extractf128_si256(y, 1); - // Get absolute values of x vectors - const __m128i axl = _mm_sign_epi8(xl, xl); - const __m128i axh = _mm_sign_epi8(xh, xh); - // Sign the values of the y vectors - const __m128i syl = _mm_sign_epi8(yl, xl); - const __m128i syh = _mm_sign_epi8(yh, xh); - // Perform multiplication and create 16-bit values - const __m128i dotl = _mm_maddubs_epi16(axl, syl); - const __m128i doth = _mm_maddubs_epi16(axh, syh); - return sum_i16_pairs_float(doth, dotl); -} - -static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) -{ - // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh - const __m128i lowByte = _mm_set1_epi16( 0xFF ); - __m128i high = _mm_andnot_si128( lowByte, bytes1 ); - __m128i low = _mm_and_si128( lowByte, bytes1 ); - high = _mm_srli_epi16( high, 4 ); - bytes1 = _mm_or_si128( low, high ); - high = _mm_andnot_si128( lowByte, bytes2 ); - low = _mm_and_si128( lowByte, bytes2 ); - high = _mm_srli_epi16( high, 4 ); - bytes2 = _mm_or_si128( low, high ); - - return _mm_packus_epi16( bytes1, bytes2); -} -#endif -#elif defined(__SSSE3__) -// horizontally add 4x4 floats -static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { - __m128 res_0 =_mm_hadd_ps(a, b); - __m128 res_1 =_mm_hadd_ps(c, d); - __m128 res =_mm_hadd_ps(res_0, res_1); - res =_mm_hadd_ps(res, res); - res =_mm_hadd_ps(res, res); - - return _mm_cvtss_f32(res); -} -#endif // __AVX__ || __AVX2__ || __AVX512F__ -#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) - -#if defined(__ARM_NEON) - -#if !defined(__aarch64__) - -inline static int32_t vaddvq_s32(int32x4_t v) { - return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3); -} - -inline static float vaddvq_f32(float32x4_t v) { - return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); -} - -inline static float vmaxvq_f32(float32x4_t v) { - return - MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)), - MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3))); -} - -inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) { - int32x4_t res; - - res[0] = roundf(vgetq_lane_f32(v, 0)); - res[1] = roundf(vgetq_lane_f32(v, 1)); - res[2] = roundf(vgetq_lane_f32(v, 2)); - res[3] = roundf(vgetq_lane_f32(v, 3)); - - return res; -} - -#endif -#endif - -#define QK4_0 32 -typedef struct { - wsp_ggml_fp16_t d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(wsp_ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); - -#define QK4_1 32 -typedef struct { - wsp_ggml_fp16_t d; // delta - wsp_ggml_fp16_t m; // min - uint8_t qs[QK4_1 / 2]; // nibbles / quants -} block_q4_1; -static_assert(sizeof(block_q4_1) == 2 * sizeof(wsp_ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); - -#define QK5_0 32 -typedef struct { - wsp_ggml_fp16_t d; // delta - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_0 / 2]; // nibbles / quants -} block_q5_0; -static_assert(sizeof(block_q5_0) == sizeof(wsp_ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); - -#define QK5_1 32 -typedef struct { - wsp_ggml_fp16_t d; // delta - wsp_ggml_fp16_t m; // min - uint8_t qh[4]; // 5-th bit of quants - uint8_t qs[QK5_1 / 2]; // nibbles / quants -} block_q5_1; -static_assert(sizeof(block_q5_1) == 2 * sizeof(wsp_ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); - -#define QK8_0 32 -typedef struct { - wsp_ggml_fp16_t d; // delta - int8_t qs[QK8_0]; // quants -} block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(wsp_ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); - -#define QK8_1 32 -typedef struct { - float d; // delta - float s; // d * sum(qs[i]) - int8_t qs[QK8_1]; // quants -} block_q8_1; -static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); - -// reference implementation for deterministic creation of model files -static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { - static const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float d = max / -8; - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = WSP_GGML_FP32_TO_FP16(d); - - for (int j = 0; j < qk/2; ++j) { - const float x0 = x[i*qk + 0 + j]*id; - const float x1 = x[i*qk + qk/2 + j]*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); - - y[i].qs[j] = xi0; - y[i].qs[j] |= xi1 << 4; - } - } -} - -static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) { - quantize_row_q4_0_reference(x, y, k); -} - -static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) { - const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 4) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = WSP_GGML_FP32_TO_FP16(d); - y[i].m = WSP_GGML_FP32_TO_FP16(min); - - for (int j = 0; j < qk/2; ++j) { - const float x0 = (x[i*qk + 0 + j] - min)*id; - const float x1 = (x[i*qk + qk/2 + j] - min)*id; - - const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f)); - const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f)); - - y[i].qs[j] = xi0; - y[i].qs[j] |= xi1 << 4; - } - } -} - -static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) { - quantize_row_q4_1_reference(x, y, k); -} - -static void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int k) { - static const int qk = QK5_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - float max = 0.0f; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - if (amax < fabsf(v)) { - amax = fabsf(v); - max = v; - } - } - - const float d = max / -16; - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = WSP_GGML_FP32_TO_FP16(d); - - uint32_t qh = 0; - - for (int j = 0; j < qk/2; ++j) { - const float x0 = x[i*qk + 0 + j]*id; - const float x1 = x[i*qk + qk/2 + j]*id; - - const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f)); - const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f)); - - y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); - - // get the 5-th bit and store it in qh at the right position - qh |= ((xi0 & 0x10) >> 4) << (j + 0); - qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); - } - - memcpy(&y[i].qh, &qh, sizeof(qh)); - } -} - -static void quantize_row_q5_0(const float * restrict x, void * restrict y, int k) { - quantize_row_q5_0_reference(x, y, k); -} - -static void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int k) { - const int qk = QK5_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - float min = FLT_MAX; - float max = -FLT_MAX; - - for (int j = 0; j < qk; j++) { - const float v = x[i*qk + j]; - - if (v < min) min = v; - if (v > max) max = v; - } - - const float d = (max - min) / ((1 << 5) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = WSP_GGML_FP32_TO_FP16(d); - y[i].m = WSP_GGML_FP32_TO_FP16(min); - - uint32_t qh = 0; - - for (int j = 0; j < qk/2; ++j) { - const float x0 = (x[i*qk + 0 + j] - min)*id; - const float x1 = (x[i*qk + qk/2 + j] - min)*id; - - const uint8_t xi0 = (uint8_t)(x0 + 0.5f); - const uint8_t xi1 = (uint8_t)(x1 + 0.5f); - - y[i].qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4); - - // get the 5-th bit and store it in qh at the right position - qh |= ((xi0 & 0x10) >> 4) << (j + 0); - qh |= ((xi1 & 0x10) >> 4) << (j + qk/2); - } - - memcpy(&y[i].qh, &qh, sizeof(y[i].qh)); - } -} - -static void quantize_row_q5_1(const float * restrict x, void * restrict y, int k) { - quantize_row_q5_1_reference(x, y, k); -} - -// reference implementation for deterministic creation of model files -static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int k) { - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_0; j++) { - const float v = x[i*QK8_0 + j]; - amax = MAX(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = WSP_GGML_FP32_TO_FP16(d); - - for (int j = 0; j < QK8_0; ++j) { - const float x0 = x[i*QK8_0 + j]*id; - - y[i].qs[j] = roundf(x0); - } - } -} - -static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { - assert(QK8_0 == 32); - assert(k % QK8_0 == 0); - const int nb = k / QK8_0; - - block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = WSP_GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - } - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = WSP_GGML_FP32_TO_FP16(d); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - } - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = WSP_GGML_FP32_TO_FP16(d); - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#else - // scalar - quantize_row_q8_0_reference(x, y, k); -#endif -} - -// reference implementation for deterministic creation of model files -static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int k) { - assert(QK8_1 == 32); - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - for (int i = 0; i < nb; i++) { - float amax = 0.0f; // absolute max - - for (int j = 0; j < QK8_1; j++) { - const float v = x[i*QK8_1 + j]; - amax = MAX(amax, fabsf(v)); - } - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - int sum = 0; - - for (int j = 0; j < QK8_1/2; ++j) { - const float v0 = x[i*QK8_1 + j]*id; - const float v1 = x[i*QK8_1 + QK8_1/2 + j]*id; - - y[i].qs[ j] = roundf(v0); - y[i].qs[QK8_1/2 + j] = roundf(v1); - - sum += y[i].qs[ j]; - sum += y[i].qs[QK8_1/2 + j]; - } - - y[i].s = sum*d; - } -} - -static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) { - assert(k % QK8_1 == 0); - const int nb = k / QK8_1; - - block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - for (int i = 0; i < nb; i++) { - float32x4_t srcv [8]; - float32x4_t asrcv[8]; - float32x4_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = vld1q_f32(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = vabsq_f32(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = vmaxq_f32(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = vmaxq_f32(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = vmaxq_f32(amaxv[8*j], amaxv[8*j+4]); - - const float amax = vmaxvq_f32(amaxv[0]); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - int32x4_t accv = vdupq_n_s32(0); - - for (int j = 0; j < 8; j++) { - const float32x4_t v = vmulq_n_f32(srcv[j], id); - const int32x4_t vi = vcvtnq_s32_f32(v); - - y[i].qs[4*j + 0] = vgetq_lane_s32(vi, 0); - y[i].qs[4*j + 1] = vgetq_lane_s32(vi, 1); - y[i].qs[4*j + 2] = vgetq_lane_s32(vi, 2); - y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3); - - accv = vaddq_s32(accv, vi); - } - - y[i].s = d * vaddvq_s32(accv); - } -#elif defined(__wasm_simd128__) - for (int i = 0; i < nb; i++) { - v128_t srcv [8]; - v128_t asrcv[8]; - v128_t amaxv[8]; - - for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j); - for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]); - - for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]); - for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]); - for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]); - - const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0), - wasm_f32x4_extract_lane(amaxv[0], 1)), - MAX(wasm_f32x4_extract_lane(amaxv[0], 2), - wasm_f32x4_extract_lane(amaxv[0], 3))); - - const float d = amax / ((1 << 7) - 1); - const float id = d ? 1.0f/d : 0.0f; - - y[i].d = d; - - v128_t accv = wasm_i32x4_splat(0); - - for (int j = 0; j < 8; j++) { - const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id)); - const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v); - - y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0); - y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1); - y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2); - y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3); - - accv = wasm_i32x4_add(accv, vi); - } - - y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) + - wasm_i32x4_extract_lane(accv, 1) + - wasm_i32x4_extract_lane(accv, 2) + - wasm_i32x4_extract_lane(accv, 3)); - } -#elif defined(__AVX2__) || defined(__AVX__) - for (int i = 0; i < nb; i++) { - // Load elements into 4 AVX vectors - __m256 v0 = _mm256_loadu_ps( x ); - __m256 v1 = _mm256_loadu_ps( x + 8 ); - __m256 v2 = _mm256_loadu_ps( x + 16 ); - __m256 v3 = _mm256_loadu_ps( x + 24 ); - x += 32; - - // Compute max(abs(e)) for the block - const __m256 signBit = _mm256_set1_ps( -0.0f ); - __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); - maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); - - __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); - max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); - max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); - const float maxScalar = _mm_cvtss_f32( max4 ); - - // Quantize these floats - const float d = maxScalar / 127.f; - y[i].d = d; - const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; - const __m256 mul = _mm256_set1_ps( id ); - - // Apply the multiplier - v0 = _mm256_mul_ps( v0, mul ); - v1 = _mm256_mul_ps( v1, mul ); - v2 = _mm256_mul_ps( v2, mul ); - v3 = _mm256_mul_ps( v3, mul ); - - // Round to nearest integer - v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); - v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); - v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); - v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); - - // Convert floats to integers - __m256i i0 = _mm256_cvtps_epi32( v0 ); - __m256i i1 = _mm256_cvtps_epi32( v1 ); - __m256i i2 = _mm256_cvtps_epi32( v2 ); - __m256i i3 = _mm256_cvtps_epi32( v3 ); - -#if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); - - // Convert int32 to int16 - i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 - i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 - // Convert int16 to int8 - i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 - - // We got our precious signed bytes, but the order is now wrong - // These AVX2 pack instructions process 16-byte pieces independently - // The following instruction is fixing the order - const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); - i0 = _mm256_permutevar8x32_epi32( i0, perm ); - - _mm256_storeu_si256((__m256i *)y[i].qs, i0); -#else - // Since we don't have in AVX some necessary functions, - // we split the registers in half and call AVX2 analogs from SSE - __m128i ni0 = _mm256_castsi256_si128( i0 ); - __m128i ni1 = _mm256_extractf128_si256( i0, 1); - __m128i ni2 = _mm256_castsi256_si128( i1 ); - __m128i ni3 = _mm256_extractf128_si256( i1, 1); - __m128i ni4 = _mm256_castsi256_si128( i2 ); - __m128i ni5 = _mm256_extractf128_si256( i2, 1); - __m128i ni6 = _mm256_castsi256_si128( i3 ); - __m128i ni7 = _mm256_extractf128_si256( i3, 1); - - // Compute the sum of the quants and set y[i].s - const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); - const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); - y[i].s = d * hsum_i32_4(_mm_add_epi32(s0, s1)); - - // Convert int32 to int16 - ni0 = _mm_packs_epi32( ni0, ni1 ); - ni2 = _mm_packs_epi32( ni2, ni3 ); - ni4 = _mm_packs_epi32( ni4, ni5 ); - ni6 = _mm_packs_epi32( ni6, ni7 ); - // Convert int16 to int8 - ni0 = _mm_packs_epi16( ni0, ni2 ); - ni4 = _mm_packs_epi16( ni4, ni6 ); - - _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); - _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); -#endif - } -#else - // scalar - quantize_row_q8_1_reference(x, y, k); -#endif -} - -static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) { - static const int qk = QK4_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = WSP_GGML_FP16_TO_FP32(x[i].d); - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F) - 8; - const int x1 = (x[i].qs[j] >> 4) - 8; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) { - static const int qk = QK4_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = WSP_GGML_FP16_TO_FP32(x[i].d); - const float m = WSP_GGML_FP16_TO_FP32(x[i].m); - - for (int j = 0; j < qk/2; ++j) { - const int x0 = (x[i].qs[j] & 0x0F); - const int x1 = (x[i].qs[j] >> 4); - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} - -static void dequantize_row_q5_0(const block_q5_0 * restrict x, float * restrict y, int k) { - static const int qk = QK5_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = WSP_GGML_FP16_TO_FP32(x[i].d); - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - - y[i*qk + j + 0 ] = x0*d; - y[i*qk + j + qk/2] = x1*d; - } - } -} - -static void dequantize_row_q5_1(const block_q5_1 * restrict x, float * restrict y, int k) { - static const int qk = QK5_1; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - const float d = WSP_GGML_FP16_TO_FP32(x[i].d); - const float m = WSP_GGML_FP16_TO_FP32(x[i].m); - - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int x0 = (x[i].qs[j] & 0x0F) | xh_0; - const int x1 = (x[i].qs[j] >> 4) | xh_1; - - y[i*qk + j + 0 ] = x0*d + m; - y[i*qk + j + qk/2] = x1*d + m; - } - } -} - -static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, int k) { - static const int qk = QK8_0; - - assert(k % qk == 0); - - const int nb = k / qk; - - const block_q8_0 * restrict x = vx; - - for (int i = 0; i < nb; i++) { - const float d = WSP_GGML_FP16_TO_FP32(x[i].d); - - for (int j = 0; j < qk; ++j) { - y[i*qk + j] = x[i].qs[j]*d; - } - } -} - -static void wsp_ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); -static void wsp_ggml_vec_dot_f16(const int n, float * restrict s, wsp_ggml_fp16_t * restrict x, wsp_ggml_fp16_t * restrict y); -static void wsp_ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void wsp_ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void wsp_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void wsp_ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void wsp_ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void wsp_ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y); +static void wsp_ggml_vec_dot_f16(const int n, float * restrict s, wsp_ggml_fp16_t * restrict x, wsp_ggml_fp16_t * restrict y); static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = { [WSP_GGML_TYPE_I8] = { @@ -1673,6 +472,28 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = { .vec_dot = wsp_ggml_vec_dot_q4_1_q8_1, .vec_dot_type = WSP_GGML_TYPE_Q8_1, }, + [4] = { // WSP_GGML_TYPE_Q4_2 + .type_name = "DEPRECATED", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = WSP_GGML_TYPE_COUNT, + }, + [5] = { // WSP_GGML_TYPE_Q4_3 + .type_name = "DEPRECATED", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + .to_float = NULL, + .from_float = NULL, + .from_float_reference = NULL, + .vec_dot = NULL, + .vec_dot_type = WSP_GGML_TYPE_COUNT, + }, [WSP_GGML_TYPE_Q5_0] = { .type_name = "q5_0", .blck_size = QK5_0, @@ -1700,7 +521,7 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = { .blck_size = QK8_0, .type_size = sizeof(block_q8_0), .is_quantized = true, - .to_float = dequantize_row_q8_0, + .to_float = (wsp_ggml_to_float_t) dequantize_row_q8_0, .from_float = quantize_row_q8_0, .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q8_0_reference, .vec_dot = wsp_ggml_vec_dot_q8_0_q8_0, @@ -1715,7 +536,6 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = { .from_float_reference = (wsp_ggml_from_float_t) quantize_row_q8_1_reference, .vec_dot_type = WSP_GGML_TYPE_Q8_1, }, -#ifdef WSP_GGML_USE_K_QUANTS [WSP_GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -1778,7 +598,6 @@ static const wsp_ggml_type_traits_t type_traits[WSP_GGML_TYPE_COUNT] = { .is_quantized = true, .from_float = quantize_row_q8_K, } -#endif }; // For internal test use @@ -1787,11 +606,22 @@ wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type return type_traits[type]; } - // // simd mappings // +#if defined(__ARM_NEON) +#if !defined(__aarch64__) + +// 64-bit compatibility + +inline static float vaddvq_f32(float32x4_t v) { + return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3); +} + +#endif +#endif + // we define a common set of C macros which map to specific intrinsics based on the current architecture // we then implement the fundamental computation operations below using only these macros // adding support for new architectures requires to define the corresponding SIMD macros @@ -1863,7 +693,7 @@ wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type #define WSP_GGML_F16x8_ADD vaddq_f16 #define WSP_GGML_F16x8_MUL vmulq_f16 #define WSP_GGML_F16x8_REDUCE(res, x) \ - { \ + do { \ int offset = WSP_GGML_F16_ARR >> 1; \ for (int i = 0; i < offset; ++i) { \ x[i] = vaddq_f16(x[i], x[offset+i]); \ @@ -1879,7 +709,7 @@ wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ res = (wsp_ggml_float) vaddvq_f32(vaddq_f32(t0, t1)); \ - } + } while (0) #define WSP_GGML_F16_VEC WSP_GGML_F16x8 #define WSP_GGML_F16_VEC_ZERO WSP_GGML_F16x8_ZERO @@ -1940,7 +770,7 @@ wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type #define WSP_GGML_F32x8_ADD _mm256_add_ps #define WSP_GGML_F32x8_MUL _mm256_mul_ps #define WSP_GGML_F32x8_REDUCE(res, x) \ -{ \ +do { \ int offset = WSP_GGML_F32_ARR >> 1; \ for (int i = 0; i < offset; ++i) { \ x[i] = _mm256_add_ps(x[i], x[offset+i]); \ @@ -1957,7 +787,7 @@ wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type _mm256_extractf128_ps(x[0], 1)); \ const __m128 t1 = _mm_hadd_ps(t0, t0); \ res = _mm_cvtss_f32(_mm_hadd_ps(t1, t1)); \ -} +} while (0) // TODO: is this optimal ? #define WSP_GGML_F32_VEC WSP_GGML_F32x8 @@ -2292,1333 +1122,115 @@ static inline void __sse_f16x4_store(wsp_ggml_fp16_t *x, __m128 y) { #define WSP_GGML_F16_VEC_ZERO WSP_GGML_F32Cx4_ZERO #define WSP_GGML_F16_VEC_SET1 WSP_GGML_F32Cx4_SET1 #define WSP_GGML_F16_VEC_LOAD(p, i) WSP_GGML_F32Cx4_LOAD(p) -#define WSP_GGML_F16_VEC_STORE(p, r, i) WSP_GGML_F32Cx4_STORE(p, r[i]) -#define WSP_GGML_F16_VEC_FMA WSP_GGML_F32Cx4_FMA -#define WSP_GGML_F16_VEC_ADD WSP_GGML_F32Cx4_ADD -#define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx4_MUL -#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx4_REDUCE - -#endif - -// WSP_GGML_F32_ARR / WSP_GGML_F16_ARR -// number of registers to use per step -#ifdef WSP_GGML_SIMD -#define WSP_GGML_F32_ARR (WSP_GGML_F32_STEP/WSP_GGML_F32_EPR) -#define WSP_GGML_F16_ARR (WSP_GGML_F16_STEP/WSP_GGML_F16_EPR) -#endif - -// -// fundamental operations -// - -inline static void wsp_ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void wsp_ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void wsp_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void wsp_ggml_vec_set_f16(const int n, wsp_ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - -inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } -inline static void wsp_ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } -inline static void wsp_ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } -inline static void wsp_ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } -inline static void wsp_ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } -inline static void wsp_ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void wsp_ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } -inline static void wsp_ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } -inline static void wsp_ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } -inline static void wsp_ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - -static void wsp_ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { -#ifdef WSP_GGML_SIMD - float sumf = 0.0f; - const int np = (n & ~(WSP_GGML_F32_STEP - 1)); - - WSP_GGML_F32_VEC sum[WSP_GGML_F32_ARR] = { WSP_GGML_F32_VEC_ZERO }; - - WSP_GGML_F32_VEC ax[WSP_GGML_F32_ARR]; - WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR]; - - for (int i = 0; i < np; i += WSP_GGML_F32_STEP) { - for (int j = 0; j < WSP_GGML_F32_ARR; j++) { - ax[j] = WSP_GGML_F32_VEC_LOAD(x + i + j*WSP_GGML_F32_EPR); - ay[j] = WSP_GGML_F32_VEC_LOAD(y + i + j*WSP_GGML_F32_EPR); - - sum[j] = WSP_GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - WSP_GGML_F32_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += x[i]*y[i]; - } -#else - // scalar - wsp_ggml_float sumf = 0.0; - for (int i = 0; i < n; ++i) { - sumf += (wsp_ggml_float)(x[i]*y[i]); - } -#endif - - *s = sumf; -} - -static void wsp_ggml_vec_dot_f16(const int n, float * restrict s, wsp_ggml_fp16_t * restrict x, wsp_ggml_fp16_t * restrict y) { - wsp_ggml_float sumf = 0.0; - -#if defined(WSP_GGML_SIMD) - const int np = (n & ~(WSP_GGML_F16_STEP - 1)); - - WSP_GGML_F16_VEC sum[WSP_GGML_F16_ARR] = { WSP_GGML_F16_VEC_ZERO }; - - WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR]; - WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR]; - - for (int i = 0; i < np; i += WSP_GGML_F16_STEP) { - for (int j = 0; j < WSP_GGML_F16_ARR; j++) { - ax[j] = WSP_GGML_F16_VEC_LOAD(x + i + j*WSP_GGML_F16_EPR, j); - ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j); - - sum[j] = WSP_GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); - } - } - - // reduce sum0..sum3 to sum0 - WSP_GGML_F16_VEC_REDUCE(sumf, sum); - - // leftovers - for (int i = np; i < n; ++i) { - sumf += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[i])*WSP_GGML_FP16_TO_FP32(y[i])); - } -#else - for (int i = 0; i < n; ++i) { - sumf += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[i])*WSP_GGML_FP16_TO_FP32(y[i])); - } -#endif - - *s = sumf; -} - -static void wsp_ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q4_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - WSP_GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q4_0 * restrict x0 = &x[i + 0]; - const block_q4_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // sub 8 - const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); - const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - - // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. - const __m256i off = _mm256_set1_epi8( 8 ); - bx = _mm256_sub_epi8( bx, off ); - - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps( d, q, acc ); - } - - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); - - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); - - __m128i bx = _mm_and_si128(lowMask, tmp); - __m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx, by); - - bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); - by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx = _mm_sub_epi8(bx, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx, by); - - // Convert int32_t to float - __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); - - // Apply the scale, and accumulate - acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); - } - - *s = hsum_float_8(acc); -#elif defined(__SSSE3__) - // set constants - const __m128i lowMask = _mm_set1_epi8(0xF); - const __m128i off = _mm_set1_epi8(8); - - // Initialize accumulator with zeros - __m128 acc_0 = _mm_setzero_ps(); - __m128 acc_1 = _mm_setzero_ps(); - __m128 acc_2 = _mm_setzero_ps(); - __m128 acc_3 = _mm_setzero_ps(); - - // First round without accumulation - { - _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[0].d) * WSP_GGML_FP16_TO_FP32(y[0].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[1].d) * WSP_GGML_FP16_TO_FP32(y[1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - acc_0 = _mm_mul_ps( d_0_1, p0 ); - acc_1 = _mm_mul_ps( d_0_1, p1 ); - acc_2 = _mm_mul_ps( d_2_3, p2 ); - acc_3 = _mm_mul_ps( d_2_3, p3 ); - } - - // Main loop - WSP_GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 2; i < nb; i+=2) { - _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d) ); - - const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); - - __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); - __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); - bx_0 = _mm_sub_epi8(bx_0, off); - const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); - - __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); - __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); - bx_1 = _mm_sub_epi8(bx_1, off); - const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); - - _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); - _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); - - // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_set1_ps( WSP_GGML_FP16_TO_FP32(x[i + 1].d) * WSP_GGML_FP16_TO_FP32(y[i + 1].d) ); - - const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); - - __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); - __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); - bx_2 = _mm_sub_epi8(bx_2, off); - const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); - - __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); - __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); - bx_3 = _mm_sub_epi8(bx_3, off); - const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); - - // Convert int32_t to float - __m128 p0 = _mm_cvtepi32_ps(i32_0); - __m128 p1 = _mm_cvtepi32_ps(i32_1); - __m128 p2 = _mm_cvtepi32_ps(i32_2); - __m128 p3 = _mm_cvtepi32_ps(i32_3); - - // Apply the scale - __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); - __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); - __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); - __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); - - // Acummulate - acc_0 = _mm_add_ps(p0_d, acc_0); - acc_1 = _mm_add_ps(p1_d, acc_1); - acc_2 = _mm_add_ps(p2_d, acc_2); - acc_3 = _mm_add_ps(p3_d, acc_3); - } - - *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); - - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); - - vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - - vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - - vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 8, vl); - vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 8, vl); - - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += sumi*WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d); - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F) - 8; - const int v1 = (x[i].qs[j] >> 4) - 8; - - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); - } - - sumf += sumi*WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d); - } - - *s = sumf; -#endif -} - -static void wsp_ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q4_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - - // TODO: add WASM SIMD -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs = 0; - - WSP_GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q4_1 * restrict x0 = &x[i + 0]; - const block_q4_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i + 0]; - const block_q8_1 * restrict y1 = &y[i + 1]; - - summs += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s + WSP_GGML_FP16_TO_FP32(x1->m) * y1->s; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs; -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0; - - // Main loop - for (int i = 0; i < nb; ++i) { - const float d0 = WSP_GGML_FP16_TO_FP32(x[i].d); - const float d1 = y[i].d; - - summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - const __m256 d0v = _mm256_set1_ps( d0 ); - const __m256 d1v = _mm256_set1_ps( d1 ); - - // Compute combined scales - const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); - - // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes - const __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - - const __m256 xy = mul_sum_us8_pairs_float(bx, by); - - // Accumulate d0*d1*x*y -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d0d1, xy, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc ); -#endif - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); - - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); - - vuint8m1_t x_a = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_l = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - - vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const int v0 = (x[i].qs[j] & 0x0F); - const int v1 = (x[i].qs[j] >> 4); - - sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); - } - - sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; - } - - *s = sumf; -#endif -} - -static void wsp_ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - assert(qk == QK5_0); - - const block_q5_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - WSP_GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q5_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - // extract the 5th bit via lookup table ((!b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_1[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_1[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0 = &x[i]; - const block_q8_0 * restrict y0 = &y[i]; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_1[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_1[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_1[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_1[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit and sub 16 (equivalent to sub 0x10 when bit is zero) - const v128_t v0lf = wasm_i8x16_sub(v0l, qhl); - const v128_t v0hf = wasm_i8x16_sub(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(WSP_GGML_FP16_TO_FP32(x0->d) * WSP_GGML_FP16_TO_FP32(y0->d)))); - } - - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3); -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - // Main loop - for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_andnot_si256(bxhi, _mm256_set1_epi8((char)0xF0)); - bx = _mm256_or_si256(bx, bxhi); - - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_fmadd_ps(d, q, acc); - } - - *s = hsum_float_8(acc); -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8((char)0xF0); - - // Main loop - for (int i = 0; i < nb; i++) { - /* Compute combined scale for the block */ - const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_andnot_si128(bxhil, mask); - bxhih = _mm_andnot_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); - - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_i8_pairs_float(bx, by); - - /* Multiply q with scale and accumulate */ - acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); - } - - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - uint32_t qh; - - // These temp values are for masking and shift operations - uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - uint32_t temp_2[16] = {0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80, - 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000}; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); - - // temporary registers - vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_2, vl); - vuint32m4_t vt_2 = __riscv_vle32_v_u32m4(temp_1, vl); - vuint32m4_t vt_3 = __riscv_vsll_vx_u32m4(vt_1, 16, vl); - vuint32m4_t vt_4 = __riscv_vadd_vx_u32m4(vt_2, 12, vl); - - // ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(vt_1, qh, vl); - vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(xha_0, vt_2, vl); - vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); - - // ((qh & (1u << (j + 16))) >> (j + 12)); - vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(vt_3, qh, vl); - vuint32m4_t xhl_1 = __riscv_vsrl_vv_u32m4(xha_1, vt_4, vl); - - // narrowing - vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xhl_0, vl); - vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); - - vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xhl_1, vl); - vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); - - // load - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); - - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); - - vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); - - vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); - vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); - - vint8m1_t x_ai = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t x_li = __riscv_vreinterpret_v_u8m1_i8m1(x_l); - - vint8m1_t v0 = __riscv_vsub_vx_i8m1(x_ai, 16, vl); - vint8m1_t v1 = __riscv_vsub_vx_i8m1(x_li, 16, vl); - - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); - - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); - - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); - - sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)) * sumi; - } - - *s = sumf; -#else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12)); - - const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16; - const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16; - - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); - } - - sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)) * sumi; - } - - *s = sumf; -#endif -} - -static void wsp_ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_1; - const int nb = n / qk; - - assert(n % qk == 0); - assert(qk == QK5_1); - - const block_q5_1 * restrict x = vx; - const block_q8_1 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float summs0 = 0.0f; - float summs1 = 0.0f; - - uint32_t qh0; - uint32_t qh1; - - uint64_t tmp0[4]; - uint64_t tmp1[4]; - - WSP_GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q5_1 * restrict x1 = &x[i + 1]; - const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y1 = &y[i + 1]; - - const uint8x16_t m4b = vdupq_n_u8(0x0F); - - summs0 += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s; - summs1 += WSP_GGML_FP16_TO_FP32(x1->m) * y1->s; - - // extract the 5th bit via lookup table ((b) << 4) - memcpy(&qh0, x0->qh, sizeof(qh0)); - memcpy(&qh1, x1->qh, sizeof(qh1)); - - tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF]; - tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF]; - tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF]; - tmp0[3] = table_b2b_0[(qh0 >> 24) ]; - - tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF]; - tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF]; - tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF]; - tmp1[3] = table_b2b_0[(qh1 >> 24) ]; - - const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0)); - const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2)); - const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0)); - const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2)); - - const uint8x16_t v0_0 = vld1q_u8(x0->qs); - const uint8x16_t v0_1 = vld1q_u8(x1->qs); - - // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); - const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); - - // add high bit - const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0); - const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0); - const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1); - const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1); - - // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); - const int8x16_t v1_1l = vld1q_s8(y1->qs); - const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); -#else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h)); - - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h)); - - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), WSP_GGML_FP16_TO_FP32(x0->d)*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), WSP_GGML_FP16_TO_FP32(x1->d)*y1->d); -#endif - } - - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1; -#elif defined(__wasm_simd128__) - v128_t sumv = wasm_f32x4_splat(0.0f); - - float summs = 0.0f; - - uint32_t qh; - uint64_t tmp[4]; - - // TODO: check if unrolling this is better - for (int i = 0; i < nb; ++i) { - const block_q5_1 * restrict x0 = &x[i]; - const block_q8_1 * restrict y0 = &y[i]; - - summs += WSP_GGML_FP16_TO_FP32(x0->m) * y0->s; - - const v128_t m4b = wasm_i8x16_splat(0x0F); - - // extract the 5th bit - memcpy(&qh, x0->qh, sizeof(qh)); - - tmp[0] = table_b2b_0[(qh >> 0) & 0xFF]; - tmp[1] = table_b2b_0[(qh >> 8) & 0xFF]; - tmp[2] = table_b2b_0[(qh >> 16) & 0xFF]; - tmp[3] = table_b2b_0[(qh >> 24) ]; - - const v128_t qhl = wasm_v128_load(tmp + 0); - const v128_t qhh = wasm_v128_load(tmp + 2); - - const v128_t v0 = wasm_v128_load(x0->qs); - - // 4-bit -> 8-bit - const v128_t v0l = wasm_v128_and (v0, m4b); - const v128_t v0h = wasm_u8x16_shr(v0, 4); - - // add high bit - const v128_t v0lf = wasm_v128_or(v0l, qhl); - const v128_t v0hf = wasm_v128_or(v0h, qhh); - - // load y - const v128_t v1l = wasm_v128_load(y0->qs); - const v128_t v1h = wasm_v128_load(y0->qs + 16); - - // int8x16 -> int16x8 - const v128_t v0lfl = wasm_i16x8_extend_low_i8x16 (v0lf); - const v128_t v0lfh = wasm_i16x8_extend_high_i8x16(v0lf); - const v128_t v0hfl = wasm_i16x8_extend_low_i8x16 (v0hf); - const v128_t v0hfh = wasm_i16x8_extend_high_i8x16(v0hf); - - const v128_t v1ll = wasm_i16x8_extend_low_i8x16 (v1l); - const v128_t v1lh = wasm_i16x8_extend_high_i8x16(v1l); - const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); - const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - - // dot product - sumv = wasm_f32x4_add(sumv, - wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), - wasm_f32x4_splat(WSP_GGML_FP16_TO_FP32(x0->d) * y0->d))); - } - - *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + - wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; -#elif defined(__AVX2__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - - float summs = 0.0f; - - // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d)); - - summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - __m256i bxhi = bytes_from_bits_32(x[i].qh); - bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); - bx = _mm256_or_si256(bx, bxhi); - - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx, by); - - acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); - __m128i mask = _mm_set1_epi8(0x10); - - float summs = 0.0f; - - // Main loop - for (int i = 0; i < nb; i++) { - const __m256 dx = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d)); - - summs += WSP_GGML_FP16_TO_FP32(x[i].m) * y[i].s; - - __m256i bx = bytes_from_nibbles_32(x[i].qs); - const __m256i bxhi = bytes_from_bits_32(x[i].qh); - __m128i bxhil = _mm256_castsi256_si128(bxhi); - __m128i bxhih = _mm256_extractf128_si256(bxhi, 1); - bxhil = _mm_and_si128(bxhil, mask); - bxhih = _mm_and_si128(bxhih, mask); - __m128i bxl = _mm256_castsi256_si128(bx); - __m128i bxh = _mm256_extractf128_si256(bx, 1); - bxl = _mm_or_si128(bxl, bxhil); - bxh = _mm_or_si128(bxh, bxhih); - bx = MM256_SET_M128I(bxh, bxl); - - const __m256 dy = _mm256_set1_ps(y[i].d); - const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - - const __m256 q = mul_sum_us8_pairs_float(bx, by); - - acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); - } - - *s = hsum_float_8(acc) + summs; -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - - uint32_t qh; - - // These temp values are for shift operations - uint32_t temp_1[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; - - size_t vl = __riscv_vsetvl_e8m1(qk/2); - - for (int i = 0; i < nb; i++) { - memcpy(&qh, x[i].qh, sizeof(uint32_t)); - - // temporary registers - vuint32m4_t vt_1 = __riscv_vle32_v_u32m4(temp_1, vl); - vuint32m4_t vt_2 = __riscv_vadd_vx_u32m4(vt_1, 12, vl); +#define WSP_GGML_F16_VEC_STORE(p, r, i) WSP_GGML_F32Cx4_STORE(p, r[i]) +#define WSP_GGML_F16_VEC_FMA WSP_GGML_F32Cx4_FMA +#define WSP_GGML_F16_VEC_ADD WSP_GGML_F32Cx4_ADD +#define WSP_GGML_F16_VEC_MUL WSP_GGML_F32Cx4_MUL +#define WSP_GGML_F16_VEC_REDUCE WSP_GGML_F32Cx4_REDUCE - // load qh - vuint32m4_t vqh = __riscv_vmv_v_x_u32m4(qh, vl); +#endif - // ((qh >> (j + 0)) << 4) & 0x10; - vuint32m4_t xhr_0 = __riscv_vsrl_vv_u32m4(vqh, vt_1, vl); - vuint32m4_t xhl_0 = __riscv_vsll_vx_u32m4(xhr_0, 4, vl); - vuint32m4_t xha_0 = __riscv_vand_vx_u32m4(xhl_0, 0x10, vl); +// WSP_GGML_F32_ARR / WSP_GGML_F16_ARR +// number of registers to use per step +#ifdef WSP_GGML_SIMD +#define WSP_GGML_F32_ARR (WSP_GGML_F32_STEP/WSP_GGML_F32_EPR) +#define WSP_GGML_F16_ARR (WSP_GGML_F16_STEP/WSP_GGML_F16_EPR) +#endif - // ((qh >> (j + 12)) ) & 0x10; - vuint32m4_t xhr_1 = __riscv_vsrl_vv_u32m4(vqh, vt_2, vl); - vuint32m4_t xha_1 = __riscv_vand_vx_u32m4(xhr_1, 0x10, vl); +// +// fundamental operations +// - // narrowing - vuint16m2_t xhc_0 = __riscv_vncvt_x_x_w_u16m2(xha_0, vl); - vuint8m1_t xh_0 = __riscv_vncvt_x_x_w_u8m1(xhc_0, vl); +inline static void wsp_ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - vuint16m2_t xhc_1 = __riscv_vncvt_x_x_w_u16m2(xha_1, vl); - vuint8m1_t xh_1 = __riscv_vncvt_x_x_w_u8m1(xhc_1, vl); +inline static void wsp_ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - // load - vuint8m1_t tx = __riscv_vle8_v_u8m1(x[i].qs, vl); +inline static void wsp_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - vint8m1_t y0 = __riscv_vle8_v_i8m1(y[i].qs, vl); - vint8m1_t y1 = __riscv_vle8_v_i8m1(y[i].qs+16, vl); +inline static void wsp_ggml_vec_set_f16(const int n, wsp_ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - vuint8m1_t x_at = __riscv_vand_vx_u8m1(tx, 0x0F, vl); - vuint8m1_t x_lt = __riscv_vsrl_vx_u8m1(tx, 0x04, vl); +inline static void wsp_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } +inline static void wsp_ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } +inline static void wsp_ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } +inline static void wsp_ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; } +inline static void wsp_ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] - y[i]; } +inline static void wsp_ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void wsp_ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void wsp_ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } +inline static void wsp_ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void wsp_ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; } - vuint8m1_t x_a = __riscv_vor_vv_u8m1(x_at, xh_0, vl); - vuint8m1_t x_l = __riscv_vor_vv_u8m1(x_lt, xh_1, vl); +static void wsp_ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { +#ifdef WSP_GGML_SIMD + float sumf = 0.0f; + const int np = (n & ~(WSP_GGML_F32_STEP - 1)); - vint8m1_t v0 = __riscv_vreinterpret_v_u8m1_i8m1(x_a); - vint8m1_t v1 = __riscv_vreinterpret_v_u8m1_i8m1(x_l); + WSP_GGML_F32_VEC sum[WSP_GGML_F32_ARR] = { WSP_GGML_F32_VEC_ZERO }; - vint16m2_t vec_mul1 = __riscv_vwmul_vv_i16m2(v0, y0, vl); - vint16m2_t vec_mul2 = __riscv_vwmul_vv_i16m2(v1, y1, vl); + WSP_GGML_F32_VEC ax[WSP_GGML_F32_ARR]; + WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR]; - vint32m1_t vec_zero = __riscv_vmv_v_x_i32m1(0, vl); + for (int i = 0; i < np; i += WSP_GGML_F32_STEP) { + for (int j = 0; j < WSP_GGML_F32_ARR; j++) { + ax[j] = WSP_GGML_F32_VEC_LOAD(x + i + j*WSP_GGML_F32_EPR); + ay[j] = WSP_GGML_F32_VEC_LOAD(y + i + j*WSP_GGML_F32_EPR); - vint32m1_t vs1 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul1, vec_zero, vl); - vint32m1_t vs2 = __riscv_vwredsum_vs_i16m2_i32m1(vec_mul2, vec_zero, vl); + sum[j] = WSP_GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } - int sumi = __riscv_vmv_x_s_i32m1_i32(vs1); - sumi += __riscv_vmv_x_s_i32m1_i32(vs2); + // reduce sum0..sum3 to sum0 + WSP_GGML_F32_VEC_REDUCE(sumf, sum); - sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; + // leftovers + for (int i = np; i < n; ++i) { + sumf += x[i]*y[i]; } - - *s = sumf; #else // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); - - int sumi = 0; - - for (int j = 0; j < qk/2; ++j) { - const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10; - const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10; - - const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0; - const int32_t x1 = (x[i].qs[j] >> 4) | xh_1; - - sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); - } - - sumf += (WSP_GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + WSP_GGML_FP16_TO_FP32(x[i].m)*y[i].s; + wsp_ggml_float sumf = 0.0; + for (int i = 0; i < n; ++i) { + sumf += (wsp_ggml_float)(x[i]*y[i]); } +#endif *s = sumf; -#endif } -static void wsp_ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - - const block_q8_0 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__ARM_NEON) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - - WSP_GGML_ASSERT(nb % 2 == 0); // TODO: handle odd nb - for (int i = 0; i < nb; i += 2) { - const block_q8_0 * restrict x0 = &x[i + 0]; - const block_q8_0 * restrict x1 = &x[i + 1]; - const block_q8_0 * restrict y0 = &y[i + 0]; - const block_q8_0 * restrict y1 = &y[i + 1]; - - const int8x16_t x0_0 = vld1q_s8(x0->qs); - const int8x16_t x0_1 = vld1q_s8(x0->qs + 16); - const int8x16_t x1_0 = vld1q_s8(x1->qs); - const int8x16_t x1_1 = vld1q_s8(x1->qs + 16); - - // load y - const int8x16_t y0_0 = vld1q_s8(y0->qs); - const int8x16_t y0_1 = vld1q_s8(y0->qs + 16); - const int8x16_t y1_0 = vld1q_s8(y1->qs); - const int8x16_t y1_1 = vld1q_s8(y1->qs + 16); - -#if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( - vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); +static void wsp_ggml_vec_dot_f16(const int n, float * restrict s, wsp_ggml_fp16_t * restrict x, wsp_ggml_fp16_t * restrict y) { + wsp_ggml_float sumf = 0.0; -#else - const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); - const int16x8_t p0_1 = vmull_s8(vget_high_s8(x0_0), vget_high_s8(y0_0)); - const int16x8_t p0_2 = vmull_s8(vget_low_s8 (x0_1), vget_low_s8 (y0_1)); - const int16x8_t p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1)); - - const int16x8_t p1_0 = vmull_s8(vget_low_s8 (x1_0), vget_low_s8 (y1_0)); - const int16x8_t p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0)); - const int16x8_t p1_2 = vmull_s8(vget_low_s8 (x1_1), vget_low_s8 (y1_1)); - const int16x8_t p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1)); - - const int32x4_t p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1)); - const int32x4_t p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3)); - const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); - const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), WSP_GGML_FP16_TO_FP32(x0->d)*WSP_GGML_FP16_TO_FP32(y0->d)); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), WSP_GGML_FP16_TO_FP32(x1->d)*WSP_GGML_FP16_TO_FP32(y1->d)); -#endif - } +#if defined(WSP_GGML_SIMD) + const int np = (n & ~(WSP_GGML_F16_STEP - 1)); - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); -#elif defined(__AVX2__) || defined(__AVX__) - // Initialize accumulator with zeros - __m256 acc = _mm256_setzero_ps(); + WSP_GGML_F16_VEC sum[WSP_GGML_F16_ARR] = { WSP_GGML_F16_VEC_ZERO }; - // Main loop - for (int i = 0; i < nb; ++i) { - // Compute combined scale for the block - const __m256 d = _mm256_set1_ps(WSP_GGML_FP16_TO_FP32(x[i].d) * WSP_GGML_FP16_TO_FP32(y[i].d)); - __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); - __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); + WSP_GGML_F16_VEC ax[WSP_GGML_F16_ARR]; + WSP_GGML_F16_VEC ay[WSP_GGML_F16_ARR]; - const __m256 q = mul_sum_i8_pairs_float(bx, by); + for (int i = 0; i < np; i += WSP_GGML_F16_STEP) { + for (int j = 0; j < WSP_GGML_F16_ARR; j++) { + ax[j] = WSP_GGML_F16_VEC_LOAD(x + i + j*WSP_GGML_F16_EPR, j); + ay[j] = WSP_GGML_F16_VEC_LOAD(y + i + j*WSP_GGML_F16_EPR, j); - // Multiply q with scale and accumulate -#if defined(__AVX2__) - acc = _mm256_fmadd_ps( d, q, acc ); -#else - acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc ); -#endif + sum[j] = WSP_GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); + } } - *s = hsum_float_8(acc); -#elif defined(__riscv_v_intrinsic) - float sumf = 0.0; - size_t vl = __riscv_vsetvl_e8m1(qk); - - for (int i = 0; i < nb; i++) { - // load elements - vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); - vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); - - vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); - - vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); - vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); - - int sumi = __riscv_vmv_x_s_i32m1_i32(v_sum); + // reduce sum0..sum3 to sum0 + WSP_GGML_F16_VEC_REDUCE(sumf, sum); - sumf += sumi*(WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)); + // leftovers + for (int i = np; i < n; ++i) { + sumf += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[i])*WSP_GGML_FP16_TO_FP32(y[i])); } - - *s = sumf; #else - // scalar - float sumf = 0.0; - - for (int i = 0; i < nb; i++) { - int sumi = 0; - - for (int j = 0; j < qk; j++) { - sumi += x[i].qs[j]*y[i].qs[j]; - } - - sumf += sumi*(WSP_GGML_FP16_TO_FP32(x[i].d)*WSP_GGML_FP16_TO_FP32(y[i].d)); + for (int i = 0; i < n; ++i) { + sumf += (wsp_ggml_float)(WSP_GGML_FP16_TO_FP32(x[i])*WSP_GGML_FP16_TO_FP32(y[i])); } +#endif *s = sumf; -#endif } // compute WSP_GGML_VEC_DOT_UNROLL dot products at once @@ -3707,6 +1319,58 @@ inline static void wsp_ggml_vec_mad_f32(const int n, float * restrict y, const f #endif } +// xs and vs are byte strides of x and v +inline static void wsp_ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { + + const float * restrict x[WSP_GGML_VEC_MAD_UNROLL]; + const float * restrict v[WSP_GGML_VEC_MAD_UNROLL]; + + for (int i = 0; i < WSP_GGML_VEC_MAD_UNROLL; ++i) { + x[i] = (const float *) ((const char *) xv + i*xs); + v[i] = (const float *) ((const char *) vv + i*vs); + } + +#if defined(WSP_GGML_SIMD) + const int np = (n & ~(WSP_GGML_F32_STEP - 1)); + + WSP_GGML_F32_VEC vx[WSP_GGML_VEC_MAD_UNROLL]; + + for (int k = 0; k < WSP_GGML_VEC_MAD_UNROLL; ++k) { + vx[k] = WSP_GGML_F32_VEC_SET1(v[k][0]); + } + + WSP_GGML_F32_VEC ax[WSP_GGML_VEC_MAD_UNROLL][WSP_GGML_F32_ARR]; + WSP_GGML_F32_VEC ay[WSP_GGML_F32_ARR]; + + for (int i = 0; i < np; i += WSP_GGML_F32_STEP) { + for (int j = 0; j < WSP_GGML_F32_ARR; j++) { + ay[j] = WSP_GGML_F32_VEC_LOAD(y + i + j*WSP_GGML_F32_EPR); + + for (int k = 0; k < WSP_GGML_VEC_MAD_UNROLL; ++k) { + ax[k][j] = WSP_GGML_F32_VEC_LOAD(x[k] + i + j*WSP_GGML_F32_EPR); + ay[j] = WSP_GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); + } + + WSP_GGML_F32_VEC_STORE(y + i + j*WSP_GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int k = 0; k < WSP_GGML_VEC_MAD_UNROLL; ++k) { + for (int i = np; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#else + // scalar + for (int k = 0; k < WSP_GGML_VEC_MAD_UNROLL; ++k) { + for (int i = 0; i < n; ++i) { + y[i] += x[k][i]*v[k][0]; + } + } +#endif +} + //inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; } inline static void wsp_ggml_vec_scale_f32(const int n, float * y, const float v) { #if defined(WSP_GGML_USE_ACCELERATE) @@ -3749,6 +1413,7 @@ inline static void wsp_ggml_vec_step_f32 (const int n, float * y, const float * inline static void wsp_ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); } inline static void wsp_ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; } inline static void wsp_ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } +inline static void wsp_ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; } static const float GELU_COEF_A = 0.044715f; static const float GELU_QUICK_COEF = -1.702f; @@ -3761,7 +1426,7 @@ inline static float wsp_ggml_gelu_f32(float x) { inline static void wsp_ggml_vec_gelu_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x) { const uint16_t * i16 = (const uint16_t *) x; for (int i = 0; i < n; ++i) { - y[i] = table_gelu_f16[i16[i]]; + y[i] = wsp_ggml_table_gelu_f16[i16[i]]; } } @@ -3771,7 +1436,7 @@ inline static void wsp_ggml_vec_gelu_f32(const int n, float * y, const float * x for (int i = 0; i < n; ++i) { wsp_ggml_fp16_t fp16 = WSP_GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = WSP_GGML_FP16_TO_FP32(table_gelu_f16[t]); + y[i] = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_gelu_f16[t]); } } #else @@ -3789,7 +1454,7 @@ inline static float wsp_ggml_gelu_quick_f32(float x) { //inline static void wsp_ggml_vec_gelu_quick_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x) { // const uint16_t * i16 = (const uint16_t *) x; // for (int i = 0; i < n; ++i) { -// y[i] = table_gelu_quick_f16[i16[i]]; +// y[i] = wsp_ggml_table_gelu_quick_f16[i16[i]]; // } //} @@ -3799,7 +1464,7 @@ inline static void wsp_ggml_vec_gelu_quick_f32(const int n, float * y, const flo for (int i = 0; i < n; ++i) { wsp_ggml_fp16_t fp16 = WSP_GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = WSP_GGML_FP16_TO_FP32(table_gelu_quick_f16[t]); + y[i] = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_gelu_quick_f16[t]); } } #else @@ -3818,7 +1483,7 @@ inline static float wsp_ggml_silu_f32(float x) { //inline static void wsp_ggml_vec_silu_f16(const int n, wsp_ggml_fp16_t * y, const wsp_ggml_fp16_t * x) { // const uint16_t * i16 = (const uint16_t *) x; // for (int i = 0; i < n; ++i) { -// y[i] = table_silu_f16[i16[i]]; +// y[i] = wsp_ggml_table_silu_f16[i16[i]]; // } //} @@ -3828,7 +1493,7 @@ inline static void wsp_ggml_vec_silu_f32(const int n, float * y, const float * x for (int i = 0; i < n; ++i) { wsp_ggml_fp16_t fp16 = WSP_GGML_FP32_TO_FP16(x[i]); memcpy(&t, &fp16, sizeof(uint16_t)); - y[i] = WSP_GGML_FP16_TO_FP32(table_silu_f16[t]); + y[i] = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_silu_f16[t]); } } #else @@ -3970,7 +1635,12 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = { "ALIBI", "CLAMP", "CONV_1D", + "CONV_1D_STAGE_0", + "CONV_1D_STAGE_1", + "CONV_TRANSPOSE_1D", "CONV_2D", + "CONV_2D_STAGE_0", + "CONV_2D_STAGE_1", "CONV_TRANSPOSE_2D", "POOL_1D", "POOL_2D", @@ -4001,7 +1671,7 @@ static const char * WSP_GGML_OP_NAME[WSP_GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(WSP_GGML_OP_COUNT == 68, "WSP_GGML_OP_COUNT != 68"); +static_assert(WSP_GGML_OP_COUNT == 73, "WSP_GGML_OP_COUNT != 73"); static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = { "none", @@ -4052,7 +1722,12 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = { "alibi(x)", "clamp(x)", "conv_1d(x)", + "conv_1d_stage_0(x)", + "conv_1d_stage_1(x)", + "conv_transpose_1d(x)", "conv_2d(x)", + "conv_2d_stage_0(x)", + "conv_2d_stage_1(x)", "conv_transpose_2d(x)", "pool_1d(x)", "pool_2d(x)", @@ -4083,7 +1758,7 @@ static const char * WSP_GGML_OP_SYMBOL[WSP_GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(WSP_GGML_OP_COUNT == 68, "WSP_GGML_OP_COUNT != 68"); +static_assert(WSP_GGML_OP_COUNT == 73, "WSP_GGML_OP_COUNT != 73"); static_assert(WSP_GGML_OP_POOL_COUNT == 2, "WSP_GGML_OP_POOL_COUNT != 2"); @@ -4112,7 +1787,12 @@ static void wsp_ggml_setup_op_has_task_pass(void) { p[WSP_GGML_OP_DIAG_MASK_INF ] = true; p[WSP_GGML_OP_DIAG_MASK_ZERO ] = true; p[WSP_GGML_OP_CONV_1D ] = true; + p[WSP_GGML_OP_CONV_1D_STAGE_0 ] = true; + p[WSP_GGML_OP_CONV_1D_STAGE_1 ] = true; + p[WSP_GGML_OP_CONV_TRANSPOSE_1D ] = true; p[WSP_GGML_OP_CONV_2D ] = true; + p[WSP_GGML_OP_CONV_2D_STAGE_0 ] = true; + p[WSP_GGML_OP_CONV_2D_STAGE_1 ] = true; p[WSP_GGML_OP_CONV_TRANSPOSE_2D ] = true; p[WSP_GGML_OP_FLASH_ATTN_BACK ] = true; p[WSP_GGML_OP_CROSS_ENTROPY_LOSS ] = true; @@ -4392,10 +2072,9 @@ static inline bool wsp_ggml_can_mul_mat(const struct wsp_ggml_tensor * t0, const static inline bool wsp_ggml_can_out_prod(const struct wsp_ggml_tensor * t0, const struct wsp_ggml_tensor * t1) { static_assert(WSP_GGML_MAX_DIMS == 4, "WSP_GGML_MAX_DIMS is not 4 - update this function"); - return - (t0->ne[1] == t1->ne[1]) && - (t0->ne[2] == t1->ne[2]) && - (t0->ne[3] == t1->ne[3]); + return (t0->ne[1] == t1->ne[1]) && + (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable + (t1->ne[3]%t0->ne[3] == 0); } enum wsp_ggml_type wsp_ggml_ftype_to_wsp_ggml_type(enum wsp_ggml_ftype ftype) { @@ -4530,11 +2209,11 @@ struct wsp_ggml_context * wsp_ggml_init(struct wsp_ggml_init_params params) { for (int i = 0; i < (1 << 16); ++i) { uint16_t ui = i; memcpy(&ii, &ui, sizeof(ii)); - const float f = table_f32_f16[i] = WSP_GGML_COMPUTE_FP16_TO_FP32(ii); - table_gelu_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_gelu_f32(f)); - table_gelu_quick_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_gelu_quick_f32(f)); - table_silu_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_silu_f32(f)); - table_exp_f16[i] = WSP_GGML_FP32_TO_FP16(expf(f)); + const float f = wsp_ggml_table_f32_f16[i] = WSP_GGML_COMPUTE_FP16_TO_FP32(ii); + wsp_ggml_table_gelu_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_gelu_f32(f)); + wsp_ggml_table_gelu_quick_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_gelu_quick_f32(f)); + wsp_ggml_table_silu_f16[i] = WSP_GGML_FP32_TO_FP16(wsp_ggml_silu_f32(f)); + wsp_ggml_table_exp_f16[i] = WSP_GGML_FP32_TO_FP16(expf(f)); } const uint64_t t_end = wsp_ggml_time_us(); UNUSED(t_end); @@ -4830,6 +2509,7 @@ static struct wsp_ggml_tensor * wsp_ggml_new_tensor_impl( *result = (struct wsp_ggml_tensor) { /*.type =*/ type, /*.backend =*/ WSP_GGML_BACKEND_CPU, + /*.buffer =*/ NULL, /*.n_dims =*/ n_dims, /*.ne =*/ { 1, 1, 1, 1 }, /*.nb =*/ { 0, 0, 0, 0 }, @@ -5065,43 +2745,78 @@ struct wsp_ggml_tensor * wsp_ggml_set_f32(struct wsp_ggml_tensor * tensor, float return tensor; } +void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3) { + const int64_t ne2 = tensor->ne[2]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne0 = tensor->ne[0]; + + const int64_t i3_ = (i/(ne2*ne1*ne0)); + const int64_t i2_ = (i - i3_*ne2*ne1*ne0)/(ne1*ne0); + const int64_t i1_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0)/ne0; + const int64_t i0_ = (i - i3_*ne2*ne1*ne0 - i2_*ne1*ne0 - i1_*ne0); + + if (i0) { + * i0 = i0_; + } + if (i1) { + * i1 = i1_; + } + if (i2) { + * i2 = i2_; + } + if (i3) { + * i3 = i3_; + } +} + int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i) { + if (!wsp_ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + wsp_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return wsp_ggml_get_i32_nd(tensor, id[0], id[1], id[2], id[3]); + } switch (tensor->type) { case WSP_GGML_TYPE_I8: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; - } break; + } case WSP_GGML_TYPE_I16: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; - } break; + } case WSP_GGML_TYPE_I32: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; - } break; + } case WSP_GGML_TYPE_F16: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *)(tensor->data))[i]); - } break; + } case WSP_GGML_TYPE_F32: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; - } break; + } default: { WSP_GGML_ASSERT(false); - } break; + } } return 0.0f; } void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value) { + if (!wsp_ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + wsp_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + wsp_ggml_set_i32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } switch (tensor->type) { case WSP_GGML_TYPE_I8: { @@ -5135,68 +2850,179 @@ void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t v } } +int32_t wsp_ggml_get_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case WSP_GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case WSP_GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case WSP_GGML_TYPE_F16: + return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *) data)[0]); + case WSP_GGML_TYPE_F32: + return ((float *) data)[0]; + default: + WSP_GGML_ASSERT(false); + } + + return 0.0f; +} + +void wsp_ggml_set_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + { + ((int8_t *)(data))[0] = value; + } break; + case WSP_GGML_TYPE_I16: + { + ((int16_t *)(data))[0] = value; + } break; + case WSP_GGML_TYPE_I32: + { + ((int32_t *)(data))[0] = value; + } break; + case WSP_GGML_TYPE_F16: + { + ((wsp_ggml_fp16_t *)(data))[0] = WSP_GGML_FP32_TO_FP16(value); + } break; + case WSP_GGML_TYPE_F32: + { + ((float *)(data))[0] = value; + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i) { + if (!wsp_ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + wsp_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + return wsp_ggml_get_f32_nd(tensor, id[0], id[1], id[2], id[3]); + } switch (tensor->type) { case WSP_GGML_TYPE_I8: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); return ((int8_t *)(tensor->data))[i]; - } break; + } case WSP_GGML_TYPE_I16: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); return ((int16_t *)(tensor->data))[i]; - } break; + } case WSP_GGML_TYPE_I32: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); return ((int32_t *)(tensor->data))[i]; - } break; + } case WSP_GGML_TYPE_F16: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *)(tensor->data))[i]); - } break; + } case WSP_GGML_TYPE_F32: { WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float)); return ((float *)(tensor->data))[i]; + } + default: + { + WSP_GGML_ASSERT(false); + } + } + + return 0.0f; +} + +void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value) { + if (!wsp_ggml_is_contiguous(tensor)) { + int64_t id[4] = { 0, 0, 0, 0 }; + wsp_ggml_unravel_index(tensor, i, &id[0], &id[1], &id[2], &id[3]); + wsp_ggml_set_f32_nd(tensor, id[0], id[1], id[2], id[3], value); + return; + } + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); + ((int8_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_I16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); + ((int16_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_I32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); + ((int32_t *)(tensor->data))[i] = value; + } break; + case WSP_GGML_TYPE_F16: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); + ((wsp_ggml_fp16_t *)(tensor->data))[i] = WSP_GGML_FP32_TO_FP16(value); + } break; + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float)); + ((float *)(tensor->data))[i] = value; } break; default: { WSP_GGML_ASSERT(false); } break; } +} + +float wsp_ggml_get_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; + switch (tensor->type) { + case WSP_GGML_TYPE_I8: + return ((int8_t *) data)[0]; + case WSP_GGML_TYPE_I16: + return ((int16_t *) data)[0]; + case WSP_GGML_TYPE_I32: + return ((int32_t *) data)[0]; + case WSP_GGML_TYPE_F16: + return WSP_GGML_FP16_TO_FP32(((wsp_ggml_fp16_t *) data)[0]); + case WSP_GGML_TYPE_F32: + return ((float *) data)[0]; + default: + WSP_GGML_ASSERT(false); + } return 0.0f; } -void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value) { +void wsp_ggml_set_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value) { + void * data = (char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1] + i2*tensor->nb[2] + i3*tensor->nb[3]; switch (tensor->type) { case WSP_GGML_TYPE_I8: { - WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int8_t)); - ((int8_t *)(tensor->data))[i] = value; + ((int8_t *)(data))[0] = value; } break; case WSP_GGML_TYPE_I16: { - WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int16_t)); - ((int16_t *)(tensor->data))[i] = value; + ((int16_t *)(data))[0] = value; } break; case WSP_GGML_TYPE_I32: { - WSP_GGML_ASSERT(tensor->nb[0] == sizeof(int32_t)); - ((int32_t *)(tensor->data))[i] = value; + ((int32_t *)(data))[0] = value; } break; case WSP_GGML_TYPE_F16: { - WSP_GGML_ASSERT(tensor->nb[0] == sizeof(wsp_ggml_fp16_t)); - ((wsp_ggml_fp16_t *)(tensor->data))[i] = WSP_GGML_FP32_TO_FP16(value); + ((wsp_ggml_fp16_t *)(data))[0] = WSP_GGML_FP32_TO_FP16(value); } break; case WSP_GGML_TYPE_F32: { - WSP_GGML_ASSERT(tensor->nb[0] == sizeof(float)); - ((float *)(tensor->data))[i] = value; + ((float *)(data))[0] = value; } break; default: { @@ -5250,6 +3076,39 @@ struct wsp_ggml_tensor * wsp_ggml_view_tensor( return result; } +struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx) { + struct wsp_ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == WSP_GGML_OBJECT_TENSOR) { + return (struct wsp_ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct wsp_ggml_tensor * wsp_ggml_get_next_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor) { + struct wsp_ggml_object * obj = (struct wsp_ggml_object *) ((char *)tensor - WSP_GGML_OBJECT_SIZE); + obj = obj->next; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == WSP_GGML_OBJECT_TENSOR) { + return (struct wsp_ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name) { struct wsp_ggml_object * obj = ctx->objects_begin; @@ -5347,6 +3206,44 @@ struct wsp_ggml_tensor * wsp_ggml_add_inplace( return wsp_ggml_add_impl(ctx, a, b, true); } +// wsp_ggml_add_cast + +static struct wsp_ggml_tensor * wsp_ggml_add_cast_impl( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + enum wsp_ggml_type type) { + // TODO: support less-strict constraint + // WSP_GGML_ASSERT(wsp_ggml_can_repeat(b, a)); + WSP_GGML_ASSERT(wsp_ggml_can_repeat_rows(b, a)); + WSP_GGML_ASSERT(wsp_ggml_is_quantized(a->type) || a->type == WSP_GGML_TYPE_F16); // currently only supported for quantized input and f16 + + bool is_node = false; + + if (a->grad || b->grad) { + // TODO: support backward pass for broadcasting + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(a, b)); + is_node = true; + } + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, type, a->n_dims, a->ne); + + result->op = WSP_GGML_OP_ADD; + result->grad = is_node ? wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, a->n_dims, a->ne) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct wsp_ggml_tensor * wsp_ggml_add_cast( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + enum wsp_ggml_type type) { + return wsp_ggml_add_cast_impl(ctx, a, b, type); +} + // wsp_ggml_add1 static struct wsp_ggml_tensor * wsp_ggml_add1_impl( @@ -5639,7 +3536,6 @@ struct wsp_ggml_tensor * wsp_ggml_sqrt_inplace( return wsp_ggml_sqrt_impl(ctx, a, true); } - // wsp_ggml_log static struct wsp_ggml_tensor * wsp_ggml_log_impl( @@ -5693,7 +3589,6 @@ struct wsp_ggml_tensor * wsp_ggml_sum( return result; } - // wsp_ggml_sum_rows struct wsp_ggml_tensor * wsp_ggml_sum_rows( @@ -5783,7 +3678,6 @@ struct wsp_ggml_tensor * wsp_ggml_repeat( result->op = WSP_GGML_OP_REPEAT; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -5811,7 +3705,6 @@ struct wsp_ggml_tensor * wsp_ggml_repeat_back( result->op = WSP_GGML_OP_REPEAT_BACK; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; - result->src[1] = b; return result; } @@ -5938,6 +3831,14 @@ struct wsp_ggml_tensor * wsp_ggml_relu_inplace( return wsp_ggml_unary_inplace(ctx, a, WSP_GGML_UNARY_OP_RELU); } +// wsp_ggml_leaky + +struct wsp_ggml_tensor * wsp_ggml_leaky( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a) { + return wsp_ggml_unary(ctx, a, WSP_GGML_UNARY_OP_LEAKY); +} + // wsp_ggml_gelu struct wsp_ggml_tensor * wsp_ggml_gelu( @@ -6186,8 +4087,9 @@ struct wsp_ggml_tensor * wsp_ggml_out_prod( is_node = true; } - const int64_t ne[4] = { a->ne[0], b->ne[0], a->ne[2], b->ne[3] }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); + // a is broadcastable to b for ne[2] and ne[3] -> use b->ne[2] and b->ne[3] + const int64_t ne[4] = { a->ne[0], b->ne[0], b->ne[2], b->ne[3] }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne); result->op = WSP_GGML_OP_OUT_PROD; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; @@ -6326,7 +4228,6 @@ struct wsp_ggml_tensor * wsp_ggml_set_2d_inplace( return wsp_ggml_set_impl(ctx, a, b, nb1, a->nb[2], a->nb[3], offset, false); } - // wsp_ggml_cpy static struct wsp_ggml_tensor * wsp_ggml_cpy_impl( @@ -6406,6 +4307,52 @@ struct wsp_ggml_tensor * wsp_ggml_cont_inplace( return wsp_ggml_cont_impl(ctx, a, true); } +// make contiguous, with new shape +WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0) { + return wsp_ggml_cont_4d(ctx, a, ne0, 1, 1, 1); +} + +WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1) { + return wsp_ggml_cont_4d(ctx, a, ne0, ne1, 1, 1); +} + +WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_3d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + return wsp_ggml_cont_4d(ctx, a, ne0, ne1, ne2, 1); +} + +struct wsp_ggml_tensor * wsp_ggml_cont_4d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3) { + WSP_GGML_ASSERT(wsp_ggml_nelements(a) == (ne0*ne1*ne2*ne3)); + + bool is_node = false; + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); + wsp_ggml_format_name(result, "%s (cont)", a->name); + + result->op = WSP_GGML_OP_CONT; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + // wsp_ggml_reshape struct wsp_ggml_tensor * wsp_ggml_reshape( @@ -6413,7 +4360,7 @@ struct wsp_ggml_tensor * wsp_ggml_reshape( struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b) { WSP_GGML_ASSERT(wsp_ggml_is_contiguous(a)); - WSP_GGML_ASSERT(wsp_ggml_is_contiguous(b)); + // as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous. WSP_GGML_ASSERT(wsp_ggml_nelements(a) == wsp_ggml_nelements(b)); bool is_node = false; @@ -6786,7 +4733,6 @@ struct wsp_ggml_tensor * wsp_ggml_get_rows_back( result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; - result->src[2] = c; return result; } @@ -6813,7 +4759,6 @@ struct wsp_ggml_tensor * wsp_ggml_diag( return result; } - // wsp_ggml_diag_mask_inf static struct wsp_ggml_tensor * wsp_ggml_diag_mask_inf_impl( @@ -6925,7 +4870,6 @@ struct wsp_ggml_tensor * wsp_ggml_soft_max_inplace( return wsp_ggml_soft_max_impl(ctx, a, true); } - // wsp_ggml_soft_max_back static struct wsp_ggml_tensor * wsp_ggml_soft_max_back_impl( @@ -6968,16 +4912,24 @@ struct wsp_ggml_tensor * wsp_ggml_soft_max_back_inplace( static struct wsp_ggml_tensor * wsp_ggml_rope_impl( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, float xpos_base, bool xpos_down, bool inplace) { - WSP_GGML_ASSERT(n_past >= 0); + WSP_GGML_ASSERT(wsp_ggml_is_vector(b)); + WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(a->ne[2] == b->ne[0]); + bool is_node = false; if (a->grad) { @@ -6986,16 +4938,21 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl( struct wsp_ggml_tensor * result = inplace ? wsp_ggml_view_tensor(ctx, a) : wsp_ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; - memcpy(params + 4, &freq_base, sizeof(float)); - memcpy(params + 5, &freq_scale, sizeof(float)); - memcpy(params + 6, &xpos_base, sizeof(float)); - memcpy(params + 7, &xpos_down, sizeof(bool)); + int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(params + 11, &xpos_base, sizeof(float)); + memcpy(params + 12, &xpos_down, sizeof(bool)); wsp_ggml_set_op_params(result, params, sizeof(params)); result->op = WSP_GGML_OP_ROPE; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -7003,55 +4960,75 @@ static struct wsp_ggml_tensor * wsp_ggml_rope_impl( struct wsp_ggml_tensor * wsp_ggml_rope( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return wsp_ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, false); + return wsp_ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, false + ); } struct wsp_ggml_tensor * wsp_ggml_rope_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx) { - return wsp_ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, 0.0f, false, true); + return wsp_ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, false, true + ); } struct wsp_ggml_tensor * wsp_ggml_rope_custom( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, - float freq_scale) { - return wsp_ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, false); + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return wsp_ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, false + ); } struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, - float freq_scale) { - return wsp_ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, 0.0f, false, true); + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + return wsp_ggml_rope_impl( + ctx, a, b, n_dims, mode, n_ctx, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow, 0.0f, false, true + ); } struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, float base, bool down) { - return wsp_ggml_rope_impl(ctx, a, n_past, n_dims, 0, 0, 10000.0f, 1.0f, base, down, true); + return wsp_ggml_rope_impl(ctx, a, b, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true); } // wsp_ggml_rope_back @@ -7059,7 +5036,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace( struct wsp_ggml_tensor * wsp_ggml_rope_back( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -7067,7 +5044,10 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back( float freq_scale, float xpos_base, bool xpos_down) { - WSP_GGML_ASSERT(n_past >= 0); + WSP_GGML_ASSERT(wsp_ggml_is_vector(b)); + WSP_GGML_ASSERT(b->type == WSP_GGML_TYPE_I32); + WSP_GGML_ASSERT(a->ne[2] == b->ne[0]); + WSP_GGML_ASSERT((mode & 4) == 0 && "wsp_ggml_rope_back() for ChatGLM not implemented yet"); bool is_node = false; @@ -7078,7 +5058,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back( struct wsp_ggml_tensor * result = wsp_ggml_dup_tensor(ctx, a); - int32_t params[8] = { n_past, n_dims, mode, n_ctx }; + int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx }; memcpy(params + 4, &freq_base, sizeof(float)); memcpy(params + 5, &freq_scale, sizeof(float)); memcpy(params + 6, &xpos_base, sizeof(float)); @@ -7088,6 +5068,7 @@ struct wsp_ggml_tensor * wsp_ggml_rope_back( result->op = WSP_GGML_OP_ROPE_BACK; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; + result->src[1] = b; return result; } @@ -7156,14 +5137,17 @@ static int64_t wsp_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, in return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; } -WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( - struct wsp_ggml_context * ctx, - struct wsp_ggml_tensor * a, - struct wsp_ggml_tensor * b, - int s0, - int p0, - int d0) { - WSP_GGML_ASSERT(wsp_ggml_is_matrix(b)); +// im2col: [N, IC, IL] => [N, OL, IC*K] +// a: [OC,IC, K] +// b: [N, IC, IL] +// result: [N, OL, IC*K] +static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_0( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int p0, + int d0) { WSP_GGML_ASSERT(a->ne[1] == b->ne[1]); bool is_node = false; @@ -7172,16 +5156,54 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( is_node = true; } + const int64_t OL = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + const int64_t ne[4] = { - wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), - a->ne[2], 1, 1, + a->ne[1] * a->ne[0], + OL, + b->ne[2], + 1, }; - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 2, ne); + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne); int32_t params[] = { s0, p0, d0 }; wsp_ggml_set_op_params(result, params, sizeof(params)); - result->op = WSP_GGML_OP_CONV_1D; + result->op = WSP_GGML_OP_CONV_1D_STAGE_0; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// wsp_ggml_conv_1d_stage_1 + +// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] +// a: [OC, IC, K] +// b: [N, OL, IC * K] +// result: [N, OC, OL] +static struct wsp_ggml_tensor * wsp_ggml_conv_1d_stage_1( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + + bool is_node = false; + + if (a->grad || b->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + b->ne[1], + a->ne[2], + b->ne[2], + 1, + }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); + + result->op = WSP_GGML_OP_CONV_1D_STAGE_1; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; @@ -7189,6 +5211,53 @@ WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( return result; } +// wsp_ggml_conv_1d + +WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int p0, + int d0) { + struct wsp_ggml_tensor * result = wsp_ggml_conv_1d_stage_0(ctx, a, b, s0, p0, d0); + result = wsp_ggml_conv_1d_stage_1(ctx, a, result); + return result; +} + +// WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_1d( +// struct wsp_ggml_context * ctx, +// struct wsp_ggml_tensor * a, +// struct wsp_ggml_tensor * b, +// int s0, +// int p0, +// int d0) { +// WSP_GGML_ASSERT(wsp_ggml_is_matrix(b)); +// WSP_GGML_ASSERT(a->ne[1] == b->ne[1]); +// bool is_node = false; + +// if (a->grad || b->grad) { +// WSP_GGML_ASSERT(false); // TODO: implement backward +// is_node = true; +// } + +// const int64_t ne[4] = { +// wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), +// a->ne[2], 1, 1, +// }; +// struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 2, ne); + +// int32_t params[] = { s0, p0, d0 }; +// wsp_ggml_set_op_params(result, params, sizeof(params)); + +// result->op = WSP_GGML_OP_CONV_1D; +// result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; +// result->src[0] = a; +// result->src[1] = b; + +// return result; +// } + // wsp_ggml_conv_1d_ph struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph( @@ -7200,9 +5269,57 @@ struct wsp_ggml_tensor* wsp_ggml_conv_1d_ph( return wsp_ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); } +// wsp_ggml_conv_transpose_1d + +static int64_t wsp_ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; +} + +WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int p0, + int d0) { + WSP_GGML_ASSERT(wsp_ggml_is_matrix(b)); + WSP_GGML_ASSERT(a->ne[2] == b->ne[1]); + WSP_GGML_ASSERT(a->ne[3] == 1); + + WSP_GGML_ASSERT(p0 == 0); + WSP_GGML_ASSERT(d0 == 1); + + bool is_node = false; + + if (a->grad || b->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t ne[4] = { + wsp_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), + a->ne[1], b->ne[2], 1, + }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); + + int32_t params[] = { s0, p0, d0 }; + wsp_ggml_set_op_params(result, params, sizeof(params)); + + result->op = WSP_GGML_OP_CONV_TRANSPOSE_1D; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + // wsp_ggml_conv_2d -struct wsp_ggml_tensor * wsp_ggml_conv_2d( +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OH, OW, IC*KH*KW] +static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_0( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b, @@ -7213,7 +5330,46 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d( int d0, int d1) { - WSP_GGML_ASSERT(a->ne[2] == b->ne[2]); + WSP_GGML_ASSERT(a->ne[2] == b->ne[2]); + bool is_node = false; + + if (a->grad || b->grad) { + WSP_GGML_ASSERT(false); // TODO: implement backward + is_node = true; + } + + const int64_t OH = wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1); + const int64_t OW = wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + + const int64_t ne[4] = { + a->ne[2] * a->ne[1] * a->ne[0], + OW, + OH, + b->ne[3], + }; + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F16, 4, ne); + + int32_t params[] = { s0, s1, p0, p1, d0, d1 }; + wsp_ggml_set_op_params(result, params, sizeof(params)); + + result->op = WSP_GGML_OP_CONV_2D_STAGE_0; + result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; + +} + +// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] +// a: [OC, IC, KH, KW] +// b: [N, OH, OW, IC * KH * KW] +// result: [N, OC, OH, OW] +static struct wsp_ggml_tensor * wsp_ggml_conv_2d_stage_1( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b) { + bool is_node = false; if (a->grad || b->grad) { @@ -7222,16 +5378,14 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d( } const int64_t ne[4] = { - wsp_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), - wsp_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1), - a->ne[3], b->ne[3], + b->ne[1], + b->ne[2], + a->ne[3], + b->ne[3], }; struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1 }; - wsp_ggml_set_op_params(result, params, sizeof(params)); - - result->op = WSP_GGML_OP_CONV_2D; + result->op = WSP_GGML_OP_CONV_2D_STAGE_1; result->grad = is_node ? wsp_ggml_dup_tensor(ctx, result) : NULL; result->src[0] = a; result->src[1] = b; @@ -7240,8 +5394,28 @@ struct wsp_ggml_tensor * wsp_ggml_conv_2d( } -// wsp_ggml_conv_2d_sk_p0 +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OC, OH, OW] +struct wsp_ggml_tensor * wsp_ggml_conv_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + + struct wsp_ggml_tensor * result = wsp_ggml_conv_2d_stage_0(ctx, a, b, s0, s1, p0, p1, d0, d1); // [N, OH, OW, IC * KH * KW] + result = wsp_ggml_conv_2d_stage_1(ctx, a, result); + + return result; + +} +// wsp_ggml_conv_2d_sk_p0 struct wsp_ggml_tensor * wsp_ggml_conv_2d_sk_p0( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, @@ -7298,7 +5472,7 @@ struct wsp_ggml_tensor * wsp_ggml_conv_transpose_2d_p0( // wsp_ggml_pool_* -static int64_t wsp_ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) { +static int64_t wsp_ggml_calc_pool_output_size(int64_t ins, int ks, int s, float p) { return (ins + 2 * p - ks) / s + 1; } @@ -7345,8 +5519,8 @@ struct wsp_ggml_tensor * wsp_ggml_pool_2d( int k1, int s0, int s1, - int p0, - int p1) { + float p0, + float p1) { bool is_node = false; @@ -7484,27 +5658,30 @@ struct wsp_ggml_tensor * wsp_ggml_flash_attn_back( // d shape [D,N,ne2,ne3] // q shape [D,N,ne2,ne3] - // k shape [D,M,ne2,ne3] - // v shape [M,D,ne2,ne3] + // k shape [D,M,kvne2,ne3] + // v shape [M,D,kvne2,ne3] - const int64_t D = q->ne[0]; - const int64_t N = q->ne[1]; - const int64_t M = k->ne[1]; - const int64_t ne2 = q->ne[2]; - const int64_t ne3 = q->ne[3]; + const int64_t D = q->ne[0]; + const int64_t N = q->ne[1]; + const int64_t M = k->ne[1]; + const int64_t ne2 = q->ne[2]; + const int64_t ne3 = q->ne[3]; + const int64_t kvne2 = k->ne[2]; WSP_GGML_ASSERT(k->ne[0] == D); WSP_GGML_ASSERT(v->ne[0] == M); WSP_GGML_ASSERT(v->ne[1] == D); WSP_GGML_ASSERT(d->ne[0] == D); WSP_GGML_ASSERT(d->ne[1] == N); - WSP_GGML_ASSERT(k->ne[2] == ne2); + WSP_GGML_ASSERT(k->ne[2] == kvne2); WSP_GGML_ASSERT(k->ne[3] == ne3); - WSP_GGML_ASSERT(v->ne[2] == ne2); + WSP_GGML_ASSERT(v->ne[2] == kvne2); WSP_GGML_ASSERT(v->ne[3] == ne3); WSP_GGML_ASSERT(d->ne[2] == ne2); WSP_GGML_ASSERT(d->ne[3] == ne3); + WSP_GGML_ASSERT(ne2 % kvne2 == 0); + bool is_node = false; if (q->grad || k->grad || v->grad) { @@ -7514,14 +5691,23 @@ struct wsp_ggml_tensor * wsp_ggml_flash_attn_back( } // store gradients of q, k and v as continuous tensors concatenated in result. - // q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3] - // gradq->data = result->data - // gradk->data = result->data + nb0*D*N*ne2*ne3 - // gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3 // note: v and gradv are actually transposed, i.e. v->ne[0] != D. - int64_t ne[4] = {D,M+N+M,ne2,ne3}; + const int64_t elem_q = wsp_ggml_nelements(q); + const int64_t elem_k = wsp_ggml_nelements(k); + const int64_t elem_v = wsp_ggml_nelements(v); - struct wsp_ggml_tensor * result = wsp_ggml_new_tensor(ctx, WSP_GGML_TYPE_F32, 4, ne); + enum wsp_ggml_type result_type = WSP_GGML_TYPE_F32; + WSP_GGML_ASSERT(wsp_ggml_blck_size(result_type) == 1); + const size_t tsize = wsp_ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + WSP_GGML_PAD(elem_q * tsize, WSP_GGML_MEM_ALIGN); + const size_t offs_v = offs_k + WSP_GGML_PAD(elem_k * tsize, WSP_GGML_MEM_ALIGN); + const size_t end = offs_v + WSP_GGML_PAD(elem_v * tsize, WSP_GGML_MEM_ALIGN); + + const size_t nelements = (end + tsize - 1)/tsize; + + struct wsp_ggml_tensor * result = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nelements); int32_t masked_i = masked ? 1 : 0; wsp_ggml_set_op_params(result, &masked_i, sizeof(masked_i)); @@ -7668,7 +5854,6 @@ static struct wsp_ggml_tensor * wsp_ggml_add_rel_pos_impl( return result; } - struct wsp_ggml_tensor * wsp_ggml_add_rel_pos( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, @@ -8113,8 +6298,6 @@ struct wsp_ggml_tensor * wsp_ggml_map_custom3_inplace( return wsp_ggml_map_custom3_impl(ctx, a, b, c, fun, n_tasks, userdata, true); } - - // wsp_ggml_cross_entropy_loss struct wsp_ggml_tensor * wsp_ggml_cross_entropy_loss( @@ -8168,6 +6351,7 @@ void wsp_ggml_set_param( WSP_GGML_ASSERT(tensor->grad == NULL); tensor->grad = wsp_ggml_dup_tensor(ctx, tensor); + wsp_ggml_format_name(tensor->grad, "%s (grad)", tensor->name); } // wsp_ggml_compute_forward_dup @@ -8214,7 +6398,7 @@ static void wsp_ggml_compute_forward_dup_f16( return; } - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8485,7 +6669,7 @@ static void wsp_ggml_compute_forward_dup_f32( return; } - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS const int ith = params->ith; // thread index const int nth = params->nth; // number of threads @@ -8766,7 +6950,7 @@ static void wsp_ggml_compute_forward_add_f32( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS WSP_GGML_ASSERT( nb0 == sizeof(float)); WSP_GGML_ASSERT(nb00 == sizeof(float)); @@ -8798,8 +6982,6 @@ static void wsp_ggml_compute_forward_add_f32( #else wsp_ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif - // } - // } } } else { // src1 is not contiguous @@ -8841,13 +7023,19 @@ static void wsp_ggml_compute_forward_add_f16_f32( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); - WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16); - WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t)); + if (dst->type == WSP_GGML_TYPE_F32) { + WSP_GGML_ASSERT( nb0 == sizeof(float)); + } + else { + WSP_GGML_ASSERT(dst->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT( nb0 == sizeof(wsp_ggml_fp16_t)); + } + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); // rows per thread @@ -8858,18 +7046,35 @@ static void wsp_ggml_compute_forward_add_f16_f32( const int ir1 = MIN(ir0 + dr, nr); if (nb10 == sizeof(float)) { - for (int ir = ir0; ir < ir1; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); - wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); - - for (int i = 0; i < ne0; i++) { - dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + if (dst->type == WSP_GGML_TYPE_F16) { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + wsp_ggml_fp16_t * dst_ptr = (wsp_ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = WSP_GGML_FP32_TO_FP16(WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]); + } + } + } else { + for (int ir = ir0; ir < ir1; ++ir) { + // src0, src1 and dst are same shape => same indices + const int i3 = ir/(ne2*ne1); + const int i2 = (ir - i3*ne2*ne1)/ne1; + const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + + float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + wsp_ggml_fp16_t * src0_ptr = (wsp_ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11); + + for (int i = 0; i < ne0; i++) { + dst_ptr[i] = WSP_GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]; + } } } } @@ -8895,7 +7100,7 @@ static void wsp_ggml_compute_forward_add_f16_f16( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16); @@ -8946,14 +7151,15 @@ static void wsp_ggml_compute_forward_add_q_f32( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; const enum wsp_ggml_type type = src0->type; + const enum wsp_ggml_type dtype = dst->type; wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; - wsp_ggml_from_float_t const quantize_row_q = type_traits[type].from_float; + wsp_ggml_from_float_t const quantize_row_q = type_traits[dtype].from_float; // we don't support permuted src0 or src1 WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type)); @@ -8965,7 +7171,6 @@ static void wsp_ggml_compute_forward_add_q_f32( WSP_GGML_ASSERT(nb2 <= nb3); WSP_GGML_ASSERT(wsp_ggml_is_quantized(src0->type)); - WSP_GGML_ASSERT(dst->type == src0->type); WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); // rows per thread @@ -9003,7 +7208,11 @@ static void wsp_ggml_compute_forward_add_q_f32( // add src1 wsp_ggml_vec_acc_f32(ne00, wdata, src1_row); // quantize row to dst - quantize_row_q(wdata, dst_row, ne00); + if (quantize_row_q != NULL) { + quantize_row_q(wdata, dst_row, ne00); + } else { + memcpy(dst_row, wdata, ne0*nb0); + } } } @@ -9068,7 +7277,7 @@ static void wsp_ggml_compute_forward_add1_f32( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS WSP_GGML_ASSERT( nb0 == sizeof(float)); WSP_GGML_ASSERT(nb00 == sizeof(float)); @@ -9123,7 +7332,7 @@ static void wsp_ggml_compute_forward_add1_f16_f32( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); @@ -9173,7 +7382,7 @@ static void wsp_ggml_compute_forward_add1_f16_f16( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16); @@ -9223,7 +7432,7 @@ static void wsp_ggml_compute_forward_add1_q_f32( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS const enum wsp_ggml_type type = src0->type; wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; @@ -9313,7 +7522,6 @@ static void wsp_ggml_compute_forward_add1( } } - // wsp_ggml_compute_forward_acc static void wsp_ggml_compute_forward_acc_f32( @@ -9351,8 +7559,8 @@ static void wsp_ggml_compute_forward_acc_f32( const int nr = wsp_ggml_nrows(src1); const int nc = src1->ne[0]; - WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) // src0 and dst as viewed during acc const size_t nb0 = wsp_ggml_element_size(src0); @@ -9441,7 +7649,7 @@ static void wsp_ggml_compute_forward_sub_f32( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS WSP_GGML_ASSERT( nb0 == sizeof(float)); WSP_GGML_ASSERT(nb00 == sizeof(float)); @@ -9453,7 +7661,6 @@ static void wsp_ggml_compute_forward_sub_f32( const int i2 = (ir - i3*ne2*ne1)/ne1; const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - #ifdef WSP_GGML_USE_ACCELERATE vDSP_vsub( (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, @@ -9531,7 +7738,7 @@ static void wsp_ggml_compute_forward_mul_f32( const int64_t nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS WSP_GGML_ASSERT( nb0 == sizeof(float)); WSP_GGML_ASSERT(nb00 == sizeof(float)); @@ -9622,7 +7829,7 @@ static void wsp_ggml_compute_forward_div_f32( const int nr = wsp_ggml_nrows(src0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS WSP_GGML_ASSERT( nb0 == sizeof(float)); WSP_GGML_ASSERT(nb00 == sizeof(float)); @@ -9634,7 +7841,6 @@ static void wsp_ggml_compute_forward_div_f32( const int i2 = (ir - i3*ne2*ne1)/ne1; const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - #ifdef WSP_GGML_USE_ACCELERATE UNUSED(wsp_ggml_vec_div_f32); @@ -9772,7 +7978,6 @@ static void wsp_ggml_compute_forward_sqrt( } } - // wsp_ggml_compute_forward_log static void wsp_ggml_compute_forward_log_f32( @@ -9831,8 +8036,8 @@ static void wsp_ggml_compute_forward_sum_f32( assert(wsp_ggml_is_scalar(dst)); assert(src0->nb[0] == sizeof(float)); - WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) wsp_ggml_float sum = 0; wsp_ggml_float row_sum = 0; @@ -9863,8 +8068,8 @@ static void wsp_ggml_compute_forward_sum_f16( assert(src0->nb[0] == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) float sum = 0; float row_sum = 0; @@ -9917,7 +8122,7 @@ static void wsp_ggml_compute_forward_sum_rows_f32( WSP_GGML_ASSERT(src0->nb[0] == sizeof(float)); WSP_GGML_ASSERT(dst->nb[0] == sizeof(float)); - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS WSP_GGML_ASSERT(ne0 == 1); WSP_GGML_ASSERT(ne1 == ne01); @@ -9967,7 +8172,7 @@ static void wsp_ggml_compute_forward_mean_f32( assert(src0->nb[0] == sizeof(float)); - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS assert(ne0 == 1); assert(ne1 == ne01); @@ -10067,7 +8272,7 @@ static void wsp_ggml_compute_forward_repeat_f32( return; } - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in wsp_ggml_can_repeat const int nr0 = (int)(ne0/ne00); @@ -10099,11 +8304,61 @@ static void wsp_ggml_compute_forward_repeat_f32( } } +static void wsp_ggml_compute_forward_repeat_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(params->ith == 0); + WSP_GGML_ASSERT(wsp_ggml_can_repeat(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_TENSOR_UNARY_OP_LOCALS; + + // guaranteed to be an integer due to the check in wsp_ggml_can_repeat + const int nr0 = (int)(ne0/ne00); + const int nr1 = (int)(ne1/ne01); + const int nr2 = (int)(ne2/ne02); + const int nr3 = (int)(ne3/ne03); + + // TODO: support for transposed / permuted tensors + WSP_GGML_ASSERT(nb0 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + + // TODO: maybe this is not optimal? + for (int i3 = 0; i3 < nr3; i3++) { + for (int k3 = 0; k3 < ne03; k3++) { + for (int i2 = 0; i2 < nr2; i2++) { + for (int k2 = 0; k2 < ne02; k2++) { + for (int i1 = 0; i1 < nr1; i1++) { + for (int k1 = 0; k1 < ne01; k1++) { + for (int i0 = 0; i0 < nr0; i0++) { + wsp_ggml_fp16_t * y = (wsp_ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0); + wsp_ggml_fp16_t * x = (wsp_ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01); + // wsp_ggml_vec_cpy_f16(ne00, y, x) + for (int i = 0; i < ne00; ++i) { + y[i] = x[i]; + } + } + } + } + } + } + } + } +} + static void wsp_ggml_compute_forward_repeat( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, struct wsp_ggml_tensor * dst) { switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_repeat_f16(params, src0, dst); + } break; case WSP_GGML_TYPE_F32: { wsp_ggml_compute_forward_repeat_f32(params, src0, dst); @@ -10128,7 +8383,7 @@ static void wsp_ggml_compute_forward_repeat_back_f32( return; } - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in wsp_ggml_can_repeat const int nr0 = (int)(ne00/ne0); @@ -10206,7 +8461,7 @@ static void wsp_ggml_compute_forward_concat_f32( const int ith = params->ith; - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS // TODO: support for transposed / permuted tensors WSP_GGML_ASSERT(nb0 == sizeof(float)); @@ -10702,7 +8957,7 @@ static void wsp_ggml_compute_forward_silu_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); @@ -10727,6 +8982,48 @@ static void wsp_ggml_compute_forward_silu( } } +// wsp_ggml_compute_forward_leaky + +static void wsp_ggml_compute_forward_leaky_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + assert(params->ith == 0); + assert(wsp_ggml_are_same_shape(src0, dst)); + + if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + const int n = wsp_ggml_nrows(src0); + const int nc = src0->ne[0]; + + assert(dst->nb[0] == sizeof(float)); + assert(src0->nb[0] == sizeof(float)); + + for (int i = 0; i < n; i++) { + wsp_ggml_vec_leaky_f32(nc, + (float *) ((char *) dst->data + i*( dst->nb[1])), + (float *) ((char *) src0->data + i*(src0->nb[1]))); + } +} + +static void wsp_ggml_compute_forward_leaky( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F32: + { + wsp_ggml_compute_forward_leaky_f32(params, src0, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + // wsp_ggml_compute_forward_silu_back static void wsp_ggml_compute_forward_silu_back_f32( @@ -10808,7 +9105,7 @@ static void wsp_ggml_compute_forward_norm_f32( const int ith = params->ith; const int nth = params->nth; - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -10877,7 +9174,7 @@ static void wsp_ggml_compute_forward_rms_norm_f32( const int ith = params->ith; const int nth = params->nth; - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -10942,7 +9239,7 @@ static void wsp_ggml_compute_forward_rms_norm_back_f32( const int ith = params->ith; const int nth = params->nth; - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -11117,7 +9414,7 @@ static void wsp_ggml_compute_forward_group_norm_f32( const int ith = params->ith; const int nth = params->nth; - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS const float eps = 1e-6f; // TODO: make this a parameter @@ -11228,7 +9525,7 @@ static void wsp_ggml_compute_forward_mul_mat( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -11265,11 +9562,6 @@ static void wsp_ggml_compute_forward_mul_mat( #if defined(WSP_GGML_USE_CLBLAST) if (wsp_ggml_cl_can_mul_mat(src0, src1, dst)) { - // TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension - // ref: https://github.com/ggerganov/ggml/pull/224 - WSP_GGML_ASSERT(ne02 == ne12); - WSP_GGML_ASSERT(ne03 == ne13); - if (params->ith == 0 && params->type == WSP_GGML_TASK_COMPUTE) { wsp_ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); } @@ -11430,36 +9722,175 @@ static void wsp_ggml_compute_forward_mul_mat( for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) { vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col); } - memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float)); + } + } + } +} + +// wsp_ggml_compute_forward_out_prod + +static void wsp_ggml_compute_forward_out_prod_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + // int64_t t0 = wsp_ggml_perf_time_us(); + // UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + WSP_GGML_ASSERT(ne02 == ne12); + WSP_GGML_ASSERT(ne03 == ne13); + WSP_GGML_ASSERT(ne2 == ne12); + WSP_GGML_ASSERT(ne3 == ne13); + + // we don't support permuted src0 or src1 + WSP_GGML_ASSERT(nb00 == sizeof(float)); + + // dst cannot be transposed or permuted + WSP_GGML_ASSERT(nb0 == sizeof(float)); + // WSP_GGML_ASSERT(nb0 <= nb1); + // WSP_GGML_ASSERT(nb1 <= nb2); + // WSP_GGML_ASSERT(nb2 <= nb3); + + WSP_GGML_ASSERT(ne0 == ne00); + WSP_GGML_ASSERT(ne1 == ne10); + WSP_GGML_ASSERT(ne2 == ne02); + WSP_GGML_ASSERT(ne3 == ne03); + + // nb01 >= nb00 - src0 is not transposed + // compute by src0 rows + + // TODO: #if defined(WSP_GGML_USE_CUBLAS) wsp_ggml_cuda_out_prod + // TODO: #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) || defined(WSP_GGML_USE_CLBLAST) + + if (params->type == WSP_GGML_TASK_INIT) { + wsp_ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // dst[:,:,:,:] = 0 + // for i2,i3: + // for i1: + // for i01: + // for i0: + // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + + // parallelize by last three dimensions + + // total rows in dst + const int64_t nr = ne1*ne2*ne3; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + // block-tiling attempt + const int64_t blck_0 = MAX(WSP_GGML_VEC_MAD_UNROLL, 32); + const int64_t blck_1 = 16; + + for (int64_t bir = ir0; bir < ir1; bir += blck_1) { + const int64_t bir1 = MIN(bir + blck_1, ir1); + for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) { + const int64_t bne01 = MIN(bi01 + blck_0, ne01); + for (int64_t ir = bir; ir < bir1; ++ir) { + // dst indices + const int64_t i3 = ir/(ne2*ne1); + const int64_t i2 = (ir - i3*ne2*ne1)/ne1; + const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1); + + const int64_t i02 = i2; + const int64_t i03 = i3; + + //const int64_t i10 = i1; + const int64_t i12 = i2; + const int64_t i13 = i3; + +#if WSP_GGML_VEC_MAD_UNROLL > 2 + const int64_t bne01_unroll = bne01 - (bne01 % WSP_GGML_VEC_MAD_UNROLL); + for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += WSP_GGML_VEC_MAD_UNROLL) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + wsp_ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1); + } + for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + wsp_ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#else + for (int64_t i01 = bi01; i01 < bne01; ++i01) { + const int64_t i11 = i01; + + float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03)); + float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); + float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); + + wsp_ggml_vec_mad_f32(ne0, d, s0, *s1); + } +#endif } } } -} -// wsp_ggml_compute_forward_out_prod + //int64_t t1 = wsp_ggml_perf_time_us(); + //static int64_t acc = 0; + //acc += t1 - t0; + //if (t1 - t0 > 10) { + // printf("\n"); + // printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03); + // printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03); + // printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13); + // printf("nb10 = %5d, nb11 = %5d, nb12 = %5d, nb13 = %5d\n", nb10, nb11, nb12, nb13); -static void wsp_ggml_compute_forward_out_prod_f32( + // printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc); + //} +} + +static void wsp_ggml_compute_forward_out_prod_q_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { - int64_t t0 = wsp_ggml_perf_time_us(); - UNUSED(t0); + // int64_t t0 = wsp_ggml_perf_time_us(); + // UNUSED(t0); WSP_GGML_TENSOR_BINARY_OP_LOCALS; const int ith = params->ith; const int nth = params->nth; + const enum wsp_ggml_type type = src0->type; + wsp_ggml_to_float_t const dequantize_row_q = type_traits[type].to_float; + WSP_GGML_ASSERT(ne02 == ne12); WSP_GGML_ASSERT(ne03 == ne13); WSP_GGML_ASSERT(ne2 == ne12); WSP_GGML_ASSERT(ne3 == ne13); - // we don't support permuted src0 or src1 - WSP_GGML_ASSERT(nb00 == sizeof(float)); + // we don't support permuted src0 dim0 + WSP_GGML_ASSERT(nb00 == wsp_ggml_type_size(type)); - // dst cannot be transposed or permuted + // dst dim0 cannot be transposed or permuted WSP_GGML_ASSERT(nb0 == sizeof(float)); // WSP_GGML_ASSERT(nb0 <= nb1); // WSP_GGML_ASSERT(nb1 <= nb2); @@ -11504,6 +9935,8 @@ static void wsp_ggml_compute_forward_out_prod_f32( // for i0: // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3] + float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith; + for (int64_t ir = ir0; ir < ir1; ++ir) { // dst indices const int64_t i3 = ir/(ne2*ne1); @@ -11524,10 +9957,8 @@ static void wsp_ggml_compute_forward_out_prod_f32( float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13)); float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3)); - wsp_ggml_vec_mad_f32(ne0, d, s0, *s1); - // for (int64_t i0 = 0; i0 < ne0; ++i0) { - // d[i0] += s0[i0] * s1[i1]; - // } + dequantize_row_q(s0, wdata, ne0); + wsp_ggml_vec_mad_f32(ne0, d, wdata, *s1); } } @@ -11556,10 +9987,13 @@ static void wsp_ggml_compute_forward_out_prod( case WSP_GGML_TYPE_Q5_0: case WSP_GGML_TYPE_Q5_1: case WSP_GGML_TYPE_Q8_0: - case WSP_GGML_TYPE_Q8_1: + case WSP_GGML_TYPE_Q2_K: + case WSP_GGML_TYPE_Q3_K: + case WSP_GGML_TYPE_Q4_K: + case WSP_GGML_TYPE_Q5_K: + case WSP_GGML_TYPE_Q6_K: { - WSP_GGML_ASSERT(false); // todo - // wsp_ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); + wsp_ggml_compute_forward_out_prod_q_f32(params, src0, src1, dst); } break; case WSP_GGML_TYPE_F16: { @@ -11613,7 +10047,6 @@ static void wsp_ggml_compute_forward_scale_f32( const size_t nb1 = dst->nb[1]; - for (int i1 = ir0; i1 < ir1; i1++) { if (dst->data != src0->data) { // src0 is same shape as dst => same indices @@ -11677,8 +10110,8 @@ static void wsp_ggml_compute_forward_set_f32( const int nr = wsp_ggml_nrows(src1); const int nc = src1->ne[0]; - WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) // src0 and dst as viewed during set const size_t nb0 = wsp_ggml_element_size(src0); @@ -11947,14 +10380,15 @@ static void wsp_ggml_compute_forward_get_rows_back_f32_f16( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, - const struct wsp_ggml_tensor * opt0, struct wsp_ggml_tensor * dst) { WSP_GGML_ASSERT(params->ith == 0); - WSP_GGML_ASSERT(wsp_ggml_are_same_shape(opt0, dst)); - WSP_GGML_ASSERT(wsp_ggml_is_contiguous(opt0)); WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); - wsp_ggml_compute_forward_dup_same_cont(params, opt0, dst); + // wsp_ggml_compute_forward_dup_same_cont(params, opt0, dst); + + if (params->type == WSP_GGML_TASK_INIT) { + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + } if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { return; @@ -11980,11 +10414,8 @@ static void wsp_ggml_compute_forward_get_rows_back_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, - const struct wsp_ggml_tensor * opt0, struct wsp_ggml_tensor * dst) { WSP_GGML_ASSERT(params->ith == 0); - WSP_GGML_ASSERT(wsp_ggml_are_same_shape(opt0, dst)); - WSP_GGML_ASSERT(wsp_ggml_is_contiguous(opt0)); WSP_GGML_ASSERT(wsp_ggml_is_contiguous(dst)); // wsp_ggml_compute_forward_dup_same_cont(params, opt0, dst); @@ -12013,21 +10444,19 @@ static void wsp_ggml_compute_forward_get_rows_back_f32( } } - static void wsp_ggml_compute_forward_get_rows_back( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, - const struct wsp_ggml_tensor * opt0, struct wsp_ggml_tensor * dst) { switch (src0->type) { case WSP_GGML_TYPE_F16: { - wsp_ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, opt0, dst); + wsp_ggml_compute_forward_get_rows_back_f32_f16(params, src0, src1, dst); } break; case WSP_GGML_TYPE_F32: { - wsp_ggml_compute_forward_get_rows_back_f32(params, src0, src1, opt0, dst); + wsp_ggml_compute_forward_get_rows_back_f32(params, src0, src1, dst); } break; default: { @@ -12068,7 +10497,7 @@ static void wsp_ggml_compute_forward_diag_f32( // TODO: handle transposed/permuted matrices - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS WSP_GGML_ASSERT(ne00 == ne0); WSP_GGML_ASSERT(ne00 == ne1); @@ -12249,7 +10678,7 @@ static void wsp_ggml_compute_forward_soft_max_f32( // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max); wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(sp[i] - max); memcpy(&scvt, &s, sizeof(scvt)); - const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt]); + const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt]); sum += (wsp_ggml_float)val; dp[i] = val; } @@ -12393,28 +10822,25 @@ static void wsp_ggml_compute_forward_alibi_f32( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int ne1 = src0->ne[1]; // seq_len_without_past - const int ne2 = src0->ne[2]; // n_head -> this is k - //const int ne3 = src0->ne[3]; // 1 -> bsz + const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int64_t ne1 = src0->ne[1]; // seq_len_without_past + const int64_t ne2 = src0->ne[2]; // n_head -> this is k + //const int64_t ne3 = src0->ne[3]; // 1 -> bsz - const int n = wsp_ggml_nrows(src0); - const int ne2_ne3 = n/ne1; // ne2*ne3 + const int64_t n = wsp_ggml_nrows(src0); + const int64_t ne2_ne3 = n/ne1; // ne2*ne3 - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; + const size_t nb0 = src0->nb[0]; + const size_t nb1 = src0->nb[1]; + const size_t nb2 = src0->nb[2]; //const int nb3 = src0->nb[3]; WSP_GGML_ASSERT(nb0 == sizeof(float)); - WSP_GGML_ASSERT(ne1 + n_past == ne0); WSP_GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) @@ -12423,9 +10849,9 @@ static void wsp_ggml_compute_forward_alibi_f32( const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - for (int i = 0; i < ne0; i++) { - for (int j = 0; j < ne1; j++) { - for (int k = 0; k < ne2_ne3; k++) { + for (int64_t i = 0; i < ne0; i++) { + for (int64_t j = 0; j < ne1; j++) { + for (int64_t k = 0; k < ne2_ne3; k++) { float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); @@ -12440,7 +10866,6 @@ static void wsp_ggml_compute_forward_alibi_f32( } pdst[0] = i * m_k + src[0]; - } } } @@ -12456,13 +10881,11 @@ static void wsp_ggml_compute_forward_alibi_f16( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne1 = src0->ne[1]; // seq_len_without_past const int ne2 = src0->ne[2]; // n_head -> this is k @@ -12477,7 +10900,7 @@ static void wsp_ggml_compute_forward_alibi_f16( //const int nb3 = src0->nb[3]; WSP_GGML_ASSERT(nb0 == sizeof(wsp_ggml_fp16_t)); - WSP_GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; + //WSP_GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; WSP_GGML_ASSERT(n_head == ne2); // add alibi to src0 (KQ_scaled) @@ -12620,34 +11043,76 @@ static void wsp_ggml_compute_forward_clamp( // wsp_ggml_compute_forward_rope +static float rope_yarn_ramp(const float low, const float high, const int i0) { + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); + return 1 - MIN(1, MAX(0, y)); +} + +// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn +// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. +static void rope_yarn( + float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float * cos_theta, float * sin_theta +) { + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta = theta_interp; + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + *cos_theta = cosf(theta) * mscale; + *sin_theta = sinf(theta) * mscale; +} + +// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get +// `corr_dim(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))` +static float wsp_ggml_rope_yarn_corr_dim(int n_dims, int n_orig_ctx, float n_rot, float base) { + return n_dims * logf(n_orig_ctx / (n_rot * 2 * (float)M_PI)) / (2 * logf(base)); +} + +void wsp_ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2] +) { + // start and end correction dims + dims[0] = MAX(0, floorf(wsp_ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_fast, freq_base))); + dims[1] = MIN(n_dims - 1, ceilf(wsp_ggml_rope_yarn_corr_dim(n_dims, n_orig_ctx, beta_slow, freq_base))); +} + static void wsp_ggml_compute_forward_rope_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { - if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { return; } - float freq_base; - float freq_scale; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; // these two only relevant for xPos RoPE: float xpos_base; - bool xpos_down; + bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; - assert(n_past >= 0); + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(&xpos_base, (int32_t *) dst->op_params + 11, sizeof(float)); + memcpy(&xpos_down, (int32_t *) dst->op_params + 12, sizeof(bool)); - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12673,29 +11138,34 @@ static void wsp_ggml_compute_forward_rope_f32( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; + wsp_ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = (float)p; if (is_glm) { - theta = MIN(p, n_ctx - 2); + theta_base = MIN(p, n_ctx - 2); float block_theta = MAX(p - (n_ctx - 2), 0); for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta); - theta *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -12713,13 +11183,16 @@ static void wsp_ggml_compute_forward_rope_f32( } } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); + // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; - theta *= theta_scale; + theta_base *= theta_scale; const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12733,12 +11206,19 @@ static void wsp_ggml_compute_forward_rope_f32( } else { // TODO: this might be wrong for ne0 != n_dims - need double check // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base *= freq_scale; for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + // simplified from `(ib * n_dims + ic) * inv_ndims` + float cur_rot = inv_ndims * ic - ib; + + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + &cos_theta, &sin_theta + ); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -12761,25 +11241,27 @@ static void wsp_ggml_compute_forward_rope_f32( static void wsp_ggml_compute_forward_rope_f16( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { - if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { return; } - float freq_base; - float freq_scale; - - const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - const int n_ctx = ((int32_t *) dst->op_params)[3]; - memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; - assert(n_past >= 0); + //const int n_past = ((int32_t *) dst->op_params)[0]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_orig_ctx = ((int32_t *) dst->op_params)[4]; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12805,29 +11287,34 @@ static void wsp_ggml_compute_forward_rope_f16( int ir = 0; const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float inv_ndims = -1.f/n_dims; + float corr_dims[2]; + wsp_ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = (float)p; if (is_glm) { - theta = MIN(p, n_ctx - 2); + theta_base = MIN(p, n_ctx - 2); float block_theta = MAX(p - (n_ctx - 2), 0); for (int64_t i0 = 0; i0 < ne0 / 4; i0++) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); const float cos_block_theta = cosf(block_theta); const float sin_block_theta = sinf(block_theta); - theta *= theta_scale; + theta_base *= theta_scale; block_theta *= theta_scale; const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); @@ -12843,12 +11330,14 @@ static void wsp_ggml_compute_forward_rope_f16( dst_data[n_dims] = WSP_GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); dst_data[n_dims/2*3] = WSP_GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); } - } if (!is_neox) { + } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta + ); - theta *= theta_scale; + theta_base *= theta_scale; const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); wsp_ggml_fp16_t * dst_data = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12862,12 +11351,19 @@ static void wsp_ggml_compute_forward_rope_f16( } else { // TODO: this might be wrong for ne0 != n_dims - need double check // ref: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + theta_base *= freq_scale; for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + // simplified from `(ib * n_dims + ic) * inv_ndims` + float cur_rot = inv_ndims * ic - ib; - theta *= theta_scale; + float cos_theta, sin_theta; + rope_yarn( + theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, + &cos_theta, &sin_theta + ); + + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -12890,15 +11386,16 @@ static void wsp_ggml_compute_forward_rope_f16( static void wsp_ggml_compute_forward_rope( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { switch (src0->type) { case WSP_GGML_TYPE_F16: { - wsp_ggml_compute_forward_rope_f16(params, src0, dst); + wsp_ggml_compute_forward_rope_f16(params, src0, src1, dst); } break; case WSP_GGML_TYPE_F32: { - wsp_ggml_compute_forward_rope_f32(params, src0, dst); + wsp_ggml_compute_forward_rope_f32(params, src0, src1, dst); } break; default: { @@ -12912,6 +11409,7 @@ static void wsp_ggml_compute_forward_rope( static void wsp_ggml_compute_forward_rope_back_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { @@ -12929,7 +11427,7 @@ static void wsp_ggml_compute_forward_rope_back_f32( float xpos_base; bool xpos_down; - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx); @@ -12938,9 +11436,7 @@ static void wsp_ggml_compute_forward_rope_back_f32( memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool)); - assert(n_past >= 0); - - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -12966,24 +11462,27 @@ static void wsp_ggml_compute_forward_rope_back_f32( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = freq_scale * (float)p; + float theta_base = freq_scale * (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); + // zeta scaling for xPos only: - float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), (n_past + i2) / xpos_base) : 1.0f; + float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f; if (xpos_down) zeta = 1.0f / zeta; - theta *= theta_scale; + theta_base *= theta_scale; const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -12997,10 +11496,10 @@ static void wsp_ggml_compute_forward_rope_back_f32( } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -13023,6 +11522,7 @@ static void wsp_ggml_compute_forward_rope_back_f32( static void wsp_ggml_compute_forward_rope_back_f16( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { if (params->type == WSP_GGML_TASK_INIT || params->type == WSP_GGML_TASK_FINALIZE) { @@ -13033,13 +11533,11 @@ static void wsp_ggml_compute_forward_rope_back_f16( // dx = rope_back(dy, src1) // src0 is dy, src1 contains options - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; - assert(n_past >= 0); - - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3); //printf("n_past = %d, ne2 = %d\n", n_past, ne2); @@ -13065,21 +11563,23 @@ static void wsp_ggml_compute_forward_rope_back_f16( const bool is_neox = mode & 2; + const int32_t * pos = (const int32_t *) src1->data; + for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = ((mode & 1) == 0 ? 0 : n_past); i2 < ne2; i2++) { - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + for (int64_t i2 = 0; i2 < ne2; i2++) { + const int64_t p = pos[i2]; for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - float theta = (float)p; + float theta_base = (float)p; if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const wsp_ggml_fp16_t * const dy = (wsp_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); wsp_ggml_fp16_t * dx = (wsp_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); @@ -13093,10 +11593,10 @@ static void wsp_ggml_compute_forward_rope_back_f16( } else { for (int64_t ib = 0; ib < ne0/n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { - const float cos_theta = cosf(theta); - const float sin_theta = sinf(theta); + const float cos_theta = cosf(theta_base); + const float sin_theta = sinf(theta_base); - theta *= theta_scale; + theta_base *= theta_scale; const int64_t i0 = ib*n_dims + ic/2; @@ -13119,15 +11619,16 @@ static void wsp_ggml_compute_forward_rope_back_f16( static void wsp_ggml_compute_forward_rope_back( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { switch (src0->type) { case WSP_GGML_TYPE_F16: { - wsp_ggml_compute_forward_rope_back_f16(params, src0, dst); + wsp_ggml_compute_forward_rope_back_f16(params, src0, src1, dst); } break; case WSP_GGML_TYPE_F32: { - wsp_ggml_compute_forward_rope_back_f32(params, src0, dst); + wsp_ggml_compute_forward_rope_back_f32(params, src0, src1, dst); } break; default: { @@ -13138,7 +11639,7 @@ static void wsp_ggml_compute_forward_rope_back( // wsp_ggml_compute_forward_conv_1d -static void wsp_ggml_compute_forward_conv_1d_s1_ph_f16_f32( +static void wsp_ggml_compute_forward_conv_1d_f16_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -13150,48 +11651,39 @@ static void wsp_ggml_compute_forward_conv_1d_s1_ph_f16_f32( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; const int nk = ne00; - const int nh = nk/2; - const int ew0 = wsp_ggml_up32(ne01); + // size of the convolution row - the kernel size unrolled across all input channels + const int ew0 = nk*ne01; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - WSP_GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); WSP_GGML_ASSERT(nb10 == sizeof(float)); if (params->type == WSP_GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) - { - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - wsp_ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; - } - } - } - } + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + wsp_ggml_fp16_t * dst_data = wdata; - // prepare source data (src1) - { - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + for (int64_t i0 = 0; i0 < ne0; i0++) { + for (int64_t ik = 0; ik < nk; ik++) { + const int idx0 = i0*s0 + ik*d0 - p0; - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - wsp_ggml_fp16_t * dst_data = wdata; - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]); + if(!(idx0 < 0 || idx0 >= ne10)) { + dst_data[i0*ew0 + i11*nk + ik] = WSP_GGML_FP32_TO_FP16(src[idx0]); + } } } } @@ -13204,7 +11696,7 @@ static void wsp_ggml_compute_forward_conv_1d_s1_ph_f16_f32( } // total rows in dst - const int nr = ne02; + const int nr = ne2; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -13213,23 +11705,22 @@ static void wsp_ggml_compute_forward_conv_1d_s1_ph_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - wsp_ggml_vec_dot_f16(ew0, &v, - (wsp_ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (wsp_ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + + for (int i0 = 0; i0 < ne0; i0++) { + wsp_ggml_vec_dot_f16(ew0, dst_data + i0, + (wsp_ggml_fp16_t *) ((char *) src0->data + i1*nb02), + (wsp_ggml_fp16_t *) wdata + i2*nb2 + i0*ew0); } } } } -static void wsp_ggml_compute_forward_conv_1d_s1_ph_f32( +static void wsp_ggml_compute_forward_conv_1d_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -13241,52 +11732,229 @@ static void wsp_ggml_compute_forward_conv_1d_s1_ph_f32( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; const int nk = ne00; - const int nh = nk/2; - const int ew0 = wsp_ggml_up32(ne01); + const int ew0 = nk*ne01; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - WSP_GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes WSP_GGML_ASSERT(nb00 == sizeof(float)); WSP_GGML_ASSERT(nb10 == sizeof(float)); if (params->type == WSP_GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) - { - float * const wdata = (float *) params->wdata + 0; + float * const wdata = (float *) params->wdata + 0; - for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; - for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; + for (int64_t i11 = 0; i11 < ne11; i11++) { + const float * const src = (float *)((char *) src1->data + i11*nb11); + float * dst_data = wdata; + + for (int64_t i0 = 0; i0 < ne0; i0++) { + for (int64_t ik = 0; ik < nk; ik++) { + const int idx0 = i0*s0 + ik*d0 - p0; + + if(!(idx0 < 0 || idx0 >= ne10)) { + dst_data[i0*ew0 + i11*nk + ik] = src[idx0]; } } } } - // prepare source data (src1) - { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + return; + } - for (int64_t i11 = 0; i11 < ne11; i11++) { - const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; - for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // total rows in dst + const int nr = ne02; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float * const wdata = (float *) params->wdata + 0; + + for (int i2 = 0; i2 < ne2; i2++) { + for (int i1 = ir0; i1 < ir1; i1++) { + float * dst_data = (float *)((char *) dst->data + i2*nb2 + i1*nb1); + + for (int i0 = 0; i0 < ne0; i0++) { + wsp_ggml_vec_dot_f32(ew0, dst_data + i0, + (float *) ((char *) src0->data + i1*nb02), + (float *) wdata + i2*nb2 + i0*ew0); + } + } + } +} + +// TODO: reuse wsp_ggml_mul_mat or implement wsp_ggml_im2col and remove stage_0 and stage_1 +static void gemm_f16_out_f32(int64_t m, int64_t n, int64_t k, + wsp_ggml_fp16_t * A, + wsp_ggml_fp16_t * B, + float * C, + const int ith, const int nth) { + // does not seem to make a difference + int64_t m0, m1, n0, n1; + // patches per thread + if (m > n) { + n0 = 0; + n1 = n; + + // total patches in dst + const int np = m; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + m0 = dp*ith; + m1 = MIN(m0 + dp, np); + } else { + m0 = 0; + m1 = m; + + // total patches in dst + const int np = n; + + // patches per thread + const int dp = (np + nth - 1)/nth; + + // patch range for this thread + n0 = dp*ith; + n1 = MIN(n0 + dp, np); + } + + // block-tiling attempt + int64_t blck_n = 16; + int64_t blck_m = 16; + + // int64_t CACHE_SIZE = 2 * 1024 * 1024; // 2MB + // int64_t blck_size = CACHE_SIZE / (sizeof(float) + 2 * sizeof(wsp_ggml_fp16_t) * K); + // if (blck_size > 0) { + // blck_0 = 4; + // blck_1 = blck_size / blck_0; + // if (blck_1 < 0) { + // blck_1 = 1; + // } + // // blck_0 = (int64_t)sqrt(blck_size); + // // blck_1 = blck_0; + // } + // // printf("%zd %zd %zd %zd\n", blck_size, K, blck_0, blck_1); + + for (int j = n0; j < n1; j+=blck_n) { + for (int i = m0; i < m1; i+=blck_m) { + // printf("i j k => %d %d %d\n", i, j, K); + for (int ii = i; ii < i + blck_m && ii < m1; ii++) { + for (int jj = j; jj < j + blck_n && jj < n1; jj++) { + wsp_ggml_vec_dot_f16(k, + C + ii*n + jj, + A + ii * k, + B + jj * k); + } + } + } + } +} + +// src0: kernel [OC, IC, K] +// src1: signal [N, IC, IL] +// dst: result [N, OL, IC*K] +static void wsp_ggml_compute_forward_conv_1d_stage_0_f32( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16); + + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int64_t N = ne12; + const int64_t IC = ne11; + const int64_t IL = ne10; + + const int64_t K = ne00; + + const int64_t OL = ne1; + + const int ith = params->ith; + const int nth = params->nth; + + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; + + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == WSP_GGML_TASK_INIT) { + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // im2col: [N, IC, IL] => [N, OL, IC*K] + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t iol = 0; iol < OL; iol++) { + for (int64_t iic = ith; iic < IC; iic+=nth) { + + // micro kernel + wsp_ggml_fp16_t * dst_data = wdata + (in*OL + iol)*(IC*K); // [IC, K] + const float * const src_data = (float *)((char *) src1->data + in*nb12 + iic*nb11); // [IL] + + for (int64_t ik = 0; ik < K; ik++) { + const int64_t iil = iol*s0 + ik*d0 - p0; + + if (!(iil < 0 || iil >= IL)) { + dst_data[iic*K + ik] = WSP_GGML_FP32_TO_FP16(src_data[iil]); + } + } } } } + } +} + +// gemm: [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] +// src0: [OC, IC, K] +// src1: [N, OL, IC * K] +// result: [N, OC, OL] +static void wsp_ggml_compute_forward_conv_1d_stage_1_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + if (params->type == WSP_GGML_TASK_INIT) { return; } @@ -13294,45 +11962,83 @@ static void wsp_ggml_compute_forward_conv_1d_s1_ph_f32( return; } - // total rows in dst - const int nr = ne02; + WSP_GGML_TENSOR_BINARY_OP_LOCALS; - // rows per thread - const int dr = (nr + nth - 1)/nth; + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb0 == sizeof(float)); - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); + const int N = ne12; + const int OL = ne11; - for (int i1 = ir0; i1 < ir1; i1++) { - float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; ++i0) { - dst_data[i0] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - wsp_ggml_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0] += v; - } - } + const int OC = ne02; + const int IC = ne01; + const int K = ne00; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t m = OC; + int64_t n = OL; + int64_t k = IC * K; + + // [N, OC, OL] = [OC, IC * K] x [N*OL, IC * K] + for (int i = 0; i < N; i++) { + wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k] + wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k] + float * C = (float *)dst->data + i * m * n; // [m, n] + + gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); } } -static void wsp_ggml_compute_forward_conv_1d_s1_ph( +static void wsp_ggml_compute_forward_conv_1d( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { - switch (src0->type) { + switch(src0->type) { case WSP_GGML_TYPE_F16: { - wsp_ggml_compute_forward_conv_1d_s1_ph_f16_f32(params, src0, src1, dst); + wsp_ggml_compute_forward_conv_1d_f16_f32(params, src0, src1, dst); } break; case WSP_GGML_TYPE_F32: { - wsp_ggml_compute_forward_conv_1d_s1_ph_f32(params, src0, src1, dst); + wsp_ggml_compute_forward_conv_1d_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +static void wsp_ggml_compute_forward_conv_1d_stage_0( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch(src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_conv_1d_stage_0_f32(params, src0, src1, dst); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +static void wsp_ggml_compute_forward_conv_1d_stage_1( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch(src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_conv_1d_stage_1_f16(params, src0, src1, dst); } break; default: { @@ -13341,7 +12047,9 @@ static void wsp_ggml_compute_forward_conv_1d_s1_ph( } } -static void wsp_ggml_compute_forward_conv_1d_s2_ph_f16_f32( +// wsp_ggml_compute_forward_conv_transpose_1d + +static void wsp_ggml_compute_forward_conv_transpose_1d_f16_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -13353,52 +12061,50 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph_f16_f32( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = wsp_ggml_up32(ne01); + const int nk = ne00*ne01*ne02; - WSP_GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); WSP_GGML_ASSERT(nb10 == sizeof(float)); if (params->type == WSP_GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) + // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) { wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { const wsp_ggml_fp16_t * const src = (wsp_ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01); - wsp_ggml_fp16_t * dst_data = wdata + i02*ew0*ne00; + wsp_ggml_fp16_t * dst_data = wdata + i01*ne00*ne02; for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; + dst_data[i00*ne02 + i02] = src[i00]; } } } } - // prepare source data (src1) + // permute source data (src1) from (L x Cin) to (Cin x L) { - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + ne02*ew0*ne00; + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + nk; + wsp_ggml_fp16_t * dst_data = wdata; for (int64_t i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i11*nb11); - wsp_ggml_fp16_t * dst_data = wdata; for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]); + dst_data[i10*ne11 + i11] = WSP_GGML_FP32_TO_FP16(src[i10]); } } } + // need to zero dst since we are accumulating into it + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + return; } @@ -13406,8 +12112,10 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph_f16_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne02; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -13416,23 +12124,26 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph_f16_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + wsp_ggml_fp16_t * const wdata_src = wdata + nk; + for (int i1 = ir0; i1 < ir1; i1++) { float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - wsp_ggml_vec_dot_f16(ew0, &v, - (wsp_ggml_fp16_t *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (wsp_ggml_fp16_t *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; + wsp_ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + wsp_ggml_vec_dot_f16(ne02, &v, + (wsp_ggml_fp16_t *) wdata_src + i1n, + (wsp_ggml_fp16_t *) wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -static void wsp_ggml_compute_forward_conv_1d_s2_ph_f32( +static void wsp_ggml_compute_forward_conv_transpose_1d_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -13444,34 +12155,29 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph_f32( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; - const int nk = ne00; - const int nh = nk/2; - - const int ew0 = wsp_ggml_up32(ne01); + const int nk = ne00*ne01*ne02; - WSP_GGML_ASSERT(ne00 % 2 == 1); // TODO: support even kernel sizes WSP_GGML_ASSERT(nb00 == sizeof(float)); WSP_GGML_ASSERT(nb10 == sizeof(float)); if (params->type == WSP_GGML_TASK_INIT) { - // TODO: fix this memset (wsize is overestimated) memset(params->wdata, 0, params->wsize); - // prepare kernel data (src0) + // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout) { float * const wdata = (float *) params->wdata + 0; for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = 0; i01 < ne01; i01++) { const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01); - float * dst_data = wdata + i02*ew0*ne00; + float * dst_data = wdata + i01*ne00*ne02; for (int64_t i00 = 0; i00 < ne00; i00++) { - dst_data[i00*ew0 + i01] = src[i00]; + dst_data[i00*ne02 + i02] = src[i00]; } } } @@ -13479,17 +12185,20 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph_f32( // prepare source data (src1) { - float * const wdata = (float *) params->wdata + ne02*ew0*ne00; + float * const wdata = (float *) params->wdata + nk; + float * dst_data = wdata; for (int64_t i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i11*nb11); - float * dst_data = wdata; for (int64_t i10 = 0; i10 < ne10; i10++) { - dst_data[(i10 + nh)*ew0 + i11] = src[i10]; + dst_data[i10*ne11 + i11] = src[i10]; } } } + // need to zero dst since we are accumulating into it + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + return; } @@ -13497,8 +12206,10 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph_f32( return; } + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + // total rows in dst - const int nr = ne02; + const int nr = ne1; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -13507,23 +12218,26 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); + float * const wdata = (float *) params->wdata + 0; + float * const wdata_src = wdata + nk; + for (int i1 = ir0; i1 < ir1; i1++) { float * dst_data = (float *)((char *) dst->data + i1*nb1); - for (int64_t i0 = 0; i0 < ne10; i0 += 2) { - dst_data[i0/2] = 0; - for (int k = -nh; k <= nh; k++) { - float v = 0.0f; - wsp_ggml_vec_dot_f32(ew0, &v, - (float *) params->wdata + i1*ew0*ne00 + (nh + k)*ew0, - (float *) params->wdata + ne02*ew0*ne00 + (i0 + nh + k)*ew0); - - dst_data[i0/2] += v; + float * wdata_kernel = wdata + i1*ne02*ne00; + for (int i10 = 0; i10 < ne10; i10++) { + const int i1n = i10*ne11; + for (int i00 = 0; i00 < ne00; i00++) { + float v = 0; + wsp_ggml_vec_dot_f32(ne02, &v, + wdata_src + i1n, + wdata_kernel + i00*ne02); + dst_data[i10*s0 + i00] += v; } } } } -static void wsp_ggml_compute_forward_conv_1d_s2_ph( +static void wsp_ggml_compute_forward_conv_transpose_1d( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, @@ -13531,11 +12245,11 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph( switch (src0->type) { case WSP_GGML_TYPE_F16: { - wsp_ggml_compute_forward_conv_1d_s2_ph_f16_f32(params, src0, src1, dst); + wsp_ggml_compute_forward_conv_transpose_1d_f16_f32(params, src0, src1, dst); } break; case WSP_GGML_TYPE_F32: { - wsp_ggml_compute_forward_conv_1d_s2_ph_f32(params, src0, src1, dst); + wsp_ggml_compute_forward_conv_transpose_1d_f32(params, src0, src1, dst); } break; default: { @@ -13544,28 +12258,145 @@ static void wsp_ggml_compute_forward_conv_1d_s2_ph( } } -// wsp_ggml_compute_forward_conv_1d +// wsp_ggml_compute_forward_conv_2d -static void wsp_ggml_compute_forward_conv_1d( +// src0: kernel [OC, IC, KH, KW] +// src1: image [N, IC, IH, IW] +// dst: result [N, OH, OW, IC*KH*KW] +static void wsp_ggml_compute_forward_conv_2d_stage_0_f32( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, const struct wsp_ggml_tensor * src1, struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F32); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F16); + + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + const int64_t N = ne13; + const int64_t IC = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + // const int64_t OC = ne03; + // const int64_t IC = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OH = ne2; + const int64_t OW = ne1; + + const int ith = params->ith; + const int nth = params->nth; + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t p0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t d0 = ((const int32_t*)(dst->op_params))[2]; - WSP_GGML_ASSERT(d0 == 1); // dilation not supported - WSP_GGML_ASSERT(p0 == src0->ne[0]/2); // only half padding supported - if (s0 == 1) { - wsp_ggml_compute_forward_conv_1d_s1_ph(params, src0, src1, dst); - } else if (s0 == 2) { - wsp_ggml_compute_forward_conv_1d_s2_ph(params, src0, src1, dst); - } else { - WSP_GGML_ASSERT(false); // only stride 1 and 2 supported - }; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; + + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb10 == sizeof(float)); + + if (params->type == WSP_GGML_TASK_INIT) { + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] + { + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) dst->data; + + for (int64_t in = 0; in < N; in++) { + for (int64_t ioh = 0; ioh < OH; ioh++) { + for (int64_t iow = 0; iow < OW; iow++) { + for (int64_t iic = ith; iic < IC; iic+=nth) { + + // micro kernel + wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] + + for (int64_t ikh = 0; ikh < KH; ikh++) { + for (int64_t ikw = 0; ikw < KW; ikw++) { + const int64_t iiw = iow*s0 + ikw*d0 - p0; + const int64_t iih = ioh*s1 + ikh*d1 - p1; + + if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); + } + } + } + } + } + } + } + } } -// wsp_ggml_compute_forward_conv_2d +// gemm: [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] +// src0: [OC, IC, KH, KW] +// src1: [N, OH, OW, IC * KH * KW] +// result: [N, OC, OH, OW] +static void wsp_ggml_compute_forward_conv_2d_stage_1_f16( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + WSP_GGML_ASSERT(src0->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT(src1->type == WSP_GGML_TYPE_F16); + WSP_GGML_ASSERT( dst->type == WSP_GGML_TYPE_F32); + + int64_t t0 = wsp_ggml_perf_time_us(); + UNUSED(t0); + + if (params->type == WSP_GGML_TASK_INIT) { + return; + } + + if (params->type == WSP_GGML_TASK_FINALIZE) { + return; + } + + WSP_GGML_TENSOR_BINARY_OP_LOCALS; + + WSP_GGML_ASSERT(nb00 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb10 == sizeof(wsp_ggml_fp16_t)); + WSP_GGML_ASSERT(nb0 == sizeof(float)); + + const int N = ne13; + const int OH = ne12; + const int OW = ne11; + + const int OC = ne03; + const int IC = ne02; + const int KH = ne01; + const int KW = ne00; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t m = OC; + int64_t n = OH * OW; + int64_t k = IC * KH * KW; + + // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] + for (int i = 0; i < N; i++) { + wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k] + wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)src1->data + i * m * k; // [n, k] + float * C = (float *)dst->data + i * m * n; // [m, n] + + gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); + } +} static void wsp_ggml_compute_forward_conv_2d_f16_f32( const struct wsp_ggml_compute_params * params, @@ -13579,16 +12410,40 @@ static void wsp_ggml_compute_forward_conv_2d_f16_f32( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS + + // src1: image [N, IC, IH, IW] + // src0: kernel [OC, IC, KH, KW] + // dst: result [N, OC, OH, OW] + // ne12: IC + // ne0: OW + // ne1: OH + // nk0: KW + // nk1: KH + // ne13: N + + const int N = ne13; + const int IC = ne12; + const int IH = ne11; + const int IW = ne10; + + const int OC = ne03; + // const int IC = ne02; + const int KH = ne01; + const int KW = ne00; + + const int OH = ne1; + const int OW = ne0; const int ith = params->ith; const int nth = params->nth; - const int nk0 = ne00; - const int nk1 = ne01; + // const int nk0 = ne00; + // const int nk1 = ne01; // size of the convolution row - the kernel size unrolled across all channels - const int ew0 = nk0*nk1*ne02; + // const int ew0 = nk0*nk1*ne02; + // ew0: IC*KH*KW const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; @@ -13604,23 +12459,28 @@ static void wsp_ggml_compute_forward_conv_2d_f16_f32( memset(params->wdata, 0, params->wsize); // prepare source data (src1) + // im2col: [N, IC, IH, IW] => [N*OH*OW, IC*KH*KW] + { wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; - for (int i12 = 0; i12 < ne12; i12++) { - const float * const src = (float *)((char *) src1->data + i12*nb12); - wsp_ggml_fp16_t * dst_data = wdata; + for (int in = 0; in < N; in++) { + for (int iic = 0; iic < IC; iic++) { + for (int ioh = 0; ioh < OH; ioh++) { + for (int iow = 0; iow < OW; iow++) { - for (int i1 = 0; i1 < ne1; i1++) { - for (int i0 = 0; i0 < ne0; i0++) { - for (int ik1 = 0; ik1 < nk1; ik1++) { - for (int ik0 = 0; ik0 < nk0; ik0++) { - const int idx0 = i0*s0 + ik0*d0 - p0; - const int idx1 = i1*s1 + ik1*d1 - p1; - - if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) { - dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = - WSP_GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]); + // micro kernel + wsp_ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] + const float * const src_data = (float *)((char *) src1->data + in*nb13 + iic*nb12); // [IH, IW] + + for (int ikh = 0; ikh < KH; ikh++) { + for (int ikw = 0; ikw < KW; ikw++) { + const int iiw = iow*s0 + ikw*d0 - p0; + const int iih = ioh*s1 + ikh*d1 - p1; + + if (!(iih < 0 || iih >= IH || iiw < 0 || iiw >= IW)) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = WSP_GGML_FP32_TO_FP16(src_data[iih*IW + iiw]); + } } } } @@ -13636,30 +12496,22 @@ static void wsp_ggml_compute_forward_conv_2d_f16_f32( return; } - // total patches in dst - const int np = ne2; - - // patches per thread - const int dp = (np + nth - 1)/nth; + wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + // wdata: [N*OH*OW, IC*KH*KW] + // dst: result [N, OC, OH, OW] + // src0: kernel [OC, IC, KH, KW] - // patch range for this thread - const int ip0 = dp*ith; - const int ip1 = MIN(ip0 + dp, np); + int64_t m = OC; + int64_t n = OH * OW; + int64_t k = IC * KH * KW; - wsp_ggml_fp16_t * const wdata = (wsp_ggml_fp16_t *) params->wdata + 0; + // [N, OC, OH, OW] = [OC, IC * KH * KW] x [N*OH*OW, IC * KH * KW] + for (int i = 0; i < N; i++) { + wsp_ggml_fp16_t * A = (wsp_ggml_fp16_t *)src0->data; // [m, k] + wsp_ggml_fp16_t * B = (wsp_ggml_fp16_t *)wdata + i * m * k; // [n, k] + float * C = (float *)dst->data + i * m * n; // [m * k] - for (int i3 = 0; i3 < ne3; i3++) { - for (int i2 = ip0; i2 < ip1; i2++) { - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2); - - for (int i1 = 0; i1 < ne1; ++i1) { - for (int i0 = 0; i0 < ne0; ++i0) { - wsp_ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0, - (wsp_ggml_fp16_t *) ((char *) src0->data + i2*nb03), - (wsp_ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0); - } - } - } + gemm_f16_out_f32(m, n, k, A, B, C, ith, nth); } } @@ -13685,6 +12537,48 @@ static void wsp_ggml_compute_forward_conv_2d( } } +static void wsp_ggml_compute_forward_conv_2d_stage_0( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_conv_2d_stage_0_f32(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(false); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + +static void wsp_ggml_compute_forward_conv_2d_stage_1( + const struct wsp_ggml_compute_params * params, + const struct wsp_ggml_tensor * src0, + const struct wsp_ggml_tensor * src1, + struct wsp_ggml_tensor * dst) { + switch (src0->type) { + case WSP_GGML_TYPE_F16: + { + wsp_ggml_compute_forward_conv_2d_stage_1_f16(params, src0, src1, dst); + } break; + case WSP_GGML_TYPE_F32: + { + WSP_GGML_ASSERT(false); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } +} + // wsp_ggml_compute_forward_conv_transpose_2d static void wsp_ggml_compute_forward_conv_transpose_2d( @@ -13699,7 +12593,7 @@ static void wsp_ggml_compute_forward_conv_transpose_2d( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_BINARY_OP_LOCALS; + WSP_GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; const int nth = params->nth; @@ -13743,6 +12637,8 @@ static void wsp_ggml_compute_forward_conv_transpose_2d( } } + memset(dst->data, 0, wsp_ggml_nbytes(dst)); + return; } @@ -13855,14 +12751,11 @@ static void wsp_ggml_compute_forward_pool_1d( wsp_ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst); } -// wsp_ggml_compute_forward_pool_2d_sk_p0 +// wsp_ggml_compute_forward_pool_2d -static void wsp_ggml_compute_forward_pool_2d_sk_p0( +static void wsp_ggml_compute_forward_pool_2d( const struct wsp_ggml_compute_params * params, - const enum wsp_ggml_op_pool op, const struct wsp_ggml_tensor * src, - const int k0, - const int k1, struct wsp_ggml_tensor * dst) { assert(src->type == WSP_GGML_TYPE_F32); assert(params->ith == 0); @@ -13871,6 +12764,14 @@ static void wsp_ggml_compute_forward_pool_2d_sk_p0( return; } + const int32_t * opts = (const int32_t *)dst->op_params; + enum wsp_ggml_op_pool op = opts[0]; + const int k0 = opts[1]; + const int k1 = opts[2]; + const int s0 = opts[3]; + const int s1 = opts[4]; + const int p0 = opts[5]; + const int p1 = opts[6]; const char * cdata = (const char*)src->data; const char * const data_end = cdata + wsp_ggml_nbytes(src); @@ -13881,6 +12782,8 @@ static void wsp_ggml_compute_forward_pool_2d_sk_p0( float * dplane = (float *)dst->data; const int ka = k0 * k1; + const int offset0 = -p0; + const int offset1 = -p1; while (cdata < data_end) { for (int oy = 0; oy < py; ++oy) { @@ -13893,13 +12796,15 @@ static void wsp_ggml_compute_forward_pool_2d_sk_p0( case WSP_GGML_OP_POOL_COUNT: WSP_GGML_ASSERT(false); break; } - const int ix = ox * k0; - const int iy = oy * k1; + const int ix = offset0 + ox * s0; + const int iy = offset1 + oy * s1; for (int ky = 0; ky < k1; ++ky) { + if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky)); for (int kx = 0; kx < k0; ++kx) { int j = ix + kx; + if (j < 0 || j >= src->ne[0]) continue; switch (op) { case WSP_GGML_OP_POOL_AVG: *out += srow[j]; break; case WSP_GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break; @@ -13916,31 +12821,8 @@ static void wsp_ggml_compute_forward_pool_2d_sk_p0( } cdata += src->nb[2]; - dplane += pa; - } -} - -// wsp_ggml_compute_forward_pool_2d - -static void wsp_ggml_compute_forward_pool_2d( - const struct wsp_ggml_compute_params * params, - const struct wsp_ggml_tensor * src0, - struct wsp_ggml_tensor * dst) { - - const int32_t * opts = (const int32_t *)dst->op_params; - enum wsp_ggml_op_pool op = opts[0]; - const int k0 = opts[1]; - const int k1 = opts[2]; - const int s0 = opts[3]; - const int s1 = opts[4]; - const int p0 = opts[5]; - const int p1 = opts[6]; - WSP_GGML_ASSERT(p0 == 0); - WSP_GGML_ASSERT(p1 == 0); // padding not supported - WSP_GGML_ASSERT(k0 == s0); - WSP_GGML_ASSERT(k1 == s1); // only s = k supported - - wsp_ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst); + dplane += pa; + } } // wsp_ggml_compute_forward_upscale @@ -13958,7 +12840,7 @@ static void wsp_ggml_compute_forward_upscale_f32( const int ith = params->ith; - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS const int scale_factor = dst->op_params[0]; @@ -14010,14 +12892,14 @@ static void wsp_ggml_compute_forward_flash_attn_f32( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14087,10 +12969,11 @@ static void wsp_ggml_compute_forward_flash_attn_f32( S[i] = -INFINITY; } - for (int64_t ic = 0; ic < nek1; ++ic) { + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14103,20 +12986,18 @@ static void wsp_ggml_compute_forward_flash_attn_f32( } // scale - wsp_ggml_vec_scale_f32(nek1, S, scale); + wsp_ggml_vec_scale_f32(masked_begin, S, scale); - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; } // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SW values to zero { float max = -INFINITY; - wsp_ggml_vec_max_f32(M, &max, S); + wsp_ggml_vec_max_f32(masked_begin, &max, S); wsp_ggml_float sum = 0.0; { @@ -14130,10 +13011,15 @@ static void wsp_ggml_compute_forward_flash_attn_f32( wsp_ggml_float sump[WSP_GGML_SOFT_MAX_UNROLL] = { 0.0 }; for (int i = 0; i < Mup; i += WSP_GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } float * SS = S + i; for (int j = 0; j < WSP_GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { + if (i + j >= masked_begin) { + break; + } else if (SS[j] == -INFINITY) { SS[j] = 0.0f; } else { #ifndef WSP_GGML_FLASH_ATTN_EXP_FP16 @@ -14141,7 +13027,7 @@ static void wsp_ggml_compute_forward_flash_attn_f32( #else wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt[j]]); #endif sump[j] += (wsp_ggml_float)val; SS[j] = val; @@ -14158,10 +13044,10 @@ static void wsp_ggml_compute_forward_flash_attn_f32( assert(sum > 0.0); sum = 1.0/sum; - wsp_ggml_vec_scale_f32(M, S, sum); + wsp_ggml_vec_scale_f32(masked_begin, S, sum); #ifndef NDEBUG - for (int i = 0; i < M; ++i) { + for (int i = 0; i < masked_begin; ++i) { assert(!isnan(S[i])); assert(!isinf(S[i])); } @@ -14174,9 +13060,13 @@ static void wsp_ggml_compute_forward_flash_attn_f32( const int i2 = iq2; const int i3 = iq3; - wsp_ggml_vec_dot_f32(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + wsp_ggml_vec_dot_f32(masked_begin, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S); } } @@ -14192,14 +13082,14 @@ static void wsp_ggml_compute_forward_flash_attn_f16( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14273,7 +13163,7 @@ static void wsp_ggml_compute_forward_flash_attn_f16( for (int64_t ic = 0; ic < nek1; ++ic) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14288,7 +13178,7 @@ static void wsp_ggml_compute_forward_flash_attn_f16( for (int64_t ic = 0; ic < nek1; ic += WSP_GGML_VEC_DOT_UNROLL) { // k indices const int ik3 = iq3; - const int ik2 = iq2; + const int ik2 = iq2 % nek2; const int ik1 = ic; // S indices @@ -14313,6 +13203,8 @@ static void wsp_ggml_compute_forward_flash_attn_f16( } // softmax + // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // dont forget to set their S values to zero { float max = -INFINITY; wsp_ggml_vec_max_f32(M, &max, S); @@ -14337,7 +13229,7 @@ static void wsp_ggml_compute_forward_flash_attn_f16( } else { wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt[j]]); sump[j] += (wsp_ggml_float)val; SS[j] = val; } @@ -14369,6 +13261,7 @@ static void wsp_ggml_compute_forward_flash_attn_f16( S16[i] = WSP_GGML_FP32_TO_FP16(S[i]); } + // todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16). if (WSP_GGML_VEC_DOT_UNROLL == 1 || (nev1 % WSP_GGML_VEC_DOT_UNROLL != 0)) { for (int64_t ic = 0; ic < nev1; ++ic) { // dst indices @@ -14376,9 +13269,13 @@ static void wsp_ggml_compute_forward_flash_attn_f16( const int i2 = iq2; const int i3 = iq3; - wsp_ggml_vec_dot_f16(nek1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - (wsp_ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + wsp_ggml_vec_dot_f16(nev0, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + (wsp_ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S16); } } else { @@ -14388,9 +13285,13 @@ static void wsp_ggml_compute_forward_flash_attn_f16( const int i2 = iq2; const int i3 = iq3; - wsp_ggml_vec_dot_f16_unroll(nek1, nbv1, - (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), - ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), + // v indices + const int iv2 = iq2 % nev2; + const int iv3 = iq3; + + wsp_ggml_vec_dot_f16_unroll(nev0, nbv1, + (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), + ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S16); } } @@ -14433,18 +13334,18 @@ static void wsp_ggml_compute_forward_flash_ff_f16( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_LOCALS(int64_t, nea, a, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nba, a, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, nea, a, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nba, a, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, neb0, b0, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbb0, b0, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, neb1, b1, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbb1, b1, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, nec0, c0, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbc0, c0, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, nec1, c1, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbc1, c1, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14592,16 +13493,16 @@ static void wsp_ggml_compute_forward_flash_attn_back_f32( int64_t t0 = wsp_ggml_perf_time_us(); UNUSED(t0); - WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, ned, d, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nbd, d, nb); - WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); - WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + WSP_GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, ned, d, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nbd, d, nb) + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + WSP_GGML_TENSOR_LOCALS(size_t, nb, dst, nb) const int ith = params->ith; const int nth = params->nth; @@ -14649,10 +13550,37 @@ static void wsp_ggml_compute_forward_flash_attn_back_f32( return; } - // parallelize by q rows using wsp_ggml_vec_dot_f32 + const int64_t elem_q = wsp_ggml_nelements(q); + const int64_t elem_k = wsp_ggml_nelements(k); - // total rows in q - const int nr = neq2*neq3; + enum wsp_ggml_type result_type = dst->type; + WSP_GGML_ASSERT(wsp_ggml_blck_size(result_type) == 1); + const size_t tsize = wsp_ggml_type_size(result_type); + + const size_t offs_q = 0; + const size_t offs_k = offs_q + WSP_GGML_PAD(elem_q * tsize, WSP_GGML_MEM_ALIGN); + const size_t offs_v = offs_k + WSP_GGML_PAD(elem_k * tsize, WSP_GGML_MEM_ALIGN); + + void * grad_q = (char *) dst->data; + void * grad_k = (char *) dst->data + offs_k; + void * grad_v = (char *) dst->data + offs_v; + + const size_t nbgq1 = nb0*neq0; + const size_t nbgq2 = nb0*neq0*neq1; + const size_t nbgq3 = nb0*neq0*neq1*neq2; + + const size_t nbgk1 = nb0*nek0; + const size_t nbgk2 = nb0*nek0*nek1; + const size_t nbgk3 = nb0*nek0*nek1*neq2; + + const size_t nbgv1 = nb0*nev0; + const size_t nbgv2 = nb0*nev0*nev1; + const size_t nbgv3 = nb0*nev0*nev1*neq2; + + // parallelize by k rows using wsp_ggml_vec_dot_f32 + + // total rows in k + const int nr = nek2*nek3; // rows per thread const int dr = (nr + nth - 1)/nth; @@ -14665,268 +13593,243 @@ static void wsp_ggml_compute_forward_flash_attn_back_f32( //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); + // how often k2 (and v2) is repeated in q2 + int nrep = neq2/nek2; + for (int ir = ir0; ir < ir1; ++ir) { // q indices - const int iq3 = ir/(neq2); - const int iq2 = ir - iq3*neq2; - for ( int iq1 = 0; iq1 < neq1; ++iq1) { + const int ik3 = ir/(nek2); + const int ik2 = ir - ik3*nek2; + const int iq3 = ik3; + const int id3 = ik3; + const int iv3 = ik3; + const int iv2 = ik2; - // not sure about CACHE_LINE_SIZE_F32.. - // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? - float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); - float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); + for (int irep = 0; irep < nrep; ++irep) { + const int iq2 = ik2 + irep*nek2; + const int id2 = iq2; - for (int i = M; i < Mup; ++i) { - S[i] = -INFINITY; - } + // (ik2 + irep*nek2) % nek2 == ik2 + for (int iq1 = 0; iq1 < neq1; ++iq1) { + const int id1 = iq1; - for (int64_t ic = 0; ic < nek1; ++ic) { - // k indices - const int ik3 = iq3; - const int ik2 = iq2; - const int ik1 = ic; + // not sure about CACHE_LINE_SIZE_F32.. + // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset? + float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32); + float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32); - // S indices - const int i1 = ik1; + for (int i = M; i < Mup; ++i) { + S[i] = -INFINITY; + } - wsp_ggml_vec_dot_f32(neq0, - S + i1, - (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), - (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); - } + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { + // k indices + const int ik1 = ic; - // scale - wsp_ggml_vec_scale_f32(nek1, S, scale); + // S indices + const int i1 = ik1; - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } + wsp_ggml_vec_dot_f32(neq0, + S + i1, + (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3))); } - } - // softmax - { - float max = -INFINITY; - wsp_ggml_vec_max_f32(M, &max, S); + // scale + wsp_ggml_vec_scale_f32(masked_begin, S, scale); - wsp_ggml_float sum = 0.0; + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; + } + + // softmax + // exclude known -INF S[..] values from max and loop + // dont forget to set their SM values to zero { + float max = -INFINITY; + wsp_ggml_vec_max_f32(masked_begin, &max, S); + + wsp_ggml_float sum = 0.0; + { #ifdef WSP_GGML_SOFT_MAX_ACCELERATE - max = -max; - vDSP_vsadd(SM, 1, &max, SM, 1, Mup); - vvexpf(SM, SM, &Mup); - wsp_ggml_vec_sum_f32(Mup, &sum, SM); + max = -max; + vDSP_vsadd(SM, 1, &max, SM, 1, Mup); + vvexpf(SM, SM, &Mup); + wsp_ggml_vec_sum_f32(Mup, &sum, SM); #else - uint16_t scvt[WSP_GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); - wsp_ggml_float sump[WSP_GGML_SOFT_MAX_UNROLL] = { 0.0 }; - - for (int i = 0; i < Mup; i += WSP_GGML_SOFT_MAX_UNROLL) { - float * SR = S + i; - float * SW = SM + i; + uint16_t scvt[WSP_GGML_SOFT_MAX_UNROLL]; UNUSED(scvt); + wsp_ggml_float sump[WSP_GGML_SOFT_MAX_UNROLL] = { 0.0 }; - for (int j = 0; j < WSP_GGML_SOFT_MAX_UNROLL; ++j) { - if (SR[j] == -INFINITY) { - SW[j] = 0.0f; - } else { + for (int i = 0; i < Mup; i += WSP_GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } + float * SR = S + i; + float * SW = SM + i; + + for (int j = 0; j < WSP_GGML_SOFT_MAX_UNROLL; ++j) { + if (i + j >= masked_begin) { + break; + } else if (SR[j] == -INFINITY) { + SW[j] = 0.0f; + } else { #ifndef WSP_GGML_FLASH_ATTN_EXP_FP16 - const float val = expf(SR[j] - max); + const float val = expf(SR[j] - max); #else - wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(SR[j] - max); - memcpy(&scvt[j], &s, sizeof(uint16_t)); - const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); + wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(SR[j] - max); + memcpy(&scvt[j], &s, sizeof(uint16_t)); + const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt[j]]); #endif - sump[j] += (wsp_ggml_float)val; - SW[j] = val; + sump[j] += (wsp_ggml_float)val; + SW[j] = val; + } } } - } - for (int i = 0; i < WSP_GGML_SOFT_MAX_UNROLL; i++) { - sum += sump[i]; - } + for (int i = 0; i < WSP_GGML_SOFT_MAX_UNROLL; i++) { + sum += sump[i]; + } #endif - } - - assert(sum > 0.0); - - sum = 1.0/sum; - wsp_ggml_vec_scale_f32(M, SM, sum); - - } - - // step-by-step explanation - { - // forward-process shape grads from backward process - // parallel_for iq2,iq3: - // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur] - // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] - // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur] - // for iq1: - // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur - // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur - // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 - // S0 = -Inf [D,1,1,1] - // ~S1[i] = dot(kcur[:D,i], qcur) - // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale - // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) - // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur - // ~S5[i] = dot(vcur[:,i], S4) - // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] - // ~dst[i,iq1,iq2,iq3] = S5[i] ^ - // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3] - // dst backward-/ grad[dst] = d - // - // output gradients with their dependencies: - // - // grad[kcur] = grad[S1].T @ qcur - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S4] = grad[S5] @ vcur - // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur - // grad[qcur] = grad[S1] @ kcur - // grad[vcur] = grad[S5].T @ S4 - // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 - // - // in post-order: - // - // S1 = qcur @ kcur.T - // S2 = S1 * scale - // S3 = diag_mask_inf(S2, P) - // S4 = softmax(S3) - // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur - // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) - // grad[S1] = diag_mask_zero(grad[S3], P) * scale - // grad[qcur] = grad[S1] @ kcur - // grad[kcur] = grad[S1].T @ qcur - // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 - // - // using less variables (SM=S4): - // - // S = diag_mask_inf(qcur @ kcur.T * scale, P) - // SM = softmax(S) - // S = d[:D,iq1,iq2,iq3] @ vcur - // dot_SM_gradSM = dot(SM, S) - // S = SM * (S - dot(SM, S)) - // S = diag_mask_zero(S, P) * scale - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM - } - - // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur - // S = d[:D,iq1,iq2,iq3] @ vcur - // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3] - wsp_ggml_vec_set_f32(M, S, 0); - for (int64_t ic = 0; ic < D; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + } - wsp_ggml_vec_mad_f32(M, - S, - (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)), - *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); - } + assert(sum > 0.0); - // S = SM * (S - dot(SM, S)) - float dot_SM_gradSM = 0; - wsp_ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S); - wsp_ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); - wsp_ggml_vec_mul_f32 (M, S, S, SM); + sum = 1.0/sum; + wsp_ggml_vec_scale_f32(masked_begin, SM, sum); - // S = diag_mask_zero(S, P) * scale - if (masked) { - // for (int64_t i = P + iq1 + 1; i < M; i++) { - // S[i] = 0; - // } - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = 0; - } } - } - wsp_ggml_vec_scale_f32(M, S, scale); - - void * grad_q = (char *) dst->data; - void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3; - void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3; - - const size_t nbgq1 = nb0*neq0; - const size_t nbgq2 = nb0*neq0*neq1; - const size_t nbgq3 = nb0*neq0*neq1*neq2; - - const size_t nbgk1 = nb0*nek0; - const size_t nbgk2 = nb0*nek0*nek1; - const size_t nbgk3 = nb0*nek0*nek1*neq2; - - const size_t nbgv1 = nb0*nev0; - const size_t nbgv2 = nb0*nev0*nev1; - const size_t nbgv3 = nb0*nev0*nev1*neq2; - - // S shape [M,1] - // SM shape [M,1] - // kcur shape [D,M] - // qcur shape [D,1] - // vcur shape [M,D] - // - // grad[q][:D,iq1,iq2,iq3] += S @ kcur - // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] - // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic] - // - //// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T) - //// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T) - for (int64_t ic = 0; ic < M; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - wsp_ggml_vec_mad_f32(D, - (float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)), - (float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)), - S[ic]); - } + // step-by-step explanation + { + // forward-process shape grads from backward process + // parallel_for ik2,ik3: + // for irep: + // iq2 = ik2 + irep*nek2 + // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur] + // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] + // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur] + // for iq1: + // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur + // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur + // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 + // S0 = -Inf [D,1,1,1] + // ~S1[i] = dot(kcur[:D,i], qcur) + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale + // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) + // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3] + // ~dst[i,iq1,iq2,iq3] = S5[i] ^ + // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3] + // dst backward-/ grad[dst] = d + // + // output gradients with their dependencies: + // + // grad[kcur] = grad[S1].T @ qcur + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S4] = grad[S5] @ vcur + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[qcur] = grad[S1] @ kcur + // grad[vcur] = grad[S5].T @ S4 + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // in post-order: + // + // S1 = qcur @ kcur.T + // S2 = S1 * scale + // S3 = diag_mask_inf(S2, P) + // S4 = softmax(S3) + // grad[S4] = d[:D,id1,id2,id3] @ vcur + // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) + // grad[S1] = diag_mask_zero(grad[S3], P) * scale + // grad[qcur] = grad[S1] @ kcur + // grad[kcur] = grad[S1].T @ qcur + // grad[vcur] = d[:D,id1,id2,id3].T @ S4 + // + // using less variables (SM=S4): + // + // S = diag_mask_inf(qcur @ kcur.T * scale, P) + // SM = softmax(S) + // S = d[:D,iq1,iq2,iq3] @ vcur + // dot_SM_gradSM = dot(SM, S) + // S = SM * (S - dot(SM, S)) + // S = diag_mask_zero(S, P) * scale + // + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[k][:D,:M,ik2,ik3] += S.T @ qcur + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + } - // grad[k][:D,:M,iq2,iq3] += S.T @ qcur - // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] - // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] - for (int64_t ic = 0; ic < M; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] + // for ic: + // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] + // exclude known future zero S[..] values from operation + wsp_ggml_vec_set_f32(masked_begin, S, 0); + for (int64_t ic = 0; ic < D; ++ic) { + wsp_ggml_vec_mad_f32(masked_begin, + S, + (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } - // wsp_ggml_vec_set_f32(D, - // (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), - // 0); - wsp_ggml_vec_mad_f32(D, - (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), - (float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)), - S[ic]); - } + // S = SM * (S - dot(SM, S)) + float dot_SM_gradSM = 0; + wsp_ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S); + wsp_ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); + wsp_ggml_vec_mul_f32 (masked_begin, S, S, SM); + + // S = diag_mask_zero(S, P) * scale + // already done by above wsp_ggml_vec_set_f32 + + // exclude known zero S[..] values from operation + wsp_ggml_vec_scale_f32(masked_begin, S, scale); + + // S shape [M,1] + // SM shape [M,1] + // kcur shape [D,M] + // qcur shape [D,1] + // vcur shape [M,D] + + // grad[q][:D,iq1,iq2,iq3] += S @ kcur + // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] + // for ic: + // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + wsp_ggml_vec_mad_f32(D, + (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), + (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), + S[ic]); + } - // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM - // grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M] - // grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M] - for (int64_t ic = 0; ic < D; ++ic) { - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // grad[k][:D,:M,iq2,iq3] += S.T @ qcur + // for ic: + // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] + // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { + wsp_ggml_vec_mad_f32(D, + (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), + (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), + S[ic]); + } - // wsp_ggml_vec_set_f32(M, - // (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), - // 0); - wsp_ggml_vec_mad_f32(M, - (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), - SM, - *(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3))); + // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM + // for ic: + // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] + // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] + // exclude known zero SM[..] values from mad + for (int64_t ic = 0; ic < D; ++ic) { + wsp_ggml_vec_mad_f32(masked_begin, + (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), + SM, + *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); + } } } } @@ -14962,8 +13865,8 @@ static void wsp_ggml_compute_forward_win_part_f32( return; } - WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) const int32_t nep0 = ((const int32_t *)(dst->op_params))[0]; const int32_t nep1 = ((const int32_t *)(dst->op_params))[1]; @@ -15024,8 +13927,8 @@ static void wsp_ggml_compute_forward_win_unpart_f32( return; } - WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne); - WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + WSP_GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + WSP_GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) const int32_t w = ((const int32_t *)(dst->op_params))[0]; @@ -15123,6 +14026,10 @@ static void wsp_ggml_compute_forward_unary( { wsp_ggml_compute_forward_silu(params, src0, dst); } break; + case WSP_GGML_UNARY_OP_LEAKY: + { + wsp_ggml_compute_forward_leaky(params, src0, dst); + } break; default: { WSP_GGML_ASSERT(false); @@ -15142,7 +14049,7 @@ static void wsp_ggml_compute_forward_get_rel_pos_f16( // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322 - WSP_GGML_TENSOR_UNARY_OP_LOCALS; + WSP_GGML_TENSOR_UNARY_OP_LOCALS const int64_t w = ne1; @@ -15220,7 +14127,6 @@ static void wsp_ggml_compute_forward_add_rel_pos_f32( const int ip0 = dp*ith; const int ip1 = MIN(ip0 + dp, np); - for (int64_t i13 = ip0; i13 < ip1; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { @@ -15287,7 +14193,6 @@ static void wsp_ggml_compute_forward_map_unary_f32( } } - static void wsp_ggml_compute_forward_map_unary( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, @@ -15335,7 +14240,6 @@ static void wsp_ggml_compute_forward_map_binary_f32( } } - static void wsp_ggml_compute_forward_map_binary( const struct wsp_ggml_compute_params * params, const struct wsp_ggml_tensor * src0, @@ -15387,7 +14291,6 @@ static void wsp_ggml_compute_forward_map_custom2_f32( fun(dst, a, b); } - // wsp_ggml_compute_forward_map_custom3 static void wsp_ggml_compute_forward_map_custom3_f32( @@ -15531,7 +14434,7 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_f32( #else wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(s0[i] - max); memcpy(&scvt, &s, sizeof(scvt)); - const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt]); + const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt]); #endif sum += (wsp_ggml_float)val; st[i] = val; @@ -15645,7 +14548,7 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32( #else wsp_ggml_fp16_t s = WSP_GGML_FP32_TO_FP16(s0[i] - max); memcpy(&scvt, &s, sizeof(scvt)); - const float val = WSP_GGML_FP16_TO_FP32(table_exp_f16[scvt]); + const float val = WSP_GGML_FP16_TO_FP32(wsp_ggml_table_exp_f16[scvt]); #endif sum += (wsp_ggml_float)val; ds0[i] = val; @@ -15662,7 +14565,6 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back_f32( wsp_ggml_vec_sub_f32(nc, ds0, ds0, s1); wsp_ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); - #ifndef NDEBUG for (int i = 0; i < nc; ++i) { assert(!isnan(ds0[i])); @@ -15690,12 +14592,15 @@ static void wsp_ggml_compute_forward_cross_entropy_loss_back( } } - ///////////////////////////////// static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, struct wsp_ggml_tensor * tensor) { WSP_GGML_ASSERT(params); + if (tensor->op == WSP_GGML_OP_NONE) { + return; + } + #ifdef WSP_GGML_USE_CUBLAS bool skip_cpu = wsp_ggml_cuda_compute_forward(params, tensor); if (skip_cpu) { @@ -15840,7 +14745,7 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st } break; case WSP_GGML_OP_GET_ROWS_BACK: { - wsp_ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); + wsp_ggml_compute_forward_get_rows_back(params, tensor->src[0], tensor->src[1], tensor); } break; case WSP_GGML_OP_DIAG: { @@ -15864,11 +14769,11 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st } break; case WSP_GGML_OP_ROPE: { - wsp_ggml_compute_forward_rope(params, tensor->src[0], tensor); + wsp_ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor); } break; case WSP_GGML_OP_ROPE_BACK: { - wsp_ggml_compute_forward_rope_back(params, tensor->src[0], tensor); + wsp_ggml_compute_forward_rope_back(params, tensor->src[0], tensor->src[1], tensor); } break; case WSP_GGML_OP_ALIBI: { @@ -15882,10 +14787,30 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st { wsp_ggml_compute_forward_conv_1d(params, tensor->src[0], tensor->src[1], tensor); } break; + case WSP_GGML_OP_CONV_1D_STAGE_0: + { + wsp_ggml_compute_forward_conv_1d_stage_0(params, tensor->src[0], tensor->src[1], tensor); + } break; + case WSP_GGML_OP_CONV_1D_STAGE_1: + { + wsp_ggml_compute_forward_conv_1d_stage_1(params, tensor->src[0], tensor->src[1], tensor); + } break; + case WSP_GGML_OP_CONV_TRANSPOSE_1D: + { + wsp_ggml_compute_forward_conv_transpose_1d(params, tensor->src[0], tensor->src[1], tensor); + } break; case WSP_GGML_OP_CONV_2D: { wsp_ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor); } break; + case WSP_GGML_OP_CONV_2D_STAGE_0: + { + wsp_ggml_compute_forward_conv_2d_stage_0(params, tensor->src[0], tensor->src[1], tensor); + } break; + case WSP_GGML_OP_CONV_2D_STAGE_1: + { + wsp_ggml_compute_forward_conv_2d_stage_1(params, tensor->src[0], tensor->src[1], tensor); + } break; case WSP_GGML_OP_CONV_TRANSPOSE_2D: { wsp_ggml_compute_forward_conv_transpose_2d(params, tensor->src[0], tensor->src[1], tensor); @@ -16013,7 +14938,265 @@ static void wsp_ggml_compute_forward(struct wsp_ggml_compute_params * params, st //////////////////////////////////////////////////////////////////////////////// -static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor, bool inplace) { +static size_t wsp_ggml_hash_size(size_t min_sz) { + // next primes after powers of two + static const size_t primes[] = { + 2, 3, 5, 11, 17, 37, 67, 131, 257, 521, 1031, + 2053, 4099, 8209, 16411, 32771, 65537, 131101, + 262147, 524309, 1048583, 2097169, 4194319, 8388617, + 16777259, 33554467, 67108879, 134217757, 268435459, + 536870923, 1073741827, 2147483659 + }; + static const size_t n_primes = sizeof(primes)/sizeof(primes[0]); + + // find the smallest prime that is larger or equal to min_sz + size_t l = 0; + size_t r = n_primes; + while (l < r) { + size_t m = (l + r)/2; + if (primes[m] < min_sz) { + l = m + 1; + } else { + r = m; + } + } + size_t sz = l < n_primes ? primes[l] : min_sz | 1; + return sz; +} + +static size_t wsp_ggml_hash(const void * p) { + return (size_t)p; +} + +size_t wsp_ggml_hash_find(const struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key) { + size_t h = wsp_ggml_hash(key) % hash_set.size; + + // linear probing + size_t i = h; + while (hash_set.keys[i] != NULL && hash_set.keys[i] != key) { + i = (i + 1) % hash_set.size; + if (i == h) { + // visited all hash table entries -> not found + return WSP_GGML_HASHTABLE_FULL; + } + } + return i; +} + +bool wsp_ggml_hash_contains(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key) { + size_t i = wsp_ggml_hash_find(hash_set, key); + return i != WSP_GGML_HASHTABLE_FULL && hash_set.keys[i] == key; +} + +size_t wsp_ggml_hash_insert(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key) { + size_t i = wsp_ggml_hash_find(hash_set, key); + + WSP_GGML_ASSERT(i != WSP_GGML_HASHTABLE_FULL); + + if (hash_set.keys[i] == key) { + return WSP_GGML_HASHTABLE_ALREADY_EXISTS; + } + + // insert + WSP_GGML_ASSERT(hash_set.keys[i] == NULL); + hash_set.keys[i] = key; + return i; +} + +size_t wsp_ggml_hash_find_or_insert(struct wsp_ggml_hash_set hash_set, struct wsp_ggml_tensor * key) { + size_t i = wsp_ggml_hash_find(hash_set, key); + + WSP_GGML_ASSERT(i != WSP_GGML_HASHTABLE_FULL); + + hash_set.keys[i] = key; + return i; +} + +static struct wsp_ggml_hash_set wsp_ggml_hash_set_new(size_t size) { + size = wsp_ggml_hash_size(size); + struct wsp_ggml_hash_set result; + result.size = size; + result.keys = malloc(sizeof(struct wsp_ggml_tensor *) * size); + memset(result.keys, 0, sizeof(struct wsp_ggml_tensor *) * size); + return result; +} + +static void wsp_ggml_hash_set_free(struct wsp_ggml_hash_set hash_set) { + free(hash_set.keys); +} + +struct hash_map { + struct wsp_ggml_hash_set set; + struct wsp_ggml_tensor ** vals; +}; + +static struct hash_map * wsp_ggml_new_hash_map(size_t size) { + struct hash_map * result = malloc(sizeof(struct hash_map)); + result->set = wsp_ggml_hash_set_new(size); + result->vals = malloc(sizeof(struct wsp_ggml_tensor *) * result->set.size); + memset(result->vals, 0, sizeof(struct wsp_ggml_tensor *) * result->set.size); + return result; +} + +static void wsp_ggml_hash_map_free(struct hash_map * map) { + wsp_ggml_hash_set_free(map->set); + free(map->vals); + free(map); +} + +// gradient checkpointing + +static struct wsp_ggml_tensor * wsp_ggml_recompute_graph_node( + struct wsp_ggml_context * ctx, + struct wsp_ggml_cgraph * graph, + struct hash_map * replacements, + struct wsp_ggml_tensor * node) { + + if (node == NULL) { + return NULL; + } + + if (node->is_param) { + return node; + } + + if (!wsp_ggml_hash_contains(graph->visited_hash_table, node)) { + return node; + } + + int count_children = 0; + for (int k = 0; k < WSP_GGML_MAX_SRC; ++k) { + if (node->src[k]) { + ++count_children; + } + } + + if (count_children == 0) { + return node; + } + + size_t i = wsp_ggml_hash_find(replacements->set, node); + WSP_GGML_ASSERT(i != WSP_GGML_HASHTABLE_FULL); // assert that not full + if (replacements->set.keys[i] == node) { + return replacements->vals[i]; + } + + struct wsp_ggml_tensor * clone = wsp_ggml_new_tensor(ctx, node->type, node->n_dims, node->ne); + + // insert clone into replacements + WSP_GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite + replacements->set.keys[i] = node; + replacements->vals[i] = clone; + + clone->op = node->op; + clone->grad = node->grad; + clone->is_param = node->is_param; + clone->extra = node->extra; + for (int k = 0; k < WSP_GGML_MAX_DIMS; ++k) { + clone->nb[k] = node->nb[k]; + } + for (int k = 0; k < WSP_GGML_MAX_SRC; ++k) { + clone->src[k] = wsp_ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); + } + if (node->view_src != NULL) { + clone->data = (node->view_src->data == NULL) + ? NULL // view_src not yet allocated + : (char *) node->view_src->data // view_src already allocated + + node->view_offs; + clone->view_src = node->view_src; + clone->view_offs = node->view_offs; + } + + WSP_GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (WSP_GGML_MAX_OP_PARAMS / sizeof(int32_t))); + WSP_GGML_ASSERT(sizeof(node->name) == WSP_GGML_MAX_NAME); + memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); + wsp_ggml_format_name(clone, "%s (clone)", wsp_ggml_get_name(node)); + + return clone; +} + +void wsp_ggml_build_backward_gradient_checkpointing( + struct wsp_ggml_context * ctx, + struct wsp_ggml_cgraph * gf, + struct wsp_ggml_cgraph * gb, + struct wsp_ggml_cgraph * gb_tmp, + struct wsp_ggml_tensor * * checkpoints, + int n_checkpoints) { + wsp_ggml_graph_cpy(gf, gb_tmp); + wsp_ggml_build_backward_expand(ctx, gf, gb_tmp, true); + + if (n_checkpoints <= 0) { + wsp_ggml_graph_cpy(gb_tmp, gb); + return; + } + + struct hash_map * replacements = wsp_ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); + + // insert checkpoints in replacements + for (int i = 0; i < n_checkpoints; ++i) { + size_t k = wsp_ggml_hash_find(replacements->set, checkpoints[i]); + WSP_GGML_ASSERT(k != WSP_GGML_HASHTABLE_FULL); // assert that not full + WSP_GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite + replacements->set.keys[k] = checkpoints[i]; + replacements->vals[k] = checkpoints[i]; + } + + wsp_ggml_graph_cpy(gf, gb); + // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], + // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), + // by recomputing them from checkpoints + for (int i = gf->n_nodes; in_nodes; ++i) { + struct wsp_ggml_tensor * node = gb_tmp->nodes[i]; + for (int k = 0; k < WSP_GGML_MAX_SRC; ++k) { + // insert new tensors recomputing src, reusing already made replacements, + // remember replacements: remember new tensors with mapping from corresponding gf nodes + // recurse for input tensors, + // unless (i.e. terminating when) input tensors are replacments (like checkpoints) + node->src[k] = wsp_ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); + } + // insert rewritten backward node with replacements made into resulting backward graph gb + wsp_ggml_build_forward_expand(gb, node); + } + + wsp_ggml_hash_map_free(replacements); +} + +// functions to change gradients considering the case that input a might be initial gradient with zero value + +static struct wsp_ggml_tensor * wsp_ggml_add_or_set(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b, struct wsp_ggml_hash_set zero_table) { + if (wsp_ggml_hash_contains(zero_table, a)) { + return b; + } else { + return wsp_ggml_add_impl(ctx, a, b, false); + } +} + +static struct wsp_ggml_tensor * wsp_ggml_acc_or_set(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct wsp_ggml_hash_set zero_table) { + if (wsp_ggml_hash_contains(zero_table, a)) { + struct wsp_ggml_tensor * a_zero = wsp_ggml_scale(ctx, a, wsp_ggml_new_f32(ctx, 0)); + return wsp_ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); + } else { + return wsp_ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); + } +} + +static struct wsp_ggml_tensor * wsp_ggml_add1_or_set(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b, struct wsp_ggml_hash_set zero_table) { + if (wsp_ggml_hash_contains(zero_table, a)) { + return wsp_ggml_repeat(ctx, b, a); + } else { + return wsp_ggml_add1_impl(ctx, a, b, false); + } +} + +static struct wsp_ggml_tensor * wsp_ggml_sub_or_set(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b, struct wsp_ggml_hash_set zero_table) { + if (wsp_ggml_hash_contains(zero_table, a)) { + return wsp_ggml_neg(ctx, b); + } else { + return wsp_ggml_sub_impl(ctx, a, b, false); + } +} + +static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor, struct wsp_ggml_hash_set zero_table) { struct wsp_ggml_tensor * src0 = tensor->src[0]; struct wsp_ggml_tensor * src1 = tensor->src[1]; @@ -16021,34 +15204,34 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ case WSP_GGML_OP_DUP: { if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case WSP_GGML_OP_ADD: { if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = wsp_ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); + src1->grad = wsp_ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); } } break; case WSP_GGML_OP_ADD1: { if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = wsp_ggml_add_impl(ctx, + src1->grad = wsp_ggml_add_or_set(ctx, src1->grad, wsp_ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - inplace); + zero_table); } } break; case WSP_GGML_OP_ACC: { if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -16065,117 +15248,117 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ nb1, nb2, nb3, offset); src1->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src1->grad, wsp_ggml_reshape(ctx, wsp_ggml_cont(ctx, tensor_grad_view), src1->grad), - inplace); + zero_table); } } break; case WSP_GGML_OP_SUB: { if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { - src1->grad = wsp_ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); + src1->grad = wsp_ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table); } } break; case WSP_GGML_OP_MUL: { if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_mul(ctx, src1, tensor->grad), - inplace); + zero_table); } if (src1->grad) { src1->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src1->grad, wsp_ggml_mul(ctx, src0, tensor->grad), - inplace); + zero_table); } } break; case WSP_GGML_OP_DIV: { if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_div(ctx, tensor->grad, src1), - inplace); + zero_table); } if (src1->grad) { src1->grad = - wsp_ggml_sub_impl(ctx, + wsp_ggml_sub_or_set(ctx, src1->grad, wsp_ggml_mul(ctx, tensor->grad, wsp_ggml_div(ctx, tensor, src1)), - inplace); + zero_table); } } break; case WSP_GGML_OP_SQR: { if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_scale(ctx, wsp_ggml_mul(ctx, src0, tensor->grad), wsp_ggml_new_f32(ctx, 2.0f)), - inplace); + zero_table); } } break; case WSP_GGML_OP_SQRT: { if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_scale(ctx, wsp_ggml_div(ctx, tensor->grad, tensor), wsp_ggml_new_f32(ctx, 0.5f)), - inplace); + zero_table); } } break; case WSP_GGML_OP_LOG: { if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_div(ctx, tensor->grad, src0), - inplace); + zero_table); } } break; case WSP_GGML_OP_SUM: { if (src0->grad) { src0->grad = - wsp_ggml_add1_impl(ctx, + wsp_ggml_add1_or_set(ctx, src0->grad, tensor->grad, - inplace); + zero_table); } } break; case WSP_GGML_OP_SUM_ROWS: { if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_repeat(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case WSP_GGML_OP_MEAN: @@ -16187,20 +15370,20 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ { // necessary for llama if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_repeat_back(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case WSP_GGML_OP_REPEAT_BACK: { if (src0->grad) { // TODO: test this - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_repeat(ctx, tensor->grad, src0->grad), - inplace); + zero_table); } } break; case WSP_GGML_OP_CONCAT: @@ -16222,10 +15405,10 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ float eps; memcpy(&eps, tensor->op_params, sizeof(float)); - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - inplace); + zero_table); } } break; case WSP_GGML_OP_RMS_NORM_BACK: @@ -16249,37 +15432,49 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix // ds1 = t.T.dot(dt) - // tensor.shape [m,p] - // src0.shape [n,m] - // src1.shape [n,p] + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] // necessary for llama if (src0->grad) { + struct wsp_ggml_tensor * s1_tg = + wsp_ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + tensor->grad); // [m,p,qq,rr] + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + if (ne2_broadcasted || ne3_broadcasted) { + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = wsp_ggml_repeat_back(ctx, s1_tg, src0); + } src0->grad = - wsp_ggml_add_impl(ctx, - src0->grad, - wsp_ggml_out_prod(ctx, // [n,m] - src1, // [n,p] - tensor->grad), // [m,p] - inplace); + wsp_ggml_add_or_set(ctx, + src0->grad, // [n,m,q1,r1] + s1_tg, // [n,m,q1,r1] + zero_table); } if (src1->grad) { src1->grad = - wsp_ggml_add_impl(ctx, - src1->grad, - // wsp_ggml_mul_mat(ctx, // [n,p] - // wsp_ggml_cont(ctx, // [m,n] - // wsp_ggml_transpose(ctx, src0)), // [m,n] - // tensor->grad), // [m,p] + wsp_ggml_add_or_set(ctx, + src1->grad, // [n,p,qq,rr] + // wsp_ggml_mul_mat(ctx, // [n,p,qq,rr] + // wsp_ggml_cont(ctx, // [m,n,q1,r1] + // wsp_ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // tensor->grad), // [m,p,qq,rr] // // when src0 is bigger than tensor->grad (this is mostly the case in llama), // // avoid transpose of src0, rather transpose smaller tensor->grad // // and then use wsp_ggml_out_prod - wsp_ggml_out_prod(ctx, // [n,p] - src0, // [n,m] - wsp_ggml_transpose(ctx, // [p,m] - tensor->grad)), // [m,p] - inplace); + wsp_ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + wsp_ggml_transpose(ctx, // [p,m,qq,rr] + tensor->grad)), // [m,p,qq,rr] + zero_table); } } break; case WSP_GGML_OP_OUT_PROD: @@ -16291,17 +15486,17 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ // necessary for llama if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_scale_impl(ctx, tensor->grad, src1, false), - inplace); + zero_table); } if (src1->grad) { src1->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src1->grad, wsp_ggml_sum(ctx, wsp_ggml_mul_impl(ctx, tensor->grad, src0, false)), - inplace); + zero_table); } } break; case WSP_GGML_OP_SET: @@ -16328,23 +15523,23 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ } if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_acc_impl(ctx, tensor->grad, wsp_ggml_neg(ctx, tensor_grad_view), nb1, nb2, nb3, offset, false), - inplace); + zero_table); } if (src1->grad) { src1->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src1->grad, wsp_ggml_reshape(ctx, wsp_ggml_cont(ctx, tensor_grad_view), src1->grad), - inplace); + zero_table); } } break; case WSP_GGML_OP_CPY: @@ -16355,7 +15550,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ // tensor = src0 * 1 + src1 * 0 if (src0->grad) { // dsrc0 = dtensor * 1 - src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } if (src1->grad) { // dsrc1 = dtensor * 0 -> noop @@ -16367,7 +15562,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ if (src0->grad) { WSP_GGML_ASSERT(wsp_ggml_is_contiguous(src0->grad)); WSP_GGML_ASSERT(wsp_ggml_is_contiguous(tensor->grad)); - src0->grad = wsp_ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case WSP_GGML_OP_RESHAPE: @@ -16375,9 +15570,13 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ // necessary for llama if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, src0->grad, - wsp_ggml_reshape(ctx, tensor->grad, src0->grad), - inplace); + wsp_ggml_add_or_set(ctx, src0->grad, + wsp_ggml_reshape(ctx, + wsp_ggml_is_contiguous(tensor->grad) + ? tensor->grad + : wsp_ggml_cont(ctx, tensor->grad), + src0->grad), + zero_table); } } break; case WSP_GGML_OP_VIEW: @@ -16406,7 +15605,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ nb3 = (nb3 / n0) * ng; } - src0->grad = wsp_ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace); + src0->grad = wsp_ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table); } } break; case WSP_GGML_OP_PERMUTE: @@ -16424,14 +15623,14 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ axes_backward[axis2] = 2; axes_backward[axis3] = 3; src0->grad = - wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_permute(ctx, tensor->grad, axes_backward[0], axes_backward[1], axes_backward[2], axes_backward[3]), - inplace); + zero_table); } } break; case WSP_GGML_OP_TRANSPOSE: @@ -16439,9 +15638,9 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ // necessary for llama if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_transpose(ctx, tensor->grad), - inplace); + zero_table); } } break; case WSP_GGML_OP_GET_ROWS: @@ -16449,9 +15648,11 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ // necessary for llama (only for tokenizer) if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_add_or_set(ctx, src0->grad, + // last wsp_ggml_get_rows_back argument src0->grad is only + // necessary to setup correct output shape wsp_ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - inplace); + zero_table); } if (src1->grad) { // noop @@ -16471,9 +15672,9 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ if (src0->grad) { const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = - wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - inplace); + zero_table); } } break; case WSP_GGML_OP_DIAG_MASK_ZERO: @@ -16482,9 +15683,9 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ if (src0->grad) { const int n_past = ((int32_t *) tensor->op_params)[0]; src0->grad = - wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - inplace); + zero_table); } } break; case WSP_GGML_OP_SOFT_MAX: @@ -16492,9 +15693,9 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ // necessary for llama if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, src0->grad, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_soft_max_back(ctx, tensor->grad, tensor), - inplace); + zero_table); } } break; @@ -16506,7 +15707,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ { // necessary for llama if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16519,11 +15720,11 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_rope_back(ctx, tensor->grad, - n_past, + src1, n_dims, mode, n_ctx, @@ -16531,13 +15732,13 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ freq_scale, xpos_base, xpos_down), - inplace); + zero_table); } } break; case WSP_GGML_OP_ROPE_BACK: { if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; + //const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_dims = ((int32_t *) tensor->op_params)[1]; const int mode = ((int32_t *) tensor->op_params)[2]; const int n_ctx = ((int32_t *) tensor->op_params)[3]; @@ -16550,20 +15751,25 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float)); memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool)); - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_rope_impl(ctx, tensor->grad, - n_past, + src1, n_dims, mode, + 0, n_ctx, freq_base, freq_scale, + 0.0f, + 1.0f, + 0.0f, + 0.0f, xpos_base, xpos_down, false), - inplace); + zero_table); } } break; case WSP_GGML_OP_ALIBI: @@ -16578,10 +15784,30 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ { WSP_GGML_ASSERT(false); // TODO: not implemented } break; + case WSP_GGML_OP_CONV_1D_STAGE_0: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_CONV_1D_STAGE_1: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_CONV_TRANSPOSE_1D: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; case WSP_GGML_OP_CONV_2D: { WSP_GGML_ASSERT(false); // TODO: not implemented } break; + case WSP_GGML_OP_CONV_2D_STAGE_0: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; + case WSP_GGML_OP_CONV_2D_STAGE_1: + { + WSP_GGML_ASSERT(false); // TODO: not implemented + } break; case WSP_GGML_OP_CONV_TRANSPOSE_2D: { WSP_GGML_ASSERT(false); // TODO: not implemented @@ -16614,145 +15840,42 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ masked); } - if (src0->grad) { - struct wsp_ggml_tensor * grad_q = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = 0; - switch(src0->n_dims) { - case 2: - { - grad_q = wsp_ggml_view_2d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - nb0*src0->ne[0], - offset); - } break; - case 3: - { - grad_q = wsp_ggml_view_3d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - src0->ne[2], - nb0*src0->ne[0], - nb0*src0->ne[0]*src0->ne[1], - offset); - } break; - case 4: - { - grad_q = wsp_ggml_view_4d(ctx, - flash_grad, - src0->ne[0], - src0->ne[1], - src0->ne[2], - src0->ne[3], - nb0*src0->ne[0], - nb0*src0->ne[0]*src0->ne[1], - nb0*src0->ne[0]*src0->ne[1]*src0->ne[2], - offset); - } break; - } + struct wsp_ggml_tensor * src2 = tensor->src[2]; + const int64_t elem_q = wsp_ggml_nelements(src0); + const int64_t elem_k = wsp_ggml_nelements(src1); + const int64_t elem_v = wsp_ggml_nelements(src2); + + enum wsp_ggml_type result_type = flash_grad->type; + WSP_GGML_ASSERT(wsp_ggml_blck_size(result_type) == 1); + const size_t tsize = wsp_ggml_type_size(result_type); - src0->grad = wsp_ggml_add_impl(ctx, + const size_t offs_q = 0; + const size_t offs_k = offs_q + WSP_GGML_PAD(elem_q * tsize, WSP_GGML_MEM_ALIGN); + const size_t offs_v = offs_k + WSP_GGML_PAD(elem_k * tsize, WSP_GGML_MEM_ALIGN); + + if (src0->grad) { + struct wsp_ggml_tensor * view_q = wsp_ggml_view_1d(ctx, flash_grad, elem_q, offs_q); + struct wsp_ggml_tensor * grad_q = wsp_ggml_reshape(ctx, view_q, src0); + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, grad_q, - inplace); + zero_table); } - - if (src1->grad) { - struct wsp_ggml_tensor * grad_k = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]; - switch(src1->n_dims) { - case 2: - { - grad_k = wsp_ggml_view_2d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - nb0*src1->ne[0], - offset); - } break; - case 3: - { - grad_k = wsp_ggml_view_3d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - src1->ne[2], - nb0*src1->ne[0], - nb0*src1->ne[0]*src1->ne[1], - offset); - } break; - case 4: - { - grad_k = wsp_ggml_view_4d(ctx, - flash_grad, - src1->ne[0], - src1->ne[1], - src1->ne[2], - src1->ne[3], - nb0*src1->ne[0], - nb0*src1->ne[0]*src1->ne[1], - nb0*src1->ne[0]*src1->ne[1]*src1->ne[2], - offset); - } break; - } - - src1->grad = wsp_ggml_add_impl(ctx, + if (src1->grad) { + struct wsp_ggml_tensor * view_k = wsp_ggml_view_1d(ctx, flash_grad, elem_k, offs_k); + struct wsp_ggml_tensor * grad_k = wsp_ggml_reshape(ctx, view_k, src1); + src1->grad = wsp_ggml_add_or_set(ctx, src1->grad, grad_k, - inplace); + zero_table); } - - struct wsp_ggml_tensor * opt0 = tensor->src[2]; - - if (opt0->grad) { - struct wsp_ggml_tensor * grad_v = NULL; - const size_t nb0 = flash_grad->nb[0]; - const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3] - + nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3]; - switch(opt0->n_dims) { - case 2: - { - grad_v = wsp_ggml_view_2d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - nb0*opt0->ne[0], - offset); - } break; - case 3: - { - grad_v = wsp_ggml_view_3d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - opt0->ne[2], - nb0*opt0->ne[0], - nb0*opt0->ne[0]*opt0->ne[1], - offset); - } break; - case 4: - { - grad_v = wsp_ggml_view_4d(ctx, - flash_grad, - opt0->ne[0], - opt0->ne[1], - opt0->ne[2], - opt0->ne[3], - nb0*opt0->ne[0], - nb0*opt0->ne[0]*opt0->ne[1], - nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2], - offset); - } break; - } - - opt0->grad = wsp_ggml_add_impl(ctx, - opt0->grad, + if (src2->grad) { + struct wsp_ggml_tensor * view_v = wsp_ggml_view_1d(ctx, flash_grad, elem_v, offs_v); + struct wsp_ggml_tensor * grad_v = wsp_ggml_reshape(ctx, view_v, src2); + src2->grad = wsp_ggml_add_or_set(ctx, + src2->grad, grad_v, - inplace); + zero_table); } } break; case WSP_GGML_OP_FLASH_FF: @@ -16772,12 +15895,12 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ { if (src0->grad) { src0->grad = - wsp_ggml_add_impl(ctx, + wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_mul(ctx, wsp_ggml_sgn(ctx, src0), tensor->grad), - inplace); + zero_table); } } break; case WSP_GGML_UNARY_OP_SGN: @@ -16789,7 +15912,7 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ case WSP_GGML_UNARY_OP_NEG: { if (src0->grad) { - src0->grad = wsp_ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); + src0->grad = wsp_ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table); } } break; case WSP_GGML_UNARY_OP_STEP: @@ -16809,12 +15932,12 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ case WSP_GGML_UNARY_OP_RELU: { if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_mul(ctx, wsp_ggml_step(ctx, src0), tensor->grad), - inplace); + zero_table); } } break; case WSP_GGML_UNARY_OP_GELU: @@ -16829,10 +15952,10 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ { // necessary for llama if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_silu_back(ctx, src0, tensor->grad), - inplace); + zero_table); } } break; default: @@ -16855,13 +15978,13 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ case WSP_GGML_OP_CROSS_ENTROPY_LOSS: { if (src0->grad) { - src0->grad = wsp_ggml_add_impl(ctx, + src0->grad = wsp_ggml_add_or_set(ctx, src0->grad, wsp_ggml_cross_entropy_loss_back(ctx, src0, src1, tensor->grad), - inplace); + zero_table); } } break; case WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK: @@ -16877,34 +16000,12 @@ static void wsp_ggml_compute_backward(struct wsp_ggml_context * ctx, struct wsp_ WSP_GGML_ASSERT(false); } break; } -} - -static_assert(WSP_GGML_GRAPH_HASHTABLE_SIZE > WSP_GGML_MAX_NODES * 2, "WSP_GGML_GRAPH_HT_SIZE is too small"); - -static size_t hash(void * p) { - return (size_t)p % WSP_GGML_GRAPH_HASHTABLE_SIZE; -} - -static bool hash_insert(void * hash_table[], void * p) { - size_t h = hash(p); - // linear probing - size_t i = h; - while (hash_table[i] != NULL && hash_table[i] != p) { - i = (i + 1) % WSP_GGML_GRAPH_HASHTABLE_SIZE; - if (i == h) { - // hash table is full - WSP_GGML_ASSERT(false); + for (int i = 0; i < WSP_GGML_MAX_SRC; ++i) { + if (tensor->src[i] && tensor->src[i]->grad) { + WSP_GGML_ASSERT(wsp_ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); } } - - if (hash_table[i] == p) { - return true; - } - - // insert - hash_table[i] = p; - return false; } static void wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * node) { @@ -16917,19 +16018,23 @@ static void wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_g } // check if already visited - if (hash_insert(cgraph->visited_hash_table, node)) { + if (wsp_ggml_hash_insert(cgraph->visited_hash_table, node) == WSP_GGML_HASHTABLE_ALREADY_EXISTS) { return; } for (int i = 0; i < WSP_GGML_MAX_SRC; ++i) { - if (node->src[i]) { - wsp_ggml_visit_parents(cgraph, node->src[i]); + const int k = + (cgraph->order == WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : + (cgraph->order == WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (WSP_GGML_MAX_SRC-1-i) : + /* unknown order, just fall back to using i*/ i; + if (node->src[k]) { + wsp_ggml_visit_parents(cgraph, node->src[k]); } } if (node->op == WSP_GGML_OP_NONE && node->grad == NULL) { // reached a leaf node, not part of the gradient graph (e.g. a constant) - WSP_GGML_ASSERT(cgraph->n_leafs < WSP_GGML_MAX_NODES); + WSP_GGML_ASSERT(cgraph->n_leafs < cgraph->size); if (strlen(node->name) == 0) { wsp_ggml_format_name(node, "leaf_%d", cgraph->n_leafs); @@ -16938,22 +16043,24 @@ static void wsp_ggml_visit_parents(struct wsp_ggml_cgraph * cgraph, struct wsp_g cgraph->leafs[cgraph->n_leafs] = node; cgraph->n_leafs++; } else { - WSP_GGML_ASSERT(cgraph->n_nodes < WSP_GGML_MAX_NODES); + WSP_GGML_ASSERT(cgraph->n_nodes < cgraph->size); if (strlen(node->name) == 0) { wsp_ggml_format_name(node, "node_%d", cgraph->n_nodes); } cgraph->nodes[cgraph->n_nodes] = node; - cgraph->grads[cgraph->n_nodes] = node->grad; + if (cgraph->grads) { + cgraph->grads[cgraph->n_nodes] = node->grad; + } cgraph->n_nodes++; } } static void wsp_ggml_build_forward_impl(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor, bool expand) { if (!expand) { - cgraph->n_nodes = 0; - cgraph->n_leafs = 0; + // TODO: this branch isn't accessible anymore, maybe move this to wsp_ggml_build_forward_expand + wsp_ggml_graph_clear(cgraph); } const int n0 = cgraph->n_nodes; @@ -16974,24 +16081,6 @@ void wsp_ggml_build_forward_expand(struct wsp_ggml_cgraph * cgraph, struct wsp_g wsp_ggml_build_forward_impl(cgraph, tensor, true); } -struct wsp_ggml_cgraph wsp_ggml_build_forward(struct wsp_ggml_tensor * tensor) { - struct wsp_ggml_cgraph result = { - /*.n_nodes =*/ 0, - /*.n_leafs =*/ 0, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.hash_table =*/ { NULL }, - /*.perf_runs =*/ 0, - /*.perf_cycles =*/ 0, - /*.perf_time_us =*/ 0, - }; - - wsp_ggml_build_forward_impl(&result, tensor, false); - - return result; -} - void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep) { WSP_GGML_ASSERT(gf->n_nodes > 0); @@ -17007,12 +16096,21 @@ void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_gg } } + // remember original gradients which start with zero values + struct wsp_ggml_hash_set zero_table = wsp_ggml_hash_set_new(gf->size); + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->grads[i]) { + wsp_ggml_hash_insert(zero_table, gf->grads[i]); + } + } + for (int i = gf->n_nodes - 1; i >= 0; i--) { struct wsp_ggml_tensor * node = gf->nodes[i]; - // because we detached the grad nodes from the original graph, we can afford inplace operations + // inplace operations to add gradients are not created by wsp_ggml_compute_backward + // use allocator to automatically make inplace operations if (node->grad) { - wsp_ggml_compute_backward(ctx, node, keep); + wsp_ggml_compute_backward(ctx, node, zero_table); } } @@ -17024,25 +16122,56 @@ void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_gg wsp_ggml_build_forward_expand(gb, node->grad); } } + + wsp_ggml_hash_set_free(zero_table); } -struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep) { - struct wsp_ggml_cgraph result = *gf; - wsp_ggml_build_backward_expand(ctx, gf, &result, keep); - return result; +static size_t wsp_ggml_graph_nbytes(size_t size, bool grads) { + size_t nbytes = sizeof(struct wsp_ggml_cgraph); + nbytes += size * sizeof(struct wsp_ggml_tensor *) * 2; // leafs + nodes + if (grads) { + nbytes += size * sizeof(struct wsp_ggml_tensor *); // grads + } + nbytes += wsp_ggml_hash_size(size * 2) * sizeof(struct wsp_ggml_tensor *); // hash set + return nbytes; } -struct wsp_ggml_cgraph * wsp_ggml_new_graph(struct wsp_ggml_context * ctx) { - struct wsp_ggml_object * obj = wsp_ggml_new_object(ctx, WSP_GGML_OBJECT_GRAPH, WSP_GGML_GRAPH_SIZE); +size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads) { + return WSP_GGML_OBJECT_SIZE + WSP_GGML_PAD(wsp_ggml_graph_nbytes(size, grads), WSP_GGML_MEM_ALIGN); +} + +size_t wsp_ggml_graph_overhead(void) { + return wsp_ggml_graph_overhead_custom(WSP_GGML_DEFAULT_GRAPH_SIZE, false); +} + +struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom(struct wsp_ggml_context * ctx, size_t size, bool grads) { + const size_t obj_size = wsp_ggml_graph_nbytes(size, grads); + struct wsp_ggml_object * obj = wsp_ggml_new_object(ctx, WSP_GGML_OBJECT_GRAPH, obj_size); struct wsp_ggml_cgraph * cgraph = (struct wsp_ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); + struct wsp_ggml_tensor ** data_start = (struct wsp_ggml_tensor **) (cgraph + 1); + + size_t hash_size = wsp_ggml_hash_size(size * 2); + struct wsp_ggml_tensor ** nodes_ptr = data_start; + struct wsp_ggml_tensor ** leafs_ptr = nodes_ptr + size; + struct wsp_ggml_tensor ** hash_keys_ptr = leafs_ptr + size; + struct wsp_ggml_tensor ** grads_ptr = grads ? hash_keys_ptr + hash_size : NULL; + + // check that we allocated the correct amount of memory + assert(obj_size == (size_t) ( + (grads ? (char *)(grads_ptr + size) : (char *)(hash_keys_ptr + hash_size)) - (char *)cgraph)); + + memset(hash_keys_ptr, 0, hash_size * sizeof(struct wsp_ggml_tensor *)); + *cgraph = (struct wsp_ggml_cgraph) { + /*.size =*/ size, /*.n_nodes =*/ 0, /*.n_leafs =*/ 0, - /*.nodes =*/ { NULL }, - /*.grads =*/ { NULL }, - /*.leafs =*/ { NULL }, - /*.hash_table =*/ { NULL }, + /*.nodes =*/ nodes_ptr, + /*.grads =*/ grads_ptr, + /*.leafs =*/ leafs_ptr, + /*.hash_table =*/ { hash_size, hash_keys_ptr }, + /*.order =*/ WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, @@ -17051,14 +16180,85 @@ struct wsp_ggml_cgraph * wsp_ggml_new_graph(struct wsp_ggml_context * ctx) { return cgraph; } -struct wsp_ggml_cgraph * wsp_ggml_build_forward_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor) { - struct wsp_ggml_cgraph * cgraph = wsp_ggml_new_graph(ctx); - wsp_ggml_build_forward_impl(cgraph, tensor, false); +struct wsp_ggml_cgraph * wsp_ggml_new_graph(struct wsp_ggml_context * ctx) { + return wsp_ggml_new_graph_custom(ctx, WSP_GGML_DEFAULT_GRAPH_SIZE, false); +} + +struct wsp_ggml_cgraph * wsp_ggml_graph_view(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph0, int i0, int i1) { + const size_t obj_size = sizeof(struct wsp_ggml_cgraph); + struct wsp_ggml_object * obj = wsp_ggml_new_object(ctx, WSP_GGML_OBJECT_GRAPH, obj_size); + struct wsp_ggml_cgraph * cgraph = (struct wsp_ggml_cgraph *) ((char *) ctx->mem_buffer + obj->offs); + + *cgraph = (struct wsp_ggml_cgraph) { + /*.size =*/ 0, + /*.n_nodes =*/ i1 - i0, + /*.n_leafs =*/ 0, + /*.nodes =*/ cgraph0->nodes + i0, + /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, + /*.leafs =*/ NULL, + /*.hash_table =*/ { 0, NULL }, + /*.order =*/ cgraph0->order, + /*.perf_runs =*/ 0, + /*.perf_cycles =*/ 0, + /*.perf_time_us =*/ 0, + }; + return cgraph; } -size_t wsp_ggml_graph_overhead(void) { - return WSP_GGML_OBJECT_SIZE + WSP_GGML_PAD(WSP_GGML_GRAPH_SIZE, WSP_GGML_MEM_ALIGN); +void wsp_ggml_graph_cpy(struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst) { + WSP_GGML_ASSERT(dst->size >= src->n_leafs); + WSP_GGML_ASSERT(dst->size >= src->n_nodes); + WSP_GGML_ASSERT(dst->visited_hash_table.size >= src->visited_hash_table.size); + + dst->n_leafs = src->n_leafs; + dst->n_nodes = src->n_nodes; + dst->order = src->order; + + for (int i = 0; i < src->n_leafs; ++i) { + dst->leafs[i] = src->leafs[i]; + } + + for (int i = 0; i < src->n_nodes; ++i) { + dst->nodes[i] = src->nodes[i]; + } + + if (src->grads) { + WSP_GGML_ASSERT(dst->grads != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + dst->grads[i] = src->grads[i]; + } + } + + for (size_t i = 0; i < src->visited_hash_table.size; ++i) { + if (src->visited_hash_table.keys[i]) { + wsp_ggml_hash_insert(dst->visited_hash_table, src->visited_hash_table.keys[i]); + } + } +} + +struct wsp_ggml_cgraph * wsp_ggml_graph_dup(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph) { + struct wsp_ggml_cgraph * result = wsp_ggml_new_graph_custom(ctx, cgraph->size, cgraph->grads != NULL); + wsp_ggml_graph_cpy(cgraph, result); + return result; +} + +void wsp_ggml_graph_reset(struct wsp_ggml_cgraph * cgraph) { + WSP_GGML_ASSERT(cgraph->grads != NULL); + + for (int i = 0; i < cgraph->n_nodes; i++) { + struct wsp_ggml_tensor * grad = cgraph->grads[i]; + + if (grad) { + wsp_ggml_set_zero(grad); + } + } +} + +void wsp_ggml_graph_clear(struct wsp_ggml_cgraph * cgraph) { + cgraph->n_leafs = 0; + cgraph->n_nodes = 0; + memset(cgraph->visited_hash_table.keys, 0, cgraph->visited_hash_table.size * sizeof(struct wsp_ggml_tensor *)); } // @@ -17202,13 +16402,253 @@ struct wsp_ggml_compute_state { struct wsp_ggml_compute_state_shared * shared; }; -static void wsp_ggml_graph_compute_perf_stats_node(struct wsp_ggml_tensor * node, const struct wsp_ggml_compute_state_shared * st) { - int64_t cycles_cur = wsp_ggml_perf_cycles() - st->perf_node_start_cycles; - int64_t time_us_cur = wsp_ggml_perf_time_us() - st->perf_node_start_time_us; +static void wsp_ggml_graph_compute_perf_stats_node(struct wsp_ggml_tensor * node, const struct wsp_ggml_compute_state_shared * st) { + int64_t cycles_cur = wsp_ggml_perf_cycles() - st->perf_node_start_cycles; + int64_t time_us_cur = wsp_ggml_perf_time_us() - st->perf_node_start_time_us; + + node->perf_runs++; + node->perf_cycles += cycles_cur; + node->perf_time_us += time_us_cur; +} + +static int wsp_ggml_get_n_tasks(struct wsp_ggml_tensor * node, int n_threads) { + int n_tasks = 0; + + switch (node->op) { + case WSP_GGML_OP_CPY: + case WSP_GGML_OP_DUP: + case WSP_GGML_OP_ADD: + case WSP_GGML_OP_ADD1: + case WSP_GGML_OP_ACC: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_SUB: + case WSP_GGML_OP_DIV: + case WSP_GGML_OP_SQR: + case WSP_GGML_OP_SQRT: + case WSP_GGML_OP_LOG: + case WSP_GGML_OP_SUM: + case WSP_GGML_OP_SUM_ROWS: + case WSP_GGML_OP_MEAN: + case WSP_GGML_OP_ARGMAX: + case WSP_GGML_OP_REPEAT: + case WSP_GGML_OP_REPEAT_BACK: + { + n_tasks = 1; + } break; + case WSP_GGML_OP_UNARY: + switch (wsp_ggml_get_unary_op(node)) { + case WSP_GGML_UNARY_OP_ABS: + case WSP_GGML_UNARY_OP_SGN: + case WSP_GGML_UNARY_OP_NEG: + case WSP_GGML_UNARY_OP_STEP: + case WSP_GGML_UNARY_OP_TANH: + case WSP_GGML_UNARY_OP_ELU: + case WSP_GGML_UNARY_OP_RELU: + case WSP_GGML_UNARY_OP_LEAKY: + { + n_tasks = 1; + } break; + + case WSP_GGML_UNARY_OP_GELU: + case WSP_GGML_UNARY_OP_GELU_QUICK: + case WSP_GGML_UNARY_OP_SILU: + { + n_tasks = n_threads; + } break; + } + break; + case WSP_GGML_OP_SILU_BACK: + case WSP_GGML_OP_MUL: + case WSP_GGML_OP_NORM: + case WSP_GGML_OP_RMS_NORM: + case WSP_GGML_OP_RMS_NORM_BACK: + case WSP_GGML_OP_GROUP_NORM: + case WSP_GGML_OP_CONCAT: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_MUL_MAT: + { + n_tasks = n_threads; + + // TODO: use different scheduling for different matrix sizes + //const int nr0 = wsp_ggml_nrows(node->src[0]); + //const int nr1 = wsp_ggml_nrows(node->src[1]); + + //n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); + +#if defined(WSP_GGML_USE_CUBLAS) + if (wsp_ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#elif defined(WSP_GGML_USE_CLBLAST) + if (wsp_ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#endif +#if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) + if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } +#endif + } break; + case WSP_GGML_OP_OUT_PROD: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_SCALE: + case WSP_GGML_OP_SET: + case WSP_GGML_OP_CONT: + case WSP_GGML_OP_RESHAPE: + case WSP_GGML_OP_VIEW: + case WSP_GGML_OP_PERMUTE: + case WSP_GGML_OP_TRANSPOSE: + case WSP_GGML_OP_GET_ROWS: + case WSP_GGML_OP_GET_ROWS_BACK: + case WSP_GGML_OP_DIAG: + { + n_tasks = 1; + } break; + case WSP_GGML_OP_DIAG_MASK_ZERO: + case WSP_GGML_OP_DIAG_MASK_INF: + case WSP_GGML_OP_SOFT_MAX: + case WSP_GGML_OP_SOFT_MAX_BACK: + case WSP_GGML_OP_ROPE: + case WSP_GGML_OP_ROPE_BACK: + case WSP_GGML_OP_ADD_REL_POS: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_ALIBI: + { + n_tasks = 1; //TODO + } break; + case WSP_GGML_OP_CLAMP: + { + n_tasks = 1; //TODO + } break; + case WSP_GGML_OP_CONV_1D: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_CONV_1D_STAGE_0: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_CONV_1D_STAGE_1: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_CONV_TRANSPOSE_1D: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_CONV_2D: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_CONV_2D_STAGE_0: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_CONV_2D_STAGE_1: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_CONV_TRANSPOSE_2D: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_POOL_1D: + case WSP_GGML_OP_POOL_2D: + { + n_tasks = 1; + } break; + case WSP_GGML_OP_UPSCALE: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_FLASH_ATTN: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_FLASH_FF: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_FLASH_ATTN_BACK: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_WIN_PART: + case WSP_GGML_OP_WIN_UNPART: + case WSP_GGML_OP_GET_REL_POS: + case WSP_GGML_OP_MAP_UNARY: + case WSP_GGML_OP_MAP_BINARY: + case WSP_GGML_OP_MAP_CUSTOM1_F32: + case WSP_GGML_OP_MAP_CUSTOM2_F32: + case WSP_GGML_OP_MAP_CUSTOM3_F32: + { + n_tasks = 1; + } break; + case WSP_GGML_OP_MAP_CUSTOM1: + { + struct wsp_ggml_map_custom1_op_params * p = (struct wsp_ggml_map_custom1_op_params *) node->op_params; + if (p->n_tasks == WSP_GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case WSP_GGML_OP_MAP_CUSTOM2: + { + struct wsp_ggml_map_custom2_op_params * p = (struct wsp_ggml_map_custom2_op_params *) node->op_params; + if (p->n_tasks == WSP_GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case WSP_GGML_OP_MAP_CUSTOM3: + { + struct wsp_ggml_map_custom3_op_params * p = (struct wsp_ggml_map_custom3_op_params *) node->op_params; + if (p->n_tasks == WSP_GGML_N_TASKS_MAX) { + n_tasks = n_threads; + } else { + n_tasks = MIN(p->n_tasks, n_threads); + } + } break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + n_tasks = n_threads; + } break; + case WSP_GGML_OP_NONE: + { + n_tasks = 1; + } break; + case WSP_GGML_OP_COUNT: + { + WSP_GGML_ASSERT(false); + } break; + default: + { + WSP_GGML_ASSERT(false); + } break; + } + + assert(n_tasks > 0); - node->perf_runs++; - node->perf_cycles += cycles_cur; - node->perf_time_us += time_us_cur; + return n_tasks; } static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { @@ -17217,7 +16657,6 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { const struct wsp_ggml_cgraph * cgraph = state->shared->cgraph; const struct wsp_ggml_cplan * cplan = state->shared->cplan; - const int * n_tasks_arr = cplan->n_tasks; const int n_threads = state->shared->n_threads; set_numa_thread_affinity(state->ith, n_threads); @@ -17242,9 +16681,9 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { if (node_n != -1) { /* FINALIZE */ - struct wsp_ggml_tensor * node = state->shared->cgraph->nodes[node_n]; + struct wsp_ggml_tensor * node = cgraph->nodes[node_n]; if (WSP_GGML_OP_HAS_FINALIZE[node->op]) { - params.nth = n_tasks_arr[node_n]; + params.nth = wsp_ggml_get_n_tasks(node, n_threads); wsp_ggml_compute_forward(¶ms, node); } wsp_ggml_graph_compute_perf_stats_node(node, state->shared); @@ -17255,7 +16694,7 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { WSP_GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); struct wsp_ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = n_tasks_arr[node_n]; + const int n_tasks = wsp_ggml_get_n_tasks(node, n_threads); state->shared->perf_node_start_cycles = wsp_ggml_perf_cycles(); state->shared->perf_node_start_time_us = wsp_ggml_perf_time_us(); @@ -17313,7 +16752,7 @@ static thread_ret_t wsp_ggml_graph_compute_thread(void * data) { /* COMPUTE */ struct wsp_ggml_tensor * node = cgraph->nodes[node_n]; - const int n_tasks = n_tasks_arr[node_n]; + const int n_tasks = wsp_ggml_get_n_tasks(node, n_threads); struct wsp_ggml_compute_params params = { /*.type =*/ WSP_GGML_TASK_COMPUTE, @@ -17347,122 +16786,46 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n struct wsp_ggml_tensor * node = cgraph->nodes[i]; + size_t cur = 0; + switch (node->op) { case WSP_GGML_OP_CPY: case WSP_GGML_OP_DUP: { n_tasks = n_threads; - size_t cur = 0; if (wsp_ggml_is_quantized(node->type)) { cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->ne[0] * n_tasks; } - - work_size = MAX(work_size, cur); } break; case WSP_GGML_OP_ADD: case WSP_GGML_OP_ADD1: { n_tasks = n_threads; - size_t cur = 0; - if (wsp_ggml_is_quantized(node->src[0]->type)) { cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; } - - work_size = MAX(work_size, cur); } break; case WSP_GGML_OP_ACC: { n_tasks = n_threads; - size_t cur = 0; - if (wsp_ggml_is_quantized(node->src[0]->type)) { cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; } - - work_size = MAX(work_size, cur); - } break; - case WSP_GGML_OP_SUB: - case WSP_GGML_OP_DIV: - case WSP_GGML_OP_SQR: - case WSP_GGML_OP_SQRT: - case WSP_GGML_OP_LOG: - case WSP_GGML_OP_SUM: - case WSP_GGML_OP_SUM_ROWS: - case WSP_GGML_OP_MEAN: - case WSP_GGML_OP_ARGMAX: - case WSP_GGML_OP_REPEAT: - case WSP_GGML_OP_REPEAT_BACK: - { - n_tasks = 1; - } break; - - case WSP_GGML_OP_UNARY: - { - switch (wsp_ggml_get_unary_op(node)) { - case WSP_GGML_UNARY_OP_ABS: - case WSP_GGML_UNARY_OP_SGN: - case WSP_GGML_UNARY_OP_NEG: - case WSP_GGML_UNARY_OP_STEP: - case WSP_GGML_UNARY_OP_TANH: - case WSP_GGML_UNARY_OP_ELU: - case WSP_GGML_UNARY_OP_RELU: - { - n_tasks = 1; - } break; - - case WSP_GGML_UNARY_OP_GELU: - case WSP_GGML_UNARY_OP_GELU_QUICK: - case WSP_GGML_UNARY_OP_SILU: - { - n_tasks = n_threads; - } break; - } - } break; - case WSP_GGML_OP_SILU_BACK: - case WSP_GGML_OP_MUL: - case WSP_GGML_OP_NORM: - case WSP_GGML_OP_RMS_NORM: - case WSP_GGML_OP_RMS_NORM_BACK: - case WSP_GGML_OP_GROUP_NORM: - { - n_tasks = n_threads; } break; - case WSP_GGML_OP_CONCAT: case WSP_GGML_OP_MUL_MAT: - case WSP_GGML_OP_OUT_PROD: { - n_tasks = n_threads; - - // TODO: use different scheduling for different matrix sizes - //const int nr0 = wsp_ggml_nrows(node->src[0]); - //const int nr1 = wsp_ggml_nrows(node->src[1]); - - //n_tasks = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); - - size_t cur = 0; const enum wsp_ggml_type vec_dot_type = type_traits[node->src[0]->type].vec_dot_type; -#if defined(WSP_GGML_USE_CUBLAS) - if (wsp_ggml_cuda_can_mul_mat(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning - } else -#elif defined(WSP_GGML_USE_CLBLAST) +#if defined(WSP_GGML_USE_CLBLAST) if (wsp_ggml_cl_can_mul_mat(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning cur = wsp_ggml_cl_mul_mat_get_wsize(node->src[0], node->src[1], node); } else #endif #if defined(WSP_GGML_USE_ACCELERATE) || defined(WSP_GGML_USE_OPENBLAS) if (wsp_ggml_compute_forward_mul_mat_use_blas(node->src[0], node->src[1], node)) { - n_tasks = 1; // TODO: this actually is doing nothing - // the threads are still spinning if (node->src[0]->type != WSP_GGML_TYPE_F32) { // here we need memory just for single 2D matrix from src0 cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(node->src[0]->ne[0]*node->src[0]->ne[1]); @@ -17471,79 +16834,75 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n #endif if (node->src[1]->type != vec_dot_type) { cur = wsp_ggml_type_size(vec_dot_type)*wsp_ggml_nelements(node->src[1])/wsp_ggml_blck_size(vec_dot_type); - } else { - cur = 0; } - - work_size = MAX(work_size, cur); - } break; - case WSP_GGML_OP_SCALE: - { - n_tasks = 1; - } break; - case WSP_GGML_OP_SET: - case WSP_GGML_OP_CONT: - case WSP_GGML_OP_RESHAPE: - case WSP_GGML_OP_VIEW: - case WSP_GGML_OP_PERMUTE: - case WSP_GGML_OP_TRANSPOSE: - case WSP_GGML_OP_GET_ROWS: - case WSP_GGML_OP_GET_ROWS_BACK: - case WSP_GGML_OP_DIAG: - { - n_tasks = 1; } break; - case WSP_GGML_OP_DIAG_MASK_ZERO: - case WSP_GGML_OP_DIAG_MASK_INF: - case WSP_GGML_OP_SOFT_MAX: - case WSP_GGML_OP_SOFT_MAX_BACK: - case WSP_GGML_OP_ROPE: - case WSP_GGML_OP_ROPE_BACK: - case WSP_GGML_OP_ADD_REL_POS: + case WSP_GGML_OP_OUT_PROD: { n_tasks = n_threads; - } break; - case WSP_GGML_OP_ALIBI: - { - n_tasks = 1; //TODO - } break; - case WSP_GGML_OP_CLAMP: - { - n_tasks = 1; //TODO + + if (wsp_ggml_is_quantized(node->src[0]->type)) { + cur = wsp_ggml_type_size(WSP_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } } break; case WSP_GGML_OP_CONV_1D: { - n_tasks = n_threads; - WSP_GGML_ASSERT(node->src[0]->ne[3] == 1); WSP_GGML_ASSERT(node->src[1]->ne[2] == 1); WSP_GGML_ASSERT(node->src[1]->ne[3] == 1); - size_t cur = 0; - const int nk = node->src[0]->ne[0]; + const int64_t ne00 = node->src[0]->ne[0]; + const int64_t ne01 = node->src[0]->ne[1]; + const int64_t ne02 = node->src[0]->ne[2]; + + const int64_t ne10 = node->src[1]->ne[0]; + const int64_t ne11 = node->src[1]->ne[1]; + + const int64_t ne0 = node->ne[0]; + const int64_t ne1 = node->ne[1]; + const int64_t nk = ne00; + const int64_t ew0 = nk * ne01; + + UNUSED(ne02); + UNUSED(ne10); + UNUSED(ne11); if (node->src[0]->type == WSP_GGML_TYPE_F16 && - node->src[1]->type == WSP_GGML_TYPE_F32) { - cur = sizeof(wsp_ggml_fp16_t)*( - nk*wsp_ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] + - ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1] - ); + node->src[1]->type == WSP_GGML_TYPE_F32) { + cur = sizeof(wsp_ggml_fp16_t)*(ne0*ne1*ew0); } else if (node->src[0]->type == WSP_GGML_TYPE_F32 && - node->src[1]->type == WSP_GGML_TYPE_F32) { - cur = sizeof(float)*( - nk*wsp_ggml_up32(node->src[0]->ne[1])*node->src[0]->ne[2] + - ( 2*(nk/2) + node->src[1]->ne[0])*node->src[1]->ne[1] - ); + node->src[1]->type == WSP_GGML_TYPE_F32) { + cur = sizeof(float)*(ne0*ne1*ew0); } else { WSP_GGML_ASSERT(false); } + } break; + case WSP_GGML_OP_CONV_TRANSPOSE_1D: + { + WSP_GGML_ASSERT(node->src[0]->ne[3] == 1); + WSP_GGML_ASSERT(node->src[1]->ne[2] == 1); + WSP_GGML_ASSERT(node->src[1]->ne[3] == 1); - work_size = MAX(work_size, cur); + const int64_t ne00 = node->src[0]->ne[0]; // K + const int64_t ne01 = node->src[0]->ne[1]; // Cout + const int64_t ne02 = node->src[0]->ne[2]; // Cin + + const int64_t ne10 = node->src[1]->ne[0]; // L + const int64_t ne11 = node->src[1]->ne[1]; // Cin + + if (node->src[0]->type == WSP_GGML_TYPE_F16 && + node->src[1]->type == WSP_GGML_TYPE_F32) { + cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02; + cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11; + } else if (node->src[0]->type == WSP_GGML_TYPE_F32 && + node->src[1]->type == WSP_GGML_TYPE_F32) { + cur += sizeof(float)*ne00*ne01*ne02; + cur += sizeof(float)*ne10*ne11; + } else { + WSP_GGML_ASSERT(false); + } } break; case WSP_GGML_OP_CONV_2D: { - n_tasks = n_threads; - const int64_t ne00 = node->src[0]->ne[0]; // W const int64_t ne01 = node->src[0]->ne[1]; // H const int64_t ne02 = node->src[0]->ne[2]; // C @@ -17556,30 +16915,26 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n const int64_t ne0 = node->ne[0]; const int64_t ne1 = node->ne[1]; const int64_t ne2 = node->ne[2]; + const int64_t ne3 = node->ne[3]; const int64_t nk = ne00*ne01; const int64_t ew0 = nk * ne02; UNUSED(ne03); UNUSED(ne2); - size_t cur = 0; - if (node->src[0]->type == WSP_GGML_TYPE_F16 && node->src[1]->type == WSP_GGML_TYPE_F32) { - cur = sizeof(wsp_ggml_fp16_t)*(ne0*ne1*ew0); + // im2col: [N*OH*OW, IC*KH*KW] + cur = sizeof(wsp_ggml_fp16_t)*(ne3*ne0*ne1*ew0); } else if (node->src[0]->type == WSP_GGML_TYPE_F32 && node->src[1]->type == WSP_GGML_TYPE_F32) { cur = sizeof(float)* (ne10*ne11*ne12); } else { WSP_GGML_ASSERT(false); } - - work_size = MAX(work_size, cur); } break; case WSP_GGML_OP_CONV_TRANSPOSE_2D: { - n_tasks = n_threads; - const int64_t ne00 = node->src[0]->ne[0]; // W const int64_t ne01 = node->src[0]->ne[1]; // H const int64_t ne02 = node->src[0]->ne[2]; // Channels Out @@ -17589,141 +16944,66 @@ struct wsp_ggml_cplan wsp_ggml_graph_plan(struct wsp_ggml_cgraph * cgraph, int n const int64_t ne11 = node->src[1]->ne[1]; // H const int64_t ne12 = node->src[1]->ne[2]; // Channels In - size_t cur = 0; cur += sizeof(wsp_ggml_fp16_t)*ne00*ne01*ne02*ne03; cur += sizeof(wsp_ggml_fp16_t)*ne10*ne11*ne12; - - work_size = MAX(work_size, cur); - } break; - case WSP_GGML_OP_POOL_1D: - case WSP_GGML_OP_POOL_2D: - { - n_tasks = 1; - } break; - case WSP_GGML_OP_UPSCALE: - { - n_tasks = n_threads; } break; case WSP_GGML_OP_FLASH_ATTN: { n_tasks = n_threads; - size_t cur = 0; - const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL); if (node->src[1]->type == WSP_GGML_TYPE_F32) { cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == WSP_GGML_TYPE_F16) { + } else if (node->src[1]->type == WSP_GGML_TYPE_F16) { cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } - - work_size = MAX(work_size, cur); } break; case WSP_GGML_OP_FLASH_FF: { n_tasks = n_threads; - size_t cur = 0; - if (node->src[1]->type == WSP_GGML_TYPE_F32) { cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == WSP_GGML_TYPE_F16) { + } else if (node->src[1]->type == WSP_GGML_TYPE_F16) { cur = sizeof(float)*node->src[1]->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*node->src[1]->ne[1]*n_tasks; // this is overestimated by x2 } - - work_size = MAX(work_size, cur); } break; case WSP_GGML_OP_FLASH_ATTN_BACK: { n_tasks = n_threads; - size_t cur = 0; - const int64_t D = node->src[0]->ne[0]; const int64_t ne11 = wsp_ggml_up(node->src[1]->ne[1], WSP_GGML_SOFT_MAX_UNROLL); const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in wsp_ggml_compute_forward_flash_attn_back if (node->src[1]->type == WSP_GGML_TYPE_F32) { cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } - - if (node->src[1]->type == WSP_GGML_TYPE_F16) { + } else if (node->src[1]->type == WSP_GGML_TYPE_F16) { cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } - - work_size = MAX(work_size, cur); - } break; - case WSP_GGML_OP_WIN_PART: - case WSP_GGML_OP_WIN_UNPART: - case WSP_GGML_OP_GET_REL_POS: - case WSP_GGML_OP_MAP_UNARY: - case WSP_GGML_OP_MAP_BINARY: - case WSP_GGML_OP_MAP_CUSTOM1_F32: - case WSP_GGML_OP_MAP_CUSTOM2_F32: - case WSP_GGML_OP_MAP_CUSTOM3_F32: - { - n_tasks = 1; - } break; - case WSP_GGML_OP_MAP_CUSTOM1: - { - struct wsp_ggml_map_custom1_op_params * p = (struct wsp_ggml_map_custom1_op_params *) node->op_params; - if (p->n_tasks == WSP_GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } - } break; - case WSP_GGML_OP_MAP_CUSTOM2: - { - struct wsp_ggml_map_custom2_op_params * p = (struct wsp_ggml_map_custom2_op_params *) node->op_params; - if (p->n_tasks == WSP_GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } - } break; - case WSP_GGML_OP_MAP_CUSTOM3: - { - struct wsp_ggml_map_custom3_op_params * p = (struct wsp_ggml_map_custom3_op_params *) node->op_params; - if (p->n_tasks == WSP_GGML_N_TASKS_MAX) { - n_tasks = n_threads; - } else { - n_tasks = MIN(p->n_tasks, n_threads); - } } break; + case WSP_GGML_OP_CROSS_ENTROPY_LOSS: { n_tasks = n_threads; - size_t cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); - - work_size = MAX(work_size, cur); - } break; - case WSP_GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - n_tasks = n_threads; - } break; - case WSP_GGML_OP_NONE: - { - n_tasks = 1; + cur = wsp_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; case WSP_GGML_OP_COUNT: { WSP_GGML_ASSERT(false); } break; + default: + break; } - cplan.n_tasks[i] = n_tasks; + work_size = MAX(work_size, cur); } if (work_size > 0) { @@ -17745,12 +17025,6 @@ int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cpla if (cplan->work_size > 0) { WSP_GGML_ASSERT(cplan->work_data); } - - for (int i = 0; i < cgraph->n_nodes; ++i) { - if (cgraph->nodes[i]->op != WSP_GGML_OP_NONE) { - WSP_GGML_ASSERT(cplan->n_tasks[i] > 0); - } - } } const int n_threads = cplan->n_threads; @@ -17823,16 +17097,6 @@ int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cpla return compute_status; } -void wsp_ggml_graph_reset(struct wsp_ggml_cgraph * cgraph) { - for (int i = 0; i < cgraph->n_nodes; i++) { - struct wsp_ggml_tensor * grad = cgraph->grads[i]; - - if (grad) { - wsp_ggml_set_zero(grad); - } - } -} - void wsp_ggml_graph_compute_with_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int n_threads) { struct wsp_ggml_cplan cplan = wsp_ggml_graph_plan(cgraph, n_threads); @@ -17959,12 +17223,12 @@ void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * f const uint32_t magic = WSP_GGML_FILE_MAGIC; const uint32_t version = WSP_GGML_FILE_VERSION; const uint32_t n_leafs = cgraph->n_leafs; - const uint32_t nodes = cgraph->n_nodes; + const uint32_t n_nodes = cgraph->n_nodes; fwrite(&magic, sizeof(uint32_t), 1, fout); fwrite(&version, sizeof(uint32_t), 1, fout); fwrite(&n_leafs, sizeof(uint32_t), 1, fout); - fwrite(&nodes, sizeof(uint32_t), 1, fout); + fwrite(&n_nodes, sizeof(uint32_t), 1, fout); fwrite(&size_eval, sizeof(uint64_t), 1, fout); } @@ -18052,7 +17316,7 @@ void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * f if (idx == -1) { for (int k = 0; k < cgraph->n_nodes; ++k) { if (args[j] == cgraph->nodes[k]) { - idx = WSP_GGML_MAX_NODES + k; + idx = cgraph->n_leafs + k; break; } } @@ -18060,6 +17324,7 @@ void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * f if (idx == -1) { fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); + fclose(fout); return; } @@ -18078,11 +17343,11 @@ void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * f } } -struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval) { +struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval) { assert(*ctx_data == NULL); assert(*ctx_eval == NULL); - struct wsp_ggml_cgraph result = { 0 }; + struct wsp_ggml_cgraph * result = NULL; struct wsp_ggml_tensor * data = NULL; @@ -18154,13 +17419,11 @@ struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml const uint32_t n_leafs = *(const uint32_t *) ptr; ptr += sizeof(n_leafs); const uint32_t n_nodes = *(const uint32_t *) ptr; ptr += sizeof(n_nodes); const uint64_t size_eval = *(const uint64_t *) ptr; ptr += sizeof(size_eval); - - result.n_leafs = n_leafs; - result.n_nodes = n_nodes; + const int graph_size = MAX(n_leafs, n_nodes); // create the data context { - const size_t overhead = (n_leafs + n_nodes)*wsp_ggml_tensor_overhead(); + const size_t overhead = (n_leafs + n_nodes)*wsp_ggml_tensor_overhead() + wsp_ggml_graph_overhead_custom(graph_size, false); struct wsp_ggml_init_params params = { .mem_size = size_eval + overhead, @@ -18176,6 +17439,12 @@ struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml } } + result = wsp_ggml_new_graph_custom(*ctx_eval, graph_size, false); + + result->n_leafs = n_leafs; + result->n_nodes = n_nodes; + + // leafs { uint32_t type; @@ -18214,7 +17483,7 @@ struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml tensor->nb[j] = nb[j]; } - result.leafs[i] = tensor; + result->leafs[i] = tensor; ptr += wsp_ggml_nbytes(tensor); @@ -18266,10 +17535,10 @@ struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml continue; } - if (arg_idx < WSP_GGML_MAX_NODES) { - args[j] = result.leafs[arg_idx]; + if (arg_idx < result->n_leafs) { + args[j] = result->leafs[arg_idx]; } else { - args[j] = result.nodes[arg_idx - WSP_GGML_MAX_NODES]; + args[j] = result->nodes[arg_idx - result->n_leafs]; } } @@ -18321,7 +17590,7 @@ struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml tensor->src[j] = args[j]; } - result.nodes[i] = tensor; + result->nodes[i] = tensor; fprintf(stderr, "%s: loaded node %d: '%16s', %3d dims, %9zu bytes\n", __func__, i, tensor->name, n_dims, wsp_ggml_nbytes(tensor)); } @@ -18568,7 +17837,7 @@ static void wsp_ggml_opt_get_params(int np, struct wsp_ggml_tensor * const ps[], } static void wsp_ggml_opt_get_grad(int np, struct wsp_ggml_tensor * const ps[], float * g) { - int i = 0; + int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = wsp_ggml_nelements(ps[p]) ; // TODO: add function to get all elements at once @@ -18578,6 +17847,17 @@ static void wsp_ggml_opt_get_grad(int np, struct wsp_ggml_tensor * const ps[], f } } +static void wsp_ggml_opt_acc_grad(int np, struct wsp_ggml_tensor * const ps[], float * g, float scale) { + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = wsp_ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + g[i++] += wsp_ggml_get_f32_1d(ps[p]->grad, j) * scale; + } + } +} + // // ADAM // @@ -18626,26 +17906,40 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_adam( const float eps = params.adam.eps; const float gclip = params.adam.gclip; const int decay_min_ndim = params.adam.decay_min_ndim; + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + float * g = opt->adam.g->data; // gradients float * m = opt->adam.m->data; // first moment float * v = opt->adam.v->data; // second moment float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - if (callback) { - callback(callback_data, &sched); - } - - // compute the function value - wsp_ggml_graph_reset (gf); - wsp_ggml_set_f32 (f->grad, 1.0f); - struct wsp_ggml_cplan cplan = wsp_ggml_graph_plan(gb, params.n_threads); struct wsp_ggml_object * obj = wsp_ggml_new_object(ctx, WSP_GGML_OBJECT_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - wsp_ggml_graph_compute(gb, &cplan); - opt->adam.fx_prev = wsp_ggml_get_f32_1d(f, 0); + bool cancel = false; + + // compute the function value + float fx = 0; + wsp_ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return WSP_GGML_OPT_CANCEL; + } + } + // wsp_ggml_graph_reset (gf); + wsp_ggml_set_f32 (f->grad, 1.0f); + wsp_ggml_graph_compute(gb, &cplan); + wsp_ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += wsp_ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; + + opt->adam.fx_prev = fx; opt->adam.fx_best = opt->adam.fx_prev; if (pf) { pf[opt->iter % params.past] = opt->adam.fx_prev; @@ -18690,12 +17984,8 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_adam( if (gclip > 0.0f) { // gradient clipping wsp_ggml_float sum = 0.0; - for (int p = 0; p < np; ++p) { - const int64_t ne = wsp_ggml_nelements(ps[p]); - for (int64_t j = 0; j < ne; ++j) { - float g = wsp_ggml_get_f32_1d(ps[p]->grad, j); - sum += (wsp_ggml_float)(g*g); - } + for (int64_t i = 0; i < nx; ++i) { + sum += (wsp_ggml_float)(g[i]*g[i]); } wsp_ggml_float norm = sqrt(sum); if (norm > (wsp_ggml_float) gclip) { @@ -18709,10 +17999,10 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_adam( const int64_t ne = wsp_ggml_nelements(ps[p]); const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; for (int64_t j = 0; j < ne; ++j) { - float x = wsp_ggml_get_f32_1d(ps[p], j); - float g = wsp_ggml_get_f32_1d(ps[p]->grad, j)*gnorm; - m[i] = m[i]*beta1 + g*(1.0f - beta1); - v[i] = v[i]*beta2 + g*g*(1.0f - beta2); + float x = wsp_ggml_get_f32_1d(ps[p], j); + float g_ = g[i]*gnorm; + m[i] = m[i]*beta1 + g_*(1.0f - beta1); + v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2); float mh = m[i]*beta1h; float vh = v[i]*beta2h; vh = sqrtf(vh) + eps; @@ -18723,19 +18013,25 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_adam( } } - if (callback) { - callback(callback_data, &sched); + fx = 0; + wsp_ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return WSP_GGML_OPT_CANCEL;; + } + } + // wsp_ggml_graph_reset (gf); + wsp_ggml_set_f32 (f->grad, 1.0f); + wsp_ggml_graph_compute(gb, &cplan); + wsp_ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += wsp_ggml_get_f32_1d(f, 0); } + fx *= accum_norm; - wsp_ggml_graph_reset (gf); - wsp_ggml_set_f32 (f->grad, 1.0f); - - wsp_ggml_graph_compute(gb, &cplan); - - const float fx = wsp_ggml_get_f32_1d(f, 0); opt->loss_after = fx; - // check convergence if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { WSP_GGML_PRINT_DEBUG("converged\n"); @@ -18812,11 +18108,11 @@ static enum wsp_ggml_opt_result linesearch_backtracking( float * step, const float * xp, struct wsp_ggml_tensor * f, - struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, struct wsp_ggml_cplan * cplan, const int np, struct wsp_ggml_tensor * ps[], + bool * cancel, wsp_ggml_opt_callback callback, void * callback_data) { int count = 0; @@ -18830,6 +18126,9 @@ static enum wsp_ggml_opt_result linesearch_backtracking( const float dec = 0.5f; const float inc = 2.1f; + const int n_accum = MAX(1, params->n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + if (*step <= 0.f) { return WSP_GGML_LINESEARCH_INVALID_PARAMETERS; } @@ -18847,12 +18146,6 @@ static enum wsp_ggml_opt_result linesearch_backtracking( dgtest = params->lbfgs.ftol*dginit; while (true) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, &sched); - } - wsp_ggml_vec_cpy_f32(nx, x, xp); wsp_ggml_vec_mad_f32(nx, x, d, *step); @@ -18860,14 +18153,25 @@ static enum wsp_ggml_opt_result linesearch_backtracking( { wsp_ggml_opt_set_params(np, ps, x); - wsp_ggml_graph_reset (gf); - wsp_ggml_set_f32 (f->grad, 1.0f); - - wsp_ggml_graph_compute(gb, cplan); - - wsp_ggml_opt_get_grad(np, ps, g); + *fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched, cancel); + if (*cancel) { + return WSP_GGML_OPT_CANCEL; + } + } + // wsp_ggml_graph_reset (gf); + wsp_ggml_set_f32 (f->grad, 1.0f); + wsp_ggml_graph_compute(gb, cplan); + wsp_ggml_opt_acc_grad(np, ps, g, accum_norm); + *fx += wsp_ggml_get_f32_1d(f, 0); + } + *fx *= accum_norm; - *fx = wsp_ggml_get_f32_1d(f, 0); } ++count; @@ -18913,7 +18217,7 @@ static enum wsp_ggml_opt_result linesearch_backtracking( (*step) *= width; } - return WSP_GGML_LINESEARCH_FAIL; + WSP_GGML_UNREACHABLE(); } static enum wsp_ggml_opt_result wsp_ggml_opt_lbfgs( @@ -18968,6 +18272,9 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_lbfgs( float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + float fx = 0.0f; // cost function value float xnorm = 0.0f; // ||x|| float gnorm = 0.0f; // ||g|| @@ -18981,24 +18288,30 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_lbfgs( float * lm_s = opt->lbfgs.lms->data; float * lm_y = opt->lbfgs.lmy->data; - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, &sched); - } + bool cancel = false; // evaluate the function value and its gradient { wsp_ggml_opt_set_params(np, ps, x); - wsp_ggml_graph_reset (gf); - wsp_ggml_set_f32 (f->grad, 1.0f); - - wsp_ggml_graph_compute(gb, &cplan); - - wsp_ggml_opt_get_grad(np, ps, g); - - fx = wsp_ggml_get_f32_1d(f, 0); + fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + return WSP_GGML_OPT_CANCEL; + } + } + // wsp_ggml_graph_reset (gf); + wsp_ggml_set_f32 (f->grad, 1.0f); + wsp_ggml_graph_compute(gb, &cplan); + wsp_ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += wsp_ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; opt->loss_before = fx; opt->loss_after = fx; @@ -19056,7 +18369,14 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_lbfgs( wsp_ggml_vec_cpy_f32(nx, xp, x); wsp_ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, &cplan, np, ps, callback, callback_data); + // TODO: instead of passing &cancel here, use the return code of the linesearch + // to determine if the optimization should be cancelled + // this is a simple change, but not doing this atm, since I don't have a nice + // way to test and don't want to break something with so many changes lined up + ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); + if (cancel) { + return WSP_GGML_OPT_CANCEL; + } if (ls < 0) { // linesearch failed - go back to the previous point and return @@ -19165,7 +18485,7 @@ static enum wsp_ggml_opt_result wsp_ggml_opt_lbfgs( step[0] = 1.0; } - return WSP_GGML_OPT_DID_NOT_CONVERGE; + WSP_GGML_UNREACHABLE(); } struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type type) { @@ -19175,16 +18495,19 @@ struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type ty case WSP_GGML_OPT_ADAM: { result = (struct wsp_ggml_opt_params) { - .type = WSP_GGML_OPT_ADAM, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, + .type = WSP_GGML_OPT_ADAM, + .graph_size = WSP_GGML_DEFAULT_GRAPH_SIZE, + .n_threads = 1, // FIXME: WSP_GGML_DEFAULT_N_THREADS ? + .past = 0, + .delta = 1e-5f, .max_no_improvement = 100, .print_forward_graph = true, .print_backward_graph = true, + .n_gradient_accumulation = 1, + .adam = { .n_iter = 10000, .sched = 1.000f, @@ -19203,16 +18526,19 @@ struct wsp_ggml_opt_params wsp_ggml_opt_default_params(enum wsp_ggml_opt_type ty case WSP_GGML_OPT_LBFGS: { result = (struct wsp_ggml_opt_params) { - .type = WSP_GGML_OPT_LBFGS, - .n_threads = 1, - .past = 0, - .delta = 1e-5f, + .type = WSP_GGML_OPT_LBFGS, + .graph_size = WSP_GGML_DEFAULT_GRAPH_SIZE, + .n_threads = 1, + .past = 0, + .delta = 1e-5f, .max_no_improvement = 0, .print_forward_graph = true, .print_backward_graph = true, + .n_gradient_accumulation = 1, + .lbfgs = { .m = 6, .n_iter = 100, @@ -19243,13 +18569,32 @@ WSP_GGML_API void wsp_ggml_opt_init( opt->iter = 0; opt->nx = nx; opt->just_initialized = true; + if (opt->ctx == NULL) { + struct wsp_ggml_init_params ctx_opt_params; + if (opt->params.type == WSP_GGML_OPT_ADAM) { + ctx_opt_params.mem_size = WSP_GGML_MEM_ALIGN*3 + wsp_ggml_tensor_overhead()*3 + wsp_ggml_type_size(WSP_GGML_TYPE_F32)*nx*3; + if (opt->params.past > 0) { + ctx_opt_params.mem_size += WSP_GGML_MEM_ALIGN + wsp_ggml_tensor_overhead() + wsp_ggml_type_size(WSP_GGML_TYPE_F32)*opt->params.past; + } + } else if (opt->params.type == WSP_GGML_OPT_LBFGS) { + ctx_opt_params.mem_size = WSP_GGML_MEM_ALIGN*9 + wsp_ggml_tensor_overhead()*9 + wsp_ggml_type_size(WSP_GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2); + if (opt->params.past > 0) { + ctx_opt_params.mem_size += WSP_GGML_MEM_ALIGN + wsp_ggml_tensor_overhead() + wsp_ggml_type_size(WSP_GGML_TYPE_F32)*opt->params.past; + } + } + ctx_opt_params.mem_buffer = NULL; + ctx_opt_params.no_alloc = false; + + opt->ctx = wsp_ggml_init(ctx_opt_params); + } switch (opt->params.type) { case WSP_GGML_OPT_ADAM: { - opt->adam.m = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); - opt->adam.v = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.g = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.m = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, nx); + opt->adam.v = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, nx); opt->adam.pf = params.past > 0 - ? wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, params.past) + ? wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, params.past) : NULL; wsp_ggml_set_zero(opt->adam.m); wsp_ggml_set_zero(opt->adam.v); @@ -19259,18 +18604,18 @@ WSP_GGML_API void wsp_ggml_opt_init( } break; case WSP_GGML_OPT_LBFGS: { - opt->lbfgs.x = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); - opt->lbfgs.xp = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); - opt->lbfgs.g = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); - opt->lbfgs.gp = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); - opt->lbfgs.d = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.x = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.xp = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.g = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.gp = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, nx); + opt->lbfgs.d = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, nx); opt->lbfgs.pf = params.past > 0 - ? wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, params.past) + ? wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, params.past) : NULL; - opt->lbfgs.lmal = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lmys = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_F32, params.lbfgs.m); - opt->lbfgs.lms = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, nx, params.lbfgs.m); - opt->lbfgs.lmy = wsp_ggml_new_tensor_2d(ctx, WSP_GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmal = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lmys = wsp_ggml_new_tensor_1d(opt->ctx, WSP_GGML_TYPE_F32, params.lbfgs.m); + opt->lbfgs.lms = wsp_ggml_new_tensor_2d(opt->ctx, WSP_GGML_TYPE_F32, nx, params.lbfgs.m); + opt->lbfgs.lmy = wsp_ggml_new_tensor_2d(opt->ctx, WSP_GGML_TYPE_F32, nx, params.lbfgs.m); wsp_ggml_set_zero(opt->lbfgs.x); wsp_ggml_set_zero(opt->lbfgs.xp); wsp_ggml_set_zero(opt->lbfgs.g); @@ -19327,14 +18672,11 @@ enum wsp_ggml_opt_result wsp_ggml_opt_resume( struct wsp_ggml_tensor * f) { // build forward + backward compute graphs - struct wsp_ggml_tensor * gfbuf = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(struct wsp_ggml_cgraph) / wsp_ggml_type_size(WSP_GGML_TYPE_I32)+ (sizeof(struct wsp_ggml_cgraph) % wsp_ggml_type_size(WSP_GGML_TYPE_I32) ? 1 : 0)); - struct wsp_ggml_tensor * gbbuf = wsp_ggml_new_tensor_1d(ctx, WSP_GGML_TYPE_I32, sizeof(struct wsp_ggml_cgraph) / wsp_ggml_type_size(WSP_GGML_TYPE_I32)+ (sizeof(struct wsp_ggml_cgraph) % wsp_ggml_type_size(WSP_GGML_TYPE_I32) ? 1 : 0)); + struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx, opt->params.graph_size, true); + wsp_ggml_build_forward_expand(gf, f); - struct wsp_ggml_cgraph * gf = (struct wsp_ggml_cgraph *) gfbuf->data; - struct wsp_ggml_cgraph * gb = (struct wsp_ggml_cgraph *) gbbuf->data; - - *gf = wsp_ggml_build_forward (f); - *gb = wsp_ggml_build_backward(ctx, gf, true); + struct wsp_ggml_cgraph * gb = wsp_ggml_graph_dup(ctx, gf); + wsp_ggml_build_backward_expand(ctx, gf, gb, true); return wsp_ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); } @@ -19537,7 +18879,6 @@ size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void block_q8_0 * block = (block_q8_0*)dst + start / QK8_0; result = wsp_ggml_quantize_q8_0(src + start, block, n, n, hist); } break; -#ifdef WSP_GGML_USE_K_QUANTS case WSP_GGML_TYPE_Q2_K: { WSP_GGML_ASSERT(start % QK_K == 0); @@ -19568,7 +18909,6 @@ size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void block_q6_K * block = (block_q6_K*)dst + start / QK_K; result = wsp_ggml_quantize_q6_K(src + start, block, n, n, hist); } break; -#endif case WSP_GGML_TYPE_F16: { int elemsize = sizeof(wsp_ggml_fp16_t); @@ -19659,7 +18999,7 @@ struct wsp_gguf_kv { }; struct wsp_gguf_header { - uint32_t magic; + char magic[4]; uint32_t version; uint64_t n_tensors; // GGUFv2 uint64_t n_kv; // GGUFv2 @@ -19700,8 +19040,7 @@ static bool wsp_gguf_fread_el(FILE * file, void * dst, size_t size, size_t * off return n == size; } -// NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 -static bool wsp_gguf_fread_str_cur(FILE * file, struct wsp_gguf_str * p, size_t * offset) { +static bool wsp_gguf_fread_str(FILE * file, struct wsp_gguf_str * p, size_t * offset) { p->n = 0; p->data = NULL; @@ -19713,23 +19052,10 @@ static bool wsp_gguf_fread_str_cur(FILE * file, struct wsp_gguf_str * p, size_t return ok; } -static bool wsp_gguf_fread_str_v1(FILE * file, struct wsp_gguf_str * p, size_t * offset) { - p->n = 0; - p->data = NULL; - - bool ok = true; - - uint32_t n = 0; - ok = ok && wsp_gguf_fread_el(file, &n, sizeof(n), offset); p->data = calloc(n + 1, 1); p->n = n; - ok = ok && wsp_gguf_fread_el(file, p->data, p->n, offset); - - return ok; -} - struct wsp_gguf_context * wsp_gguf_init_empty(void) { struct wsp_gguf_context * ctx = WSP_GGML_ALIGNED_MALLOC(sizeof(struct wsp_gguf_context)); - ctx->header.magic = WSP_GGUF_MAGIC; + memcpy(ctx->header.magic, WSP_GGUF_MAGIC, sizeof(ctx->header.magic)); ctx->header.version = WSP_GGUF_VERSION; ctx->header.n_tensors = 0; ctx->header.n_kv = 0; @@ -19755,16 +19081,18 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp // offset from start of file size_t offset = 0; - uint32_t magic = 0; + char magic[4]; // check the magic before making allocations { wsp_gguf_fread_el(file, &magic, sizeof(magic), &offset); - if (magic != WSP_GGUF_MAGIC) { - fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic); - fclose(file); - return NULL; + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != WSP_GGUF_MAGIC[i]) { + fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic); + fclose(file); + return NULL; + } } } @@ -19774,27 +19102,22 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp // read the header { - ctx->header.magic = magic; + strncpy(ctx->header.magic, magic, 4); + ctx->kv = NULL; ctx->infos = NULL; ctx->data = NULL; ok = ok && wsp_gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset); + ok = ok && wsp_gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); + ok = ok && wsp_gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t n_tensors = 0; - uint32_t n_kv = 0; - - ok = ok && wsp_gguf_fread_el(file, &n_tensors, sizeof(n_tensors), &offset); - ok = ok && wsp_gguf_fread_el(file, &n_kv, sizeof(n_kv), &offset); - - ctx->header.n_tensors = n_tensors; - ctx->header.n_kv = n_kv; - } else { - ok = ok && wsp_gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); - ok = ok && wsp_gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); + fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__); + fclose(file); + wsp_gguf_free(ctx); + return NULL; } if (!ok) { @@ -19805,12 +19128,6 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp } } - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - bool (* wsp_gguf_fread_str)(FILE *, struct wsp_gguf_str *, size_t *) = wsp_gguf_fread_str_cur; - if (ctx->header.version == 1) { - wsp_gguf_fread_str = wsp_gguf_fread_str_v1; - } - // read the kv pairs { ctx->kv = malloc(ctx->header.n_kv * sizeof(struct wsp_gguf_kv)); @@ -19841,15 +19158,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp case WSP_GGUF_TYPE_ARRAY: { ok = ok && wsp_gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); - - if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t n = 0; - ok = ok && wsp_gguf_fread_el(file, &n, sizeof(n), &offset); - kv->value.arr.n = n; - } else { - ok = ok && wsp_gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); - } + ok = ok && wsp_gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); switch (kv->value.arr.type) { case WSP_GGUF_TYPE_UINT8: @@ -19876,10 +19185,10 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp } break; case WSP_GGUF_TYPE_ARRAY: case WSP_GGUF_TYPE_COUNT: WSP_GGML_ASSERT(false && "invalid type"); break; - }; + } } break; case WSP_GGUF_TYPE_COUNT: WSP_GGML_ASSERT(false && "invalid type"); - }; + } if (!ok) { break; @@ -19908,14 +19217,7 @@ struct wsp_gguf_context * wsp_gguf_init_from_file(const char * fname, struct wsp ok = ok && wsp_gguf_fread_str(file, &info->name, &offset); ok = ok && wsp_gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset); for (uint32_t j = 0; j < info->n_dims; ++j) { - if (ctx->header.version == 1) { - // NOTE: temporary handling of GGUFv1 >> remove after Oct 2023 - uint32_t t = 0; - ok = ok && wsp_gguf_fread_el(file, &t, sizeof(t), &offset); - info->ne[j] = t; - } else { - ok = ok && wsp_gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); - } + ok = ok && wsp_gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); } ok = ok && wsp_gguf_fread_el (file, &info->type, sizeof(info->type), &offset); ok = ok && wsp_gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); @@ -20155,78 +19457,94 @@ int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key) { return keyfound; } -const char * wsp_gguf_get_key(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].key.data; +const char * wsp_gguf_get_key(const struct wsp_gguf_context * ctx, int key_id) { + return ctx->kv[key_id].key.data; } -enum wsp_gguf_type wsp_gguf_get_kv_type(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].type; +enum wsp_gguf_type wsp_gguf_get_kv_type(const struct wsp_gguf_context * ctx, int key_id) { + return ctx->kv[key_id].type; } -enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.type; +enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.type; } -const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.data; +const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.data; } const char * wsp_gguf_get_arr_str(const struct wsp_gguf_context * ctx, int key_id, int i) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY); struct wsp_gguf_kv * kv = &ctx->kv[key_id]; struct wsp_gguf_str * str = &((struct wsp_gguf_str *) kv->value.arr.data)[i]; return str->data; } -int wsp_gguf_get_arr_n(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.arr.n; +int wsp_gguf_get_arr_n(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_ARRAY); + return ctx->kv[key_id].value.arr.n; } -uint8_t wsp_gguf_get_val_u8(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.uint8; +uint8_t wsp_gguf_get_val_u8(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT8); + return ctx->kv[key_id].value.uint8; } -int8_t wsp_gguf_get_val_i8(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.int8; +int8_t wsp_gguf_get_val_i8(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT8); + return ctx->kv[key_id].value.int8; } -uint16_t wsp_gguf_get_val_u16(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.uint16; +uint16_t wsp_gguf_get_val_u16(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT16); + return ctx->kv[key_id].value.uint16; } -int16_t wsp_gguf_get_val_i16(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.int16; +int16_t wsp_gguf_get_val_i16(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT16); + return ctx->kv[key_id].value.int16; } -uint32_t wsp_gguf_get_val_u32(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.uint32; +uint32_t wsp_gguf_get_val_u32(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT32); + return ctx->kv[key_id].value.uint32; } -int32_t wsp_gguf_get_val_i32(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.int32; +int32_t wsp_gguf_get_val_i32(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT32); + return ctx->kv[key_id].value.int32; } -float wsp_gguf_get_val_f32(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.float32; +float wsp_gguf_get_val_f32(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT32); + return ctx->kv[key_id].value.float32; } -uint64_t wsp_gguf_get_val_u64(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.uint64; +uint64_t wsp_gguf_get_val_u64(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_UINT64); + return ctx->kv[key_id].value.uint64; } -int64_t wsp_gguf_get_val_i64(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.int64; +int64_t wsp_gguf_get_val_i64(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_INT64); + return ctx->kv[key_id].value.int64; } -double wsp_gguf_get_val_f64(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.float64; +double wsp_gguf_get_val_f64(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_FLOAT64); + return ctx->kv[key_id].value.float64; } -bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.bool_; +bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_BOOL); + return ctx->kv[key_id].value.bool_; } -const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int i) { - return ctx->kv[i].value.str.data; +const char * wsp_gguf_get_val_str(const struct wsp_gguf_context * ctx, int key_id) { + WSP_GGML_ASSERT(ctx->kv[key_id].type == WSP_GGUF_TYPE_STRING); + return ctx->kv[key_id].value.str.data; } int wsp_gguf_get_n_tensors(const struct wsp_gguf_context * ctx) { @@ -20591,10 +19909,10 @@ static void wsp_gguf_write_to_buf(const struct wsp_gguf_context * ctx, struct ws } break; case WSP_GGUF_TYPE_ARRAY: case WSP_GGUF_TYPE_COUNT: WSP_GGML_ASSERT(false && "invalid type"); break; - }; + } } break; case WSP_GGUF_TYPE_COUNT: WSP_GGML_ASSERT(false && "invalid type"); - }; + } } // write tensor infos diff --git a/cpp/ggml.h b/cpp/ggml.h index f1bbd88..bf1d729 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -58,7 +58,8 @@ // { // ... // -// struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(f); +// struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx); +// wsp_ggml_build_forward_expand(gf, f); // // // set the input variable and parameter values // wsp_ggml_set_f32(x, 2.0f); @@ -213,15 +214,14 @@ #define WSP_GGML_QNT_VERSION 2 // bump this on quantization format changes #define WSP_GGML_QNT_VERSION_FACTOR 1000 // do not change this -#define WSP_GGML_MAX_DIMS 4 -#define WSP_GGML_MAX_NODES 4096 -#define WSP_GGML_MAX_PARAMS 256 -#define WSP_GGML_MAX_CONTEXTS 64 -#define WSP_GGML_MAX_SRC 6 -#define WSP_GGML_MAX_NAME 64 -#define WSP_GGML_MAX_OP_PARAMS 32 -#define WSP_GGML_DEFAULT_N_THREADS 4 - +#define WSP_GGML_MAX_DIMS 4 +#define WSP_GGML_MAX_PARAMS 1024 +#define WSP_GGML_MAX_CONTEXTS 64 +#define WSP_GGML_MAX_SRC 6 +#define WSP_GGML_MAX_NAME 64 +#define WSP_GGML_MAX_OP_PARAMS 64 +#define WSP_GGML_DEFAULT_N_THREADS 4 +#define WSP_GGML_DEFAULT_GRAPH_SIZE 2048 #if UINTPTR_MAX == 0xFFFFFFFF #define WSP_GGML_MEM_ALIGN 4 #else @@ -231,8 +231,9 @@ #define WSP_GGML_EXIT_SUCCESS 0 #define WSP_GGML_EXIT_ABORTED 1 -#define WSP_GGUF_MAGIC 0x46554747 // "GGUF" -#define WSP_GGUF_VERSION 2 +#define WSP_GGUF_MAGIC "GGUF" + +#define WSP_GGUF_VERSION 3 #define WSP_GGUF_DEFAULT_ALIGNMENT 32 @@ -244,10 +245,21 @@ do { \ if (!(x)) { \ fprintf(stderr, "WSP_GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - abort(); \ + fflush(stderr); \ + fflush(stdout); \ + wsp_ggml_print_backtrace(); \ + exit(1); \ } \ } while (0) +#ifndef NDEBUG +#define WSP_GGML_UNREACHABLE() WSP_GGML_ASSERT(!"statement should not be reached") +#elif defined(__GNUC__) +#define WSP_GGML_UNREACHABLE() __builtin_unreachable() +#else +#define WSP_GGML_UNREACHABLE() ((void) 0) +#endif + // used to copy the number of elements and stride in bytes of tensors into local variables. // main purpose is to reduce code duplication and improve readability. // @@ -318,7 +330,7 @@ extern "C" { WSP_GGML_TYPE_COUNT, }; - enum wsp_ggml_backend { + enum wsp_ggml_backend_type { WSP_GGML_BACKEND_CPU = 0, WSP_GGML_BACKEND_GPU = 10, WSP_GGML_BACKEND_GPU_SPLIT = 20, @@ -392,7 +404,12 @@ extern "C" { WSP_GGML_OP_ALIBI, WSP_GGML_OP_CLAMP, WSP_GGML_OP_CONV_1D, + WSP_GGML_OP_CONV_1D_STAGE_0, // internal + WSP_GGML_OP_CONV_1D_STAGE_1, // internal + WSP_GGML_OP_CONV_TRANSPOSE_1D, WSP_GGML_OP_CONV_2D, + WSP_GGML_OP_CONV_2D_STAGE_0, // internal + WSP_GGML_OP_CONV_2D_STAGE_1, // internal WSP_GGML_OP_CONV_TRANSPOSE_2D, WSP_GGML_OP_POOL_1D, WSP_GGML_OP_POOL_2D, @@ -437,6 +454,7 @@ extern "C" { WSP_GGML_UNARY_OP_GELU, WSP_GGML_UNARY_OP_GELU_QUICK, WSP_GGML_UNARY_OP_SILU, + WSP_GGML_UNARY_OP_LEAKY }; enum wsp_ggml_object_type { @@ -445,6 +463,12 @@ extern "C" { WSP_GGML_OBJECT_WORK_BUFFER }; + enum wsp_ggml_log_level { + WSP_GGML_LOG_LEVEL_ERROR = 2, + WSP_GGML_LOG_LEVEL_WARN = 3, + WSP_GGML_LOG_LEVEL_INFO = 4 + }; + // ggml object struct wsp_ggml_object { size_t offs; @@ -461,14 +485,16 @@ extern "C" { // n-dimensional tensor struct wsp_ggml_tensor { - enum wsp_ggml_type type; - enum wsp_ggml_backend backend; + enum wsp_ggml_type type; + enum wsp_ggml_backend_type backend; + + struct wsp_ggml_backend_buffer * buffer; int n_dims; int64_t ne[WSP_GGML_MAX_DIMS]; // number of elements size_t nb[WSP_GGML_MAX_DIMS]; // stride in bytes: - // nb[0] = sizeof(type) - // nb[1] = nb[0] * ne[0] + padding + // nb[0] = wsp_ggml_type_size(type) + // nb[1] = nb[0] * (ne[0] / wsp_ggml_blck_size(type)) + padding // nb[i] = nb[i-1] * ne[i-1] // compute data @@ -496,7 +522,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[4]; + char padding[12]; }; static const size_t WSP_GGML_TENSOR_SIZE = sizeof(struct wsp_ggml_tensor); @@ -509,29 +535,35 @@ extern "C" { int n_threads; - // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes - int n_tasks[WSP_GGML_MAX_NODES]; - // abort wsp_ggml_graph_compute when true bool (*abort_callback)(void * data); void * abort_callback_data; }; - // next prime after WSP_GGML_MAX_NODES - // #define WSP_GGML_GRAPH_HASHTABLE_SIZE 4099 - // next prime after WSP_GGML_MAX_NODES * 2 (nodes + leafs) - #define WSP_GGML_GRAPH_HASHTABLE_SIZE 8273 + enum wsp_ggml_cgraph_eval_order { + WSP_GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, + WSP_GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, + WSP_GGML_CGRAPH_EVAL_ORDER_COUNT + }; + + struct wsp_ggml_hash_set { + size_t size; + struct wsp_ggml_tensor ** keys; + }; // computation graph struct wsp_ggml_cgraph { + int size; int n_nodes; int n_leafs; - struct wsp_ggml_tensor * nodes[WSP_GGML_MAX_NODES]; - struct wsp_ggml_tensor * grads[WSP_GGML_MAX_NODES]; - struct wsp_ggml_tensor * leafs[WSP_GGML_MAX_NODES]; + struct wsp_ggml_tensor ** nodes; + struct wsp_ggml_tensor ** grads; + struct wsp_ggml_tensor ** leafs; + + struct wsp_ggml_hash_set visited_hash_table; - void * visited_hash_table[WSP_GGML_GRAPH_HASHTABLE_SIZE]; + enum wsp_ggml_cgraph_eval_order order; // performance int perf_runs; @@ -539,8 +571,6 @@ extern "C" { int64_t perf_time_us; }; - static const size_t WSP_GGML_GRAPH_SIZE = sizeof(struct wsp_ggml_cgraph); - // scratch buffer struct wsp_ggml_scratch { size_t offs; @@ -585,6 +615,8 @@ extern "C" { WSP_GGML_API int64_t wsp_ggml_cycles(void); WSP_GGML_API int64_t wsp_ggml_cycles_per_ms(void); + WSP_GGML_API void wsp_ggml_print_backtrace(void); + WSP_GGML_API void wsp_ggml_numa_init(void); // call once for better performance on NUMA systems WSP_GGML_API bool wsp_ggml_is_numa(void); // true if init detected that system has >1 NUMA node @@ -674,18 +706,30 @@ extern "C" { WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_dup_tensor (struct wsp_ggml_context * ctx, const struct wsp_ggml_tensor * src); WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_view_tensor(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * src); + // Context tensor enumeration and lookup + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_first_tensor(struct wsp_ggml_context * ctx); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_next_tensor (struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor); WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_get_tensor(struct wsp_ggml_context * ctx, const char * name); WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_zero(struct wsp_ggml_tensor * tensor); WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_i32 (struct wsp_ggml_tensor * tensor, int32_t value); WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_set_f32 (struct wsp_ggml_tensor * tensor, float value); + // Converts a flat index into coordinates + WSP_GGML_API void wsp_ggml_unravel_index(const struct wsp_ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); + WSP_GGML_API int32_t wsp_ggml_get_i32_1d(const struct wsp_ggml_tensor * tensor, int i); WSP_GGML_API void wsp_ggml_set_i32_1d(const struct wsp_ggml_tensor * tensor, int i, int32_t value); + WSP_GGML_API int32_t wsp_ggml_get_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3); + WSP_GGML_API void wsp_ggml_set_i32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); + WSP_GGML_API float wsp_ggml_get_f32_1d(const struct wsp_ggml_tensor * tensor, int i); WSP_GGML_API void wsp_ggml_set_f32_1d(const struct wsp_ggml_tensor * tensor, int i, float value); + WSP_GGML_API float wsp_ggml_get_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3); + WSP_GGML_API void wsp_ggml_set_f32_nd(const struct wsp_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + WSP_GGML_API void * wsp_ggml_get_data (const struct wsp_ggml_tensor * tensor); WSP_GGML_API float * wsp_ggml_get_data_f32(const struct wsp_ggml_tensor * tensor); @@ -719,6 +763,12 @@ extern "C" { struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add_cast( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + enum wsp_ggml_type type); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_add1( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, @@ -828,6 +878,7 @@ extern "C" { struct wsp_ggml_tensor * a, struct wsp_ggml_tensor * b); + // sums repetitions in a into shape of b WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_repeat_back( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, @@ -892,6 +943,10 @@ extern "C" { struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_leaky( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_relu_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a); @@ -970,9 +1025,9 @@ extern "C" { struct wsp_ggml_tensor * b, float eps); - // A: n columns, m rows - // B: n columns, p rows (i.e. we transpose it internally) - // result is m columns, p rows + // A: k columns, n rows => [ne03, ne02, n, k] + // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k] + // result is n columns, m rows => [ne03 * x, ne02 * y, m, n] WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_mul_mat( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, @@ -1049,7 +1104,6 @@ extern "C" { size_t nb1, size_t offset); - // a -> b, return view(b) WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cpy( struct wsp_ggml_context * ctx, @@ -1072,6 +1126,33 @@ extern "C" { struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a); + // make contiguous, with new shape + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_2d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_3d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_cont_4d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + // return view(a), b specifies the new shape // TODO: when we start computing gradient, make a copy instead of view WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_reshape( @@ -1219,14 +1300,15 @@ extern "C" { struct wsp_ggml_tensor * b); // rotary position embedding - // if mode & 1 == 1, skip n_past elements + // if mode & 1 == 1, skip n_past elements (DEPRECATED) // if mode & 2 == 1, GPT-NeoX style // if mode & 4 == 1, ChatGLM style - // TODO: avoid creating a new tensor every time + // + // b is an int32 vector with size a->ne[2], it contains the positions WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx); @@ -1235,7 +1317,7 @@ extern "C" { WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx); @@ -1244,29 +1326,43 @@ extern "C" { WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, - float freq_scale); + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); // in-place, returns view(a) WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_custom_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, - float freq_scale); + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // compute correction dims for YaRN RoPE scaling + void wsp_ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); // xPos RoPE, in-place, returns view(a) WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_xpos_inplace( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, float base, bool down); @@ -1276,7 +1372,7 @@ extern "C" { WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_rope_back( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, - int n_past, + struct wsp_ggml_tensor * b, int n_dims, int mode, int n_ctx, @@ -1287,7 +1383,7 @@ extern "C" { // alibi position embedding // in-place, returns view(a) - struct wsp_ggml_tensor * wsp_ggml_alibi( + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_alibi( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, int n_past, @@ -1296,7 +1392,7 @@ extern "C" { // clamp // in-place, returns view(a) - struct wsp_ggml_tensor * wsp_ggml_clamp( + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_clamp( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, float min, @@ -1319,6 +1415,14 @@ extern "C" { int s, int d); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_transpose_1d( + struct wsp_ggml_context * ctx, + struct wsp_ggml_tensor * a, + struct wsp_ggml_tensor * b, + int s0, + int p0, + int d0); + WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_conv_2d( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, @@ -1377,6 +1481,8 @@ extern "C" { int s0, // stride int p0); // padding + // the result will have 2*p0 padding for the first dimension + // and 2*p1 padding for the second dimension WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_pool_2d( struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * a, @@ -1385,8 +1491,8 @@ extern "C" { int k1, int s0, int s1, - int p0, - int p1); + float p0, + float p1); // nearest interpolate // used in stable-diffusion @@ -1627,19 +1733,22 @@ extern "C" { WSP_GGML_API void wsp_ggml_build_forward_expand (struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_tensor * tensor); WSP_GGML_API void wsp_ggml_build_backward_expand(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, struct wsp_ggml_cgraph * gb, bool keep); - WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_forward (struct wsp_ggml_tensor * tensor); - WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_build_backward(struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * gf, bool keep); - // graph allocation in a context - WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); - WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_build_forward_ctx(struct wsp_ggml_context * ctx, struct wsp_ggml_tensor * tensor); + WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph (struct wsp_ggml_context * ctx); // size = WSP_GGML_DEFAULT_GRAPH_SIZE, grads = false + WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_new_graph_custom (struct wsp_ggml_context * ctx, size_t size, bool grads); + WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_dup (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph); + WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_view (struct wsp_ggml_context * ctx, struct wsp_ggml_cgraph * cgraph, int i0, int i1); + WSP_GGML_API void wsp_ggml_graph_cpy (struct wsp_ggml_cgraph * src, struct wsp_ggml_cgraph * dst); + WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); // zero grads + WSP_GGML_API void wsp_ggml_graph_clear (struct wsp_ggml_cgraph * cgraph); + WSP_GGML_API size_t wsp_ggml_graph_overhead(void); + WSP_GGML_API size_t wsp_ggml_graph_overhead_custom(size_t size, bool grads); // wsp_ggml_graph_plan() has to be called before wsp_ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data WSP_GGML_API struct wsp_ggml_cplan wsp_ggml_graph_plan (struct wsp_ggml_cgraph * cgraph, int n_threads /*= WSP_GGML_DEFAULT_N_THREADS*/); - WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan); - WSP_GGML_API void wsp_ggml_graph_reset (struct wsp_ggml_cgraph * cgraph); + WSP_GGML_API int wsp_ggml_graph_compute(struct wsp_ggml_cgraph * cgraph, struct wsp_ggml_cplan * cplan); // same as wsp_ggml_graph_compute() but the work data is allocated as a part of the context // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data @@ -1647,8 +1756,8 @@ extern "C" { WSP_GGML_API struct wsp_ggml_tensor * wsp_ggml_graph_get_tensor(struct wsp_ggml_cgraph * cgraph, const char * name); - WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname); - WSP_GGML_API struct wsp_ggml_cgraph wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval); + WSP_GGML_API void wsp_ggml_graph_export(const struct wsp_ggml_cgraph * cgraph, const char * fname); + WSP_GGML_API struct wsp_ggml_cgraph * wsp_ggml_graph_import(const char * fname, struct wsp_ggml_context ** ctx_data, struct wsp_ggml_context ** ctx_eval); // print info and performance information for the graph WSP_GGML_API void wsp_ggml_graph_print(const struct wsp_ggml_cgraph * cgraph); @@ -1656,6 +1765,16 @@ extern "C" { // dump the graph into a file using the dot format WSP_GGML_API void wsp_ggml_graph_dump_dot(const struct wsp_ggml_cgraph * gb, const struct wsp_ggml_cgraph * gf, const char * filename); + // build gradient checkpointing backward graph gb for gf using provided checkpoints + // gb_tmp will contain original backward graph with rewritten backward process nodes, + // but without the second forward pass nodes. + WSP_GGML_API void wsp_ggml_build_backward_gradient_checkpointing( + struct wsp_ggml_context * ctx, + struct wsp_ggml_cgraph * gf, + struct wsp_ggml_cgraph * gb, + struct wsp_ggml_cgraph * gb_tmp, + struct wsp_ggml_tensor * * checkpoints, + int n_checkpoints); // // optimization // @@ -1682,6 +1801,7 @@ extern "C" { WSP_GGML_OPT_NO_CONTEXT, WSP_GGML_OPT_INVALID_WOLFE, WSP_GGML_OPT_FAIL, + WSP_GGML_OPT_CANCEL, WSP_GGML_LINESEARCH_FAIL = -128, WSP_GGML_LINESEARCH_MINIMUM_STEP, @@ -1690,7 +1810,8 @@ extern "C" { WSP_GGML_LINESEARCH_INVALID_PARAMETERS, }; - typedef void (*wsp_ggml_opt_callback)(void * data, float * sched); + typedef void (*wsp_ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); + typedef void (*wsp_ggml_log_callback)(enum wsp_ggml_log_level level, const char * text, void * user_data); // optimization parameters // @@ -1699,6 +1820,8 @@ extern "C" { struct wsp_ggml_opt_params { enum wsp_ggml_opt_type type; + size_t graph_size; + int n_threads; // delta-based convergence test @@ -1721,6 +1844,8 @@ extern "C" { bool print_forward_graph; bool print_backward_graph; + int n_gradient_accumulation; + // ADAM parameters struct { int n_iter; @@ -1766,6 +1891,7 @@ extern "C" { float loss_after; struct { + struct wsp_ggml_tensor * g; // current gradient struct wsp_ggml_tensor * m; // first moment struct wsp_ggml_tensor * v; // second moment struct wsp_ggml_tensor * pf; // past function values @@ -1829,12 +1955,19 @@ extern "C" { // quantization // + // TODO: these would probably get removed in favor of the more general wsp_ggml_quantize_chunk WSP_GGML_API size_t wsp_ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); WSP_GGML_API size_t wsp_ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); WSP_GGML_API size_t wsp_ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); WSP_GGML_API size_t wsp_ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); WSP_GGML_API size_t wsp_ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); + WSP_GGML_API size_t wsp_ggml_quantize_chunk(enum wsp_ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); // @@ -1882,26 +2015,26 @@ extern "C" { WSP_GGML_API int wsp_gguf_get_n_kv(const struct wsp_gguf_context * ctx); WSP_GGML_API int wsp_gguf_find_key(const struct wsp_gguf_context * ctx, const char * key); - WSP_GGML_API const char * wsp_gguf_get_key (const struct wsp_gguf_context * ctx, int i); - - WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_kv_type (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int i); - - // results are undefined if the wrong type is used for the key - WSP_GGML_API uint8_t wsp_gguf_get_val_u8 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API int8_t wsp_gguf_get_val_i8 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API uint16_t wsp_gguf_get_val_u16 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API int16_t wsp_gguf_get_val_i16 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API uint32_t wsp_gguf_get_val_u32 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API int32_t wsp_gguf_get_val_i32 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API float wsp_gguf_get_val_f32 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API uint64_t wsp_gguf_get_val_u64 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API int64_t wsp_gguf_get_val_i64 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int i); - WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int i); + WSP_GGML_API const char * wsp_gguf_get_key (const struct wsp_gguf_context * ctx, int key_id); + + WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_kv_type (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API enum wsp_gguf_type wsp_gguf_get_arr_type(const struct wsp_gguf_context * ctx, int key_id); + + // will abort if the wrong type is used for the key + WSP_GGML_API uint8_t wsp_gguf_get_val_u8 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API int8_t wsp_gguf_get_val_i8 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API uint16_t wsp_gguf_get_val_u16 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API int16_t wsp_gguf_get_val_i16 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API uint32_t wsp_gguf_get_val_u32 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API int32_t wsp_gguf_get_val_i32 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API float wsp_gguf_get_val_f32 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API uint64_t wsp_gguf_get_val_u64 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API int64_t wsp_gguf_get_val_i64 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API double wsp_gguf_get_val_f64 (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API bool wsp_gguf_get_val_bool(const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API const char * wsp_gguf_get_val_str (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API int wsp_gguf_get_arr_n (const struct wsp_gguf_context * ctx, int key_id); + WSP_GGML_API const void * wsp_gguf_get_arr_data(const struct wsp_gguf_context * ctx, int key_id); WSP_GGML_API const char * wsp_gguf_get_arr_str (const struct wsp_gguf_context * ctx, int key_id, int i); WSP_GGML_API int wsp_gguf_get_n_tensors (const struct wsp_gguf_context * ctx); @@ -2008,7 +2141,7 @@ extern "C" { enum wsp_ggml_type vec_dot_type; } wsp_ggml_type_traits_t; - wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type); + WSP_GGML_API wsp_ggml_type_traits_t wsp_ggml_internal_get_type_traits(enum wsp_ggml_type type); #ifdef __cplusplus } diff --git a/cpp/whisper.cpp b/cpp/whisper.cpp index fa58cef..a84e7dd 100644 --- a/cpp/whisper.cpp +++ b/cpp/whisper.cpp @@ -120,6 +120,7 @@ static void byteswap_tensor(wsp_ggml_tensor * tensor) { //#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 16 +#define WHISPER_MAX_NODES 4096 // // ggml helpers @@ -192,6 +193,15 @@ enum e_model { MODEL_LARGE, }; +static const std::map g_model_name = { + { MODEL_UNKNOWN, "unknown" }, + { MODEL_TINY, "tiny" }, + { MODEL_BASE, "base" }, + { MODEL_SMALL, "small" }, + { MODEL_MEDIUM, "medium" }, + { MODEL_LARGE, "large" }, +}; + static const std::map> g_lang = { { "en", { 0, "english", } }, { "zh", { 1, "chinese", } }, @@ -292,6 +302,7 @@ static const std::map> g_lang = { { "ba", { 96, "bashkir", } }, { "jw", { 97, "javanese", } }, { "su", { 98, "sundanese", } }, + { "yue", { 99, "cantonese", } }, }; static const size_t MB = 1ull*1024*1024; @@ -401,7 +412,11 @@ struct whisper_vocab { id token_beg = 50363; // begin timestamps bool is_multilingual() const { - return n_vocab == 51865; + return n_vocab >= 51865; + } + + int num_languages() const { + return n_vocab - 51765 - (is_multilingual() ? 1 : 0); } }; @@ -663,7 +678,7 @@ static void whisper_allocr_graph_init(struct whisper_allocr & allocr, std::funct auto & meta = allocr.meta; auto & data = allocr.data; - meta.resize(wsp_ggml_tensor_overhead()*WSP_GGML_MAX_NODES + wsp_ggml_graph_overhead()); + meta.resize(wsp_ggml_tensor_overhead()*WHISPER_MAX_NODES + wsp_ggml_graph_overhead()); alloc = wsp_ggml_allocr_new_measure(tensor_alignment); @@ -735,7 +750,7 @@ struct whisper_state { int lang_id = 0; // english by default - std::string path_model; // populated by whisper_init_from_file() + std::string path_model; // populated by whisper_init_from_file_with_params() #ifdef WHISPER_USE_COREML whisper_coreml_context * ctx_coreml = nullptr; #endif @@ -769,10 +784,8 @@ struct whisper_context { whisper_vocab vocab; whisper_state * state = nullptr; - std::string path_model; // populated by whisper_init_from_file() -#ifdef WHISPER_USE_COREML - bool load_coreml = true; -#endif + std::string path_model; // populated by whisper_init_from_file_with_params() + whisper_context_params params; }; static void whisper_default_log(const char * text) { @@ -923,6 +936,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con assert(hparams.n_text_state == hparams.n_audio_state); + std::string mver = ""; + if (hparams.n_audio_layer == 4) { model.type = e_model::MODEL_TINY; } @@ -941,6 +956,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con if (hparams.n_audio_layer == 32) { model.type = e_model::MODEL_LARGE; + + if (hparams.n_vocab == 51866) { + mver = " v3"; + } } const int32_t qntvr = hparams.ftype / WSP_GGML_QNT_VERSION_FACTOR; @@ -969,7 +988,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con log("%s: n_mels = %d\n", __func__, hparams.n_mels); log("%s: ftype = %d\n", __func__, model.hparams.ftype); log("%s: qntvr = %d\n", __func__, qntvr); - log("%s: type = %d\n", __func__, model.type); + log("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); // print memory requirements { @@ -1040,13 +1059,17 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con if (vocab.is_multilingual()) { vocab.token_eot++; vocab.token_sot++; - vocab.token_translate++; - vocab.token_transcribe++; - vocab.token_solm++; - vocab.token_prev++; - vocab.token_nosp++; - vocab.token_not++; - vocab.token_beg++; + + // account for variable number of language tokens + const int dt = vocab.num_languages() - 98; + + vocab.token_translate += dt; + vocab.token_transcribe += dt; + vocab.token_solm += dt; + vocab.token_prev += dt; + vocab.token_nosp += dt; + vocab.token_not += dt; + vocab.token_beg += dt; } if (n_vocab < model.hparams.n_vocab) { @@ -1075,6 +1098,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con vocab.id_to_token[i] = word; } } + + log("%s: n_langs = %d\n", __func__, vocab.num_languages()); } size_t ctx_size = 0; @@ -1619,7 +1644,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_encoder( struct wsp_ggml_context * ctx0 = wsp_ggml_init(params); - wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0); + wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); wsp_ggml_allocr * alloc = wstate.alloc_encode.alloc; @@ -2037,7 +2062,7 @@ static struct wsp_ggml_cgraph * whisper_build_graph_decoder( struct wsp_ggml_context * ctx0 = wsp_ggml_init(params); - wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0); + wsp_ggml_cgraph * gf = wsp_ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); wsp_ggml_allocr * alloc = wstate.alloc_decode.alloc; @@ -2856,8 +2881,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } + #ifdef WHISPER_USE_COREML -if (ctx->load_coreml) { // Not in correct layer for easy patch + if (ctx->params.use_coreml) { const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); @@ -2873,7 +2899,7 @@ if (ctx->load_coreml) { // Not in correct layer for easy patch } else { log("%s: Core ML model loaded\n", __func__); } -} + } #endif state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); @@ -2934,59 +2960,64 @@ if (ctx->load_coreml) { // Not in correct layer for easy patch } #ifdef WSP_GGML_USE_METAL - state->ctx_metal = wsp_ggml_metal_init(1); - if (!state->ctx_metal) { - log("%s: wsp_ggml_metal_init() failed\n", __func__); - delete state; - return nullptr; + if (ctx->params.use_gpu) { + state->ctx_metal = wsp_ggml_metal_init(1); + if (!state->ctx_metal) { + log("%s: wsp_ggml_metal_init() failed\n", __func__); + delete state; + return nullptr; + } } - log("%s: Metal context initialized\n", __func__); + if (state->ctx_metal) { + log("%s: Metal context initialized\n", __func__); - // this allocates all Metal resources and memory buffers + // this allocates all Metal resources and memory buffers - void * data_ptr = NULL; - size_t data_size = 0; + void * data_ptr = NULL; + size_t data_size = 0; - // TODO: add mmap support - //if (params.use_mmap) { - // data_ptr = ctx->model.mapping->addr; - // data_size = ctx->model.mapping->size; - //} else { - // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx); - // data_size = wsp_ggml_get_mem_size (ctx->model.ctx); - //} + // TODO: add mmap support + //if (params.use_mmap) { + // data_ptr = ctx->model.mapping->addr; + // data_size = ctx->model.mapping->size; + //} else { + // data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx); + // data_size = wsp_ggml_get_mem_size (ctx->model.ctx); + //} - data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx); - data_size = wsp_ggml_get_mem_size (ctx->model.ctx); + data_ptr = wsp_ggml_get_mem_buffer(ctx->model.ctx); + data_size = wsp_ggml_get_mem_size (ctx->model.ctx); - const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx); + const size_t max_size = wsp_ggml_get_max_tensor_size(ctx->model.ctx); - log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); + log("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); #define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - delete state; \ - return nullptr; \ - } + if (!(result)) { \ + log("%s: failed to add metal buffer\n", __func__); \ + delete state; \ + return nullptr; \ + } - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data", data_ptr, data_size, max_size)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_conv", state->alloc_conv.meta.data(), state->alloc_conv.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_encode", state->alloc_encode.meta.data(), state->alloc_encode.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_cross", state->alloc_cross.meta.data(), state->alloc_cross.meta.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "meta_decode", state->alloc_decode.meta.data(), state->alloc_decode.meta.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_conv", state->alloc_conv.data.data(), state->alloc_conv.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_encode", state->alloc_encode.data.data(), state->alloc_encode.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_cross", state->alloc_cross.data.data(), state->alloc_cross.data.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "data_decode", state->alloc_decode.data.data(), state->alloc_decode.data.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_cross", state->kv_cross.buf.data(), state->kv_cross.buf.size(), 0)); - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, "kv_self_0", state->decoders[0].kv_self.buf.data(), state->decoders[0].kv_self.buf.size(), 0)); #undef WHISPER_METAL_CHECK_BUF + + } #endif state->rng = std::mt19937(0); @@ -2994,23 +3025,6 @@ if (ctx->load_coreml) { // Not in correct layer for easy patch return state; } -#ifdef WHISPER_USE_COREML -struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) { - whisper_context * ctx = whisper_init_from_file_no_state(path_model); - if (!ctx) { - return nullptr; - } - ctx->load_coreml = false; - ctx->state = whisper_init_state(ctx); - if (!ctx->state) { - whisper_free(ctx); - return nullptr; - } - - return ctx; -} -#endif - int whisper_ctx_init_openvino_encoder( struct whisper_context * ctx, const char * model_path, @@ -3060,7 +3074,15 @@ int whisper_ctx_init_openvino_encoder( #endif } -struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { +struct whisper_context_params whisper_context_default_params() { + struct whisper_context_params result = { + /*.use_gpu =*/ true, + /*.use_coreml =*/ false, + }; + return result; +} + +struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) { log("%s: loading model from '%s'\n", __func__, path_model); auto fin = std::ifstream(path_model, std::ios::binary); @@ -3089,7 +3111,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model fin->close(); }; - auto ctx = whisper_init_no_state(&loader); + auto ctx = whisper_init_with_params_no_state(&loader, params); if (ctx) { ctx->path_model = path_model; @@ -3098,7 +3120,7 @@ struct whisper_context * whisper_init_from_file_no_state(const char * path_model return ctx; } -struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { +struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) { struct buf_context { uint8_t* buffer; size_t size; @@ -3132,13 +3154,14 @@ struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t loader.close = [](void * /*ctx*/) { }; - return whisper_init_no_state(&loader); + return whisper_init_with_params_no_state(&loader, params); } -struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { +struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) { wsp_ggml_time_init(); whisper_context * ctx = new whisper_context; + ctx->params = params; if (!whisper_model_load(loader, *ctx)) { loader->close(loader->context); @@ -3152,8 +3175,8 @@ struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loa return ctx; } -struct whisper_context * whisper_init_from_file(const char * path_model) { - whisper_context * ctx = whisper_init_from_file_no_state(path_model); +struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params); if (!ctx) { return nullptr; } @@ -3167,8 +3190,8 @@ struct whisper_context * whisper_init_from_file(const char * path_model) { return ctx; } -struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { - whisper_context * ctx = whisper_init_from_buffer_no_state(buffer, buffer_size); +struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params); if (!ctx) { return nullptr; } @@ -3182,8 +3205,8 @@ struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_s return ctx; } -struct whisper_context * whisper_init(struct whisper_model_loader * loader) { - whisper_context * ctx = whisper_init_no_state(loader); +struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_with_params_no_state(loader, params); if (!ctx) { return nullptr; } @@ -3197,6 +3220,30 @@ struct whisper_context * whisper_init(struct whisper_model_loader * loader) { return ctx; } +struct whisper_context * whisper_init_from_file(const char * path_model) { + return whisper_init_from_file_with_params(path_model, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { + return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params()); +} + +struct whisper_context * whisper_init(struct whisper_model_loader * loader) { + return whisper_init_with_params(loader, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { + return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { + return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { + return whisper_init_with_params_no_state(loader, whisper_context_default_params()); +} + void whisper_free_state(struct whisper_state * state) { if (state) { @@ -3251,6 +3298,12 @@ void whisper_free(struct whisper_context * ctx) { } } +void whisper_free_context_params(struct whisper_context_params * params) { + if (params) { + delete params; + } +} + void whisper_free_params(struct whisper_full_params * params) { if (params) { delete params; @@ -3258,7 +3311,7 @@ void whisper_free_params(struct whisper_full_params * params) { } int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { log("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3272,7 +3325,7 @@ int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int // same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { - if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, WHISPER_N_MEL, n_threads, ctx->model.filters, false, state->mel)) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { log("%s: failed to compute mel spectrogram\n", __func__); return -1; } @@ -3295,13 +3348,13 @@ int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * // TODO int whisper_set_mel_with_state( - struct whisper_context * /*ctx*/, + struct whisper_context * ctx, struct whisper_state * state, const float * data, int n_len, int n_mel) { - if (n_mel != WHISPER_N_MEL) { - log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, WHISPER_N_MEL); + if (n_mel != ctx->model.filters.n_mel) { + log("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); return -1; } @@ -3665,6 +3718,7 @@ void whisper_print_timings(struct whisper_context * ctx) { } void whisper_reset_timings(struct whisper_context * ctx) { + ctx->t_start_us = wsp_ggml_time_us(); if (ctx->state != nullptr) { ctx->state->t_sample_us = 0; ctx->state->t_encode_us = 0; @@ -3719,6 +3773,14 @@ const char * whisper_print_system_info(void) { //////////////////////////////////////////////////////////////////////////// +struct whisper_context_params * whisper_context_default_params_by_ref() { + struct whisper_context_params params = whisper_context_default_params(); + + struct whisper_context_params* result = new whisper_context_params(); + *result = params; + return result; +} + struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) { struct whisper_full_params params = whisper_full_default_params(strategy); @@ -3795,8 +3857,8 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.encoder_begin_callback =*/ nullptr, /*.encoder_begin_callback_user_data =*/ nullptr, - /*.abort_callback =*/ nullptr, - /*.abort_callback_user_data =*/ nullptr, + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, /*.logits_filter_callback =*/ nullptr, /*.logits_filter_callback_user_data =*/ nullptr, @@ -3964,6 +4026,7 @@ static void whisper_process_logits( // suppress task tokens logits[vocab.token_translate] = -INFINITY; logits[vocab.token_transcribe] = -INFINITY; + logits[vocab.token_prev] = -INFINITY; if (params.logits_filter_callback) { params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); @@ -4530,17 +4593,19 @@ int whisper_full_with_state( // TODO: not very clean - look for a better way and potentially merging with the init of decoder 0 #ifdef WSP_GGML_USE_METAL + if (state->ctx_metal) { #define WHISPER_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - log("%s: failed to add metal buffer\n", __func__); \ - return 0; \ - } + if (!(result)) { \ + log("%s: failed to add metal buffer\n", __func__); \ + return 0; \ + } - const std::string kv_name = "kv_self_" + std::to_string(j); - auto & kv_self = decoder.kv_self; + const std::string kv_name = "kv_self_" + std::to_string(j); + auto & kv_self = decoder.kv_self; - WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); + WHISPER_METAL_CHECK_BUF(wsp_ggml_metal_add_buffer(state->ctx_metal, kv_name.c_str(), kv_self.buf.data(), kv_self.buf.size(), 0)); #undef WHISPER_METAL_CHECK_BUF + } #endif } } @@ -4557,7 +4622,7 @@ int whisper_full_with_state( // initial prompt if (!params.prompt_tokens && params.initial_prompt) { - prompt_tokens.resize(2048); + prompt_tokens.resize(1024); prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size())); params.prompt_tokens = prompt_tokens.data(); params.prompt_n_tokens = prompt_tokens.size(); @@ -4582,6 +4647,7 @@ int whisper_full_with_state( // these tokens determine the task that will be performed std::vector prompt_init = { whisper_token_sot(ctx) }; + if (whisper_is_multilingual(ctx)) { const int lang_id = whisper_lang_id(params.language); state->lang_id = lang_id; @@ -4593,6 +4659,17 @@ int whisper_full_with_state( } } + { + const bool is_distil = ctx->model.hparams.n_text_layer == 2; + + // distilled models require the "no_timestamps" token + // TODO: add input parameter (#1229) + if (is_distil) { + log("%s: using distilled model - forcing no_timestamps\n", __func__); + prompt_init.push_back(whisper_token_not(ctx)); + } + } + int seek = seek_start; std::vector prompt; @@ -5454,7 +5531,7 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) { // b: N*N*sizeof(float) // c: N*N*sizeof(float) // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) - std::vector buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead()); + std::vector buf(3llu*N_max*N_max*sizeof(float) + 3*wsp_ggml_tensor_overhead() + wsp_ggml_graph_overhead()); std::vector work; // put a bunch of random data in the buffer @@ -5505,17 +5582,19 @@ WHISPER_API const char * whisper_bench_wsp_ggml_mul_mat_str(int n_threads) { struct wsp_ggml_tensor * c = wsp_ggml_mul_mat(ctx0, a, b); - struct wsp_ggml_cgraph gf = wsp_ggml_build_forward(c); + struct wsp_ggml_cgraph * gf = wsp_ggml_new_graph(ctx0); + + wsp_ggml_build_forward_expand(gf, c); double tsum = 0.0; // heat-up - wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr , nullptr); + wsp_ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); for (int i = 0; i < n_max; ++i) { const int64_t t0 = wsp_ggml_time_us(); - wsp_ggml_graph_compute_helper(work, &gf, n_threads, nullptr, nullptr); + wsp_ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); const int64_t t1 = wsp_ggml_time_us(); diff --git a/cpp/whisper.h b/cpp/whisper.h index 5d2b013..eea563c 100644 --- a/cpp/whisper.h +++ b/cpp/whisper.h @@ -5,6 +5,14 @@ #include #include +#ifdef __GNUC__ +# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define WHISPER_DEPRECATED(func, hint) func +#endif + #ifdef WHISPER_SHARED # ifdef _WIN32 # ifdef WHISPER_BUILD @@ -21,7 +29,6 @@ #define WHISPER_SAMPLE_RATE 16000 #define WHISPER_N_FFT 400 -#define WHISPER_N_MEL 80 #define WHISPER_HOP_LENGTH 160 #define WHISPER_CHUNK_SIZE 30 @@ -71,6 +78,11 @@ extern "C" { typedef int whisper_token; + struct whisper_context_params { + bool use_gpu; + bool use_coreml; + }; + typedef struct whisper_token_data { whisper_token id; // token id whisper_token tid; // forced timestamp token id @@ -99,18 +111,40 @@ extern "C" { // Various functions for loading a ggml whisper model. // Allocate (almost) all memory needed for the model. // Return NULL on failure -#ifdef WHISPER_USE_COREML - WHISPER_API struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model); -#endif - WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); - WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); - WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params); // These are the same as the above, but the internal state of the context is not allocated automatically // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) - WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model); - WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size); - WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader); + WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params); + + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model), + "use whisper_init_from_file_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader), + "use whisper_init_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model), + "use whisper_init_from_file_with_params_no_state instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params_no_state instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader), + "use whisper_init_with_params_no_state instead" + ); WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); @@ -135,6 +169,7 @@ extern "C" { WHISPER_API void whisper_free (struct whisper_context * ctx); WHISPER_API void whisper_free_state(struct whisper_state * state); WHISPER_API void whisper_free_params(struct whisper_full_params * params); + WHISPER_API void whisper_free_context_params(struct whisper_context_params * params); // Convert RAW PCM audio to log mel spectrogram. // The resulting spectrogram is stored inside the default state of the provided whisper context. @@ -445,7 +480,9 @@ extern "C" { void * logits_filter_callback_user_data; }; - // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_params() + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() + WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(); + WHISPER_API struct whisper_context_params whisper_context_default_params(void); WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); diff --git a/docs/API/README.md b/docs/API/README.md index ed45d74..0d83911 100644 --- a/docs/API/README.md +++ b/docs/API/README.md @@ -58,7 +58,7 @@ whisper.rn #### Defined in -[index.ts:76](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L76) +[index.ts:76](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L76) ___ @@ -76,10 +76,11 @@ ___ | `filePath` | `string` \| `number` | - | | `isBundleAsset?` | `boolean` | Is the file path a bundle asset for pure string filePath | | `useCoreMLIos?` | `boolean` | Prefer to use Core ML model if exists. If set to false, even if the Core ML model exists, it will not be used. | +| `useGpu?` | `boolean` | Use GPU if available. Currently iOS only, if it's enabled, Core ML option will be ignored. | #### Defined in -[index.ts:428](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L428) +[index.ts:438](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L438) ___ @@ -89,7 +90,7 @@ ___ #### Defined in -[index.ts:59](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L59) +[index.ts:59](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L59) ___ @@ -107,7 +108,7 @@ ___ #### Defined in -[index.ts:52](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L52) +[index.ts:52](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L52) ___ @@ -126,7 +127,7 @@ ___ #### Defined in -[index.ts:45](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L45) +[index.ts:45](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L45) ___ @@ -156,7 +157,7 @@ ___ #### Defined in -[NativeRNWhisper.ts:5](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/NativeRNWhisper.ts#L5) +[NativeRNWhisper.ts:5](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/NativeRNWhisper.ts#L5) ___ @@ -174,7 +175,7 @@ ___ #### Defined in -[index.ts:70](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L70) +[index.ts:70](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L70) ___ @@ -199,7 +200,7 @@ ___ #### Defined in -[index.ts:133](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L133) +[index.ts:133](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L133) ___ @@ -217,7 +218,7 @@ ___ #### Defined in -[index.ts:166](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L166) +[index.ts:166](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L166) ___ @@ -241,7 +242,7 @@ ___ #### Defined in -[index.ts:153](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L153) +[index.ts:153](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L153) ___ @@ -251,7 +252,7 @@ ___ #### Defined in -[index.ts:84](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L84) +[index.ts:84](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L84) ___ @@ -269,7 +270,7 @@ ___ #### Defined in -[NativeRNWhisper.ts:37](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/NativeRNWhisper.ts#L37) +[NativeRNWhisper.ts:37](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/NativeRNWhisper.ts#L37) ## Variables @@ -294,7 +295,7 @@ AudioSession Utility, iOS only. #### Defined in -[AudioSessionIos.ts:50](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L50) +[AudioSessionIos.ts:50](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L50) ___ @@ -306,7 +307,7 @@ Is allow fallback to CPU if load CoreML model failed #### Defined in -[index.ts:526](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L526) +[index.ts:540](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L540) ___ @@ -318,7 +319,7 @@ Is use CoreML models on iOS #### Defined in -[index.ts:523](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L523) +[index.ts:537](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L537) ___ @@ -330,7 +331,7 @@ Current version of whisper.cpp #### Defined in -[index.ts:518](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L518) +[index.ts:532](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L532) ## Functions @@ -350,7 +351,7 @@ Current version of whisper.cpp #### Defined in -[index.ts:452](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L452) +[index.ts:464](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L464) ___ @@ -364,4 +365,4 @@ ___ #### Defined in -[index.ts:513](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L513) +[index.ts:527](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L527) diff --git a/docs/API/classes/WhisperContext.md b/docs/API/classes/WhisperContext.md index 315003c..1d704f8 100644 --- a/docs/API/classes/WhisperContext.md +++ b/docs/API/classes/WhisperContext.md @@ -10,7 +10,9 @@ ### Properties +- [gpu](WhisperContext.md#gpu) - [id](WhisperContext.md#id) +- [reasonNoGPU](WhisperContext.md#reasonnogpu) ### Methods @@ -22,27 +24,47 @@ ### constructor -• **new WhisperContext**(`id`) +• **new WhisperContext**(`«destructured»`) #### Parameters | Name | Type | | :------ | :------ | -| `id` | `number` | +| `«destructured»` | `NativeWhisperContext` | #### Defined in -[index.ts:186](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L186) +[index.ts:190](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L190) ## Properties +### gpu + +• **gpu**: `boolean` = `false` + +#### Defined in + +[index.ts:186](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L186) + +___ + ### id • **id**: `number` #### Defined in -[index.ts:184](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L184) +[index.ts:184](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L184) + +___ + +### reasonNoGPU + +• **reasonNoGPU**: `string` = `''` + +#### Defined in + +[index.ts:188](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L188) ## Methods @@ -56,7 +78,7 @@ #### Defined in -[index.ts:423](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L423) +[index.ts:433](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L433) ___ @@ -84,7 +106,7 @@ Transcribe audio file #### Defined in -[index.ts:191](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L191) +[index.ts:201](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L201) ___ @@ -106,4 +128,4 @@ Transcribe the microphone audio stream, the microphone user permission is requir #### Defined in -[index.ts:287](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/index.ts#L287) +[index.ts:297](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/index.ts#L297) diff --git a/docs/API/enums/AudioSessionCategoryIos.md b/docs/API/enums/AudioSessionCategoryIos.md index fb620f0..1d3168e 100644 --- a/docs/API/enums/AudioSessionCategoryIos.md +++ b/docs/API/enums/AudioSessionCategoryIos.md @@ -25,7 +25,7 @@ https://developer.apple.com/documentation/avfaudio/avaudiosessioncategory?langua #### Defined in -[AudioSessionIos.ts:8](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L8) +[AudioSessionIos.ts:8](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L8) ___ @@ -35,7 +35,7 @@ ___ #### Defined in -[AudioSessionIos.ts:13](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L13) +[AudioSessionIos.ts:13](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L13) ___ @@ -45,7 +45,7 @@ ___ #### Defined in -[AudioSessionIos.ts:12](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L12) +[AudioSessionIos.ts:12](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L12) ___ @@ -55,7 +55,7 @@ ___ #### Defined in -[AudioSessionIos.ts:10](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L10) +[AudioSessionIos.ts:10](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L10) ___ @@ -65,7 +65,7 @@ ___ #### Defined in -[AudioSessionIos.ts:11](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L11) +[AudioSessionIos.ts:11](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L11) ___ @@ -75,4 +75,4 @@ ___ #### Defined in -[AudioSessionIos.ts:9](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L9) +[AudioSessionIos.ts:9](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L9) diff --git a/docs/API/enums/AudioSessionCategoryOptionIos.md b/docs/API/enums/AudioSessionCategoryOptionIos.md index a46772d..e0656cb 100644 --- a/docs/API/enums/AudioSessionCategoryOptionIos.md +++ b/docs/API/enums/AudioSessionCategoryOptionIos.md @@ -26,7 +26,7 @@ https://developer.apple.com/documentation/avfaudio/avaudiosessioncategoryoptions #### Defined in -[AudioSessionIos.ts:25](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L25) +[AudioSessionIos.ts:25](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L25) ___ @@ -36,7 +36,7 @@ ___ #### Defined in -[AudioSessionIos.ts:23](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L23) +[AudioSessionIos.ts:23](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L23) ___ @@ -46,7 +46,7 @@ ___ #### Defined in -[AudioSessionIos.ts:24](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L24) +[AudioSessionIos.ts:24](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L24) ___ @@ -56,7 +56,7 @@ ___ #### Defined in -[AudioSessionIos.ts:26](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L26) +[AudioSessionIos.ts:26](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L26) ___ @@ -66,7 +66,7 @@ ___ #### Defined in -[AudioSessionIos.ts:21](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L21) +[AudioSessionIos.ts:21](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L21) ___ @@ -76,7 +76,7 @@ ___ #### Defined in -[AudioSessionIos.ts:22](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L22) +[AudioSessionIos.ts:22](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L22) ___ @@ -86,4 +86,4 @@ ___ #### Defined in -[AudioSessionIos.ts:20](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L20) +[AudioSessionIos.ts:20](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L20) diff --git a/docs/API/enums/AudioSessionModeIos.md b/docs/API/enums/AudioSessionModeIos.md index 74ebb30..6f2a3eb 100644 --- a/docs/API/enums/AudioSessionModeIos.md +++ b/docs/API/enums/AudioSessionModeIos.md @@ -27,7 +27,7 @@ https://developer.apple.com/documentation/avfaudio/avaudiosessionmode?language=o #### Defined in -[AudioSessionIos.ts:33](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L33) +[AudioSessionIos.ts:33](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L33) ___ @@ -37,7 +37,7 @@ ___ #### Defined in -[AudioSessionIos.ts:36](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L36) +[AudioSessionIos.ts:36](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L36) ___ @@ -47,7 +47,7 @@ ___ #### Defined in -[AudioSessionIos.ts:38](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L38) +[AudioSessionIos.ts:38](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L38) ___ @@ -57,7 +57,7 @@ ___ #### Defined in -[AudioSessionIos.ts:39](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L39) +[AudioSessionIos.ts:39](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L39) ___ @@ -67,7 +67,7 @@ ___ #### Defined in -[AudioSessionIos.ts:40](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L40) +[AudioSessionIos.ts:40](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L40) ___ @@ -77,7 +77,7 @@ ___ #### Defined in -[AudioSessionIos.ts:35](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L35) +[AudioSessionIos.ts:35](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L35) ___ @@ -87,7 +87,7 @@ ___ #### Defined in -[AudioSessionIos.ts:37](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L37) +[AudioSessionIos.ts:37](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L37) ___ @@ -97,4 +97,4 @@ ___ #### Defined in -[AudioSessionIos.ts:34](https://github.com/mybigday/whisper.rn/blob/f3ce9a6/src/AudioSessionIos.ts#L34) +[AudioSessionIos.ts:34](https://github.com/mybigday/whisper.rn/blob/ae1df68/src/AudioSessionIos.ts#L34) diff --git a/example/ios/Podfile b/example/ios/Podfile index e1c5cf6..6d182c8 100644 --- a/example/ios/Podfile +++ b/example/ios/Podfile @@ -27,7 +27,8 @@ target 'RNWhisperExample' do # Tip: You can use RNWHISPER_DISABLE_COREML = '1' to disable CoreML support. ENV['RNWHISPER_DISABLE_COREML'] = '0' - ENV['RNWHISPER_ENABLE_METAL'] = '0' # TODO + # Tip: You can use RNWHISPER_DISABLE_METAL = '1' to disable GPU support. + ENV['RNWHISPER_DISABLE_METAL'] = '0' config = use_native_modules! diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index 600a807..f79eae6 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -764,7 +764,7 @@ PODS: - SSZipArchive (~> 2.2) - SocketRocket (0.6.0) - SSZipArchive (2.4.3) - - whisper-rn (0.4.0-rc.2): + - whisper-rn (0.4.0-rc.3): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -1006,10 +1006,10 @@ SPEC CHECKSUMS: RNZipArchive: ef9451b849c45a29509bf44e65b788829ab07801 SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608 SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef - whisper-rn: 02c39fead176096e6d8420c4f7f4326f122b36e3 + whisper-rn: 36ab43448dfed18fac97b74a461bd7e3a472bd53 Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa YogaKit: f782866e155069a2cca2517aafea43200b01fd5a -PODFILE CHECKSUM: a78cf54fa529c6dc4b44aaf32b861fdf1245919a +PODFILE CHECKSUM: 957e6780c2a33b892b094d14de1977f5749cbd7b COCOAPODS: 1.11.3 diff --git a/example/src/context-opts.ios.ts b/example/src/context-opts.ios.ts index 459979e..f0784b9 100644 --- a/example/src/context-opts.ios.ts +++ b/example/src/context-opts.ios.ts @@ -1,16 +1,18 @@ -// import { Platform } from 'react-native' +import { Platform } from 'react-native' export default { + useCoreMLIos: true, // If you don't want to enable Core ML, you can remove this property - // coreMLModelAsset: - // Platform.OS === 'ios' - // ? { - // filename: 'ggml-tiny.en-encoder.mlmodelc', - // assets: [ - // require('../assets/ggml-tiny.en-encoder.mlmodelc/weights/weight.bin'), - // require('../assets/ggml-tiny.en-encoder.mlmodelc/model.mil'), - // require('../assets/ggml-tiny.en-encoder.mlmodelc/coremldata.bin'), - // ], - // } - // : undefined, + coreMLModelAsset: + Platform.OS === 'ios' + ? { + filename: 'ggml-tiny.en-encoder.mlmodelc', + assets: [ + require('../assets/ggml-tiny.en-encoder.mlmodelc/weights/weight.bin'), + require('../assets/ggml-tiny.en-encoder.mlmodelc/model.mil'), + require('../assets/ggml-tiny.en-encoder.mlmodelc/coremldata.bin'), + ], + } + : undefined, + useGpu: false, // Enable Metal (Will skip Core ML if enabled) } diff --git a/ios/RNWhisper.mm b/ios/RNWhisper.mm index f9699af..a8c69f8 100644 --- a/ios/RNWhisper.mm +++ b/ios/RNWhisper.mm @@ -48,6 +48,7 @@ - (NSDictionary *)constantsToExport NSString *modelPath = [modelOptions objectForKey:@"filePath"]; BOOL isBundleAsset = [[modelOptions objectForKey:@"isBundleAsset"] boolValue]; + BOOL useGpu = [[modelOptions objectForKey:@"useGpu"] boolValue]; BOOL useCoreMLIos = [[modelOptions objectForKey:@"useCoreMLIos"] boolValue]; // For support debug assets in development mode @@ -77,6 +78,7 @@ - (NSDictionary *)constantsToExport initWithModelPath:path contextId:contextId noCoreML:!useCoreMLIos + noMetal:!useGpu ]; if ([context getContext] == NULL) { reject(@"whisper_cpp_error", @"Failed to load the model", nil); @@ -85,7 +87,11 @@ - (NSDictionary *)constantsToExport [contexts setObject:context forKey:[NSNumber numberWithInt:contextId]]; - resolve([NSNumber numberWithInt:contextId]); + resolve(@{ + @"contextId": @(contextId), + @"gpu": @([context isMetalEnabled]), + @"reasonNoGPU": [context reasonNoMetal], + }); } - (NSArray *)supportedEvents { diff --git a/ios/RNWhisperContext.h b/ios/RNWhisperContext.h index 4bee2fa..4d6d4ad 100644 --- a/ios/RNWhisperContext.h +++ b/ios/RNWhisperContext.h @@ -46,9 +46,13 @@ typedef struct { dispatch_queue_t dQueue; struct whisper_context * ctx; RNWhisperContextRecordState recordState; + NSString * reasonNoMetal; + bool isMetalEnabled; } -+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML; ++ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML noMetal:(BOOL)noMetal; +- (bool)isMetalEnabled; +- (NSString *)reasonNoMetal; - (struct whisper_context *)getContext; - (dispatch_queue_t)getDispatchQueue; - (OSStatus)transcribeRealtime:(int)jobId diff --git a/ios/RNWhisperContext.mm b/ios/RNWhisperContext.mm index 2c90755..4c0c3c3 100644 --- a/ios/RNWhisperContext.mm +++ b/ios/RNWhisperContext.mm @@ -1,30 +1,92 @@ #import "RNWhisperContext.h" #import "RNWhisperAudioUtils.h" +#import #include #define NUM_BYTES_PER_BUFFER 16 * 1024 @implementation RNWhisperContext -+ (instancetype)initWithModelPath:(NSString *)modelPath contextId:(int)contextId noCoreML:(BOOL)noCoreML { ++ (instancetype)initWithModelPath:(NSString *)modelPath + contextId:(int)contextId + noCoreML:(BOOL)noCoreML + noMetal:(BOOL)noMetal +{ RNWhisperContext *context = [[RNWhisperContext alloc] init]; context->contextId = contextId; -#ifdef WHISPER_USE_COREML - if (noCoreML) { - context->ctx = whisper_init_from_file_no_coreml([modelPath UTF8String]); - } else { - context->ctx = whisper_init_from_file([modelPath UTF8String]); + struct whisper_context_params cparams; + NSString *reasonNoMetal = @""; + cparams.use_gpu = !noMetal; + + cparams.use_coreml = !noCoreML; +#ifndef WHISPER_USE_COREML + if (cparams.use_coreml) { + NSLog(@"[RNWhisper] CoreML is not enabled in this build, ignoring use_coreml option"); + cparams.use_coreml = false; + } +#endif + +#ifndef WSP_GGML_USE_METAL + if (cparams.use_gpu) { + NSLog(@"[RNWhisper] ggml-metal is not enabled in this build, ignoring use_gpu option"); + cparams.use_gpu = false; } -#else - context->ctx = whisper_init_from_file([modelPath UTF8String]); #endif + +#ifdef WSP_GGML_USE_METAL + if (cparams.use_gpu) { +#if TARGET_OS_SIMULATOR + NSLog(@"[RNWhisper] ggml-metal is not available in simulator, ignoring use_gpu option: %@", reasonNoMetal); + cparams.use_gpu = false; +#else // TARGET_OS_SIMULATOR + // Check ggml-metal availability + NSError * error = nil; + id device = MTLCreateSystemDefaultDevice(); + id library = [device + newLibraryWithSource:@"#include \n" + "using namespace metal;" + "kernel void test() { simd_sum(0); }" + options:nil + error:&error + ]; + if (error) { + reasonNoMetal = [error localizedDescription]; + } else { + id kernel = [library newFunctionWithName:@"test"]; + id pipeline = [device newComputePipelineStateWithFunction:kernel error:&error]; + if (pipeline == nil) { + reasonNoMetal = [error localizedDescription]; + NSLog(@"[RNWhisper] ggml-metal is not available, ignoring use_gpu option: %@", reasonNoMetal); + cparams.use_gpu = false; + } + } +#endif // TARGET_OS_SIMULATOR + } +#endif // WSP_GGML_USE_METAL + + if (cparams.use_gpu && cparams.use_coreml) { + NSLog(@"[RNWhisper] Both use_gpu and use_coreml are enabled, ignoring use_coreml option"); + cparams.use_coreml = false; // Skip CoreML if Metal is enabled + } + + context->ctx = whisper_init_from_file_with_params([modelPath UTF8String], cparams); context->dQueue = dispatch_queue_create( [[NSString stringWithFormat:@"RNWhisperContext-%d", contextId] UTF8String], DISPATCH_QUEUE_SERIAL ); + context->isMetalEnabled = cparams.use_gpu; + context->reasonNoMetal = reasonNoMetal; return context; } +- (bool)isMetalEnabled { + return isMetalEnabled; +} + +- (NSString *)reasonNoMetal { + return reasonNoMetal; +} + - (struct whisper_context *)getContext { return self->ctx; } diff --git a/jest/mock.js b/jest/mock.js index b7e6567..dca094a 100644 --- a/jest/mock.js +++ b/jest/mock.js @@ -2,7 +2,7 @@ const { NativeModules, DeviceEventEmitter } = require('react-native') if (!NativeModules.RNWhisper) { NativeModules.RNWhisper = { - initContext: jest.fn(() => Promise.resolve(1)), + initContext: jest.fn(() => Promise.resolve({ contextId: 1 })), transcribeFile: jest.fn(() => Promise.resolve({ result: ' Test', segments: [{ text: ' Test', t0: 0, t1: 33 }], diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 8549b68..4e80b62 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -5,8 +5,14 @@ git submodule update --recursive cp ./whisper.cpp/ggml.h ./cpp/ggml.h cp ./whisper.cpp/ggml.c ./cpp/ggml.c +cp ./whisper.cpp/ggml-impl.h ./cpp/ggml-impl.h cp ./whisper.cpp/ggml-alloc.h ./cpp/ggml-alloc.h cp ./whisper.cpp/ggml-alloc.c ./cpp/ggml-alloc.c +cp ./whisper.cpp/ggml-quants.h ./cpp/ggml-quants.h +cp ./whisper.cpp/ggml-quants.c ./cpp/ggml-quants.c +cp ./whisper.cpp/ggml-backend.h ./cpp/ggml-backend.h +cp ./whisper.cpp/ggml-backend.c ./cpp/ggml-backend.c +cp ./whisper.cpp/ggml-backend-impl.h ./cpp/ggml-backend-impl.h cp ./whisper.cpp/ggml-metal.h ./cpp/ggml-metal.h cp ./whisper.cpp/ggml-metal.m ./cpp/ggml-metal.m cp ./whisper.cpp/ggml-metal.metal ./cpp/ggml-metal-whisper.metal @@ -20,8 +26,14 @@ cp -R ./whisper.cpp/coreml/ ./cpp/coreml/ files=( "./cpp/ggml.h" "./cpp/ggml.c" + "./cpp/ggml-impl.h" "./cpp/ggml-alloc.h" "./cpp/ggml-alloc.c" + "./cpp/ggml-quants.h" + "./cpp/ggml-quants.c" + "./cpp/ggml-backend.h" + "./cpp/ggml-backend.c" + "./cpp/ggml-backend-impl.h" "./cpp/ggml-metal.h" "./cpp/ggml-metal.m" "./cpp/whisper.h" diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch index 222da8a..0fe27b0 100644 --- a/scripts/ggml-metal.m.patch +++ b/scripts/ggml-metal.m.patch @@ -1,47 +1,37 @@ ---- ggml-metal.m.orig 2023-10-25 17:55:09 -+++ ggml-metal.m 2023-10-25 17:55:42 -@@ -178,7 +178,7 @@ - - //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"]; - NSBundle * bundle = [NSBundle bundleForClass:[WSPGGMLMetalClass class]]; -- NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; -+ NSString * path = [bundle pathForResource:@"ggml-metal-whisper" ofType:@"metal"]; - metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]); - - NSString * src = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error]; -@@ -207,7 +207,7 @@ - #define WSP_GGML_METAL_ADD_KERNEL(name) \ - ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \ - ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \ -- metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \ -+ metal_printf("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (__bridge void *) ctx->pipeline_##name, \ - (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \ - (int) ctx->pipeline_##name.threadExecutionWidth); \ - if (error) { \ -@@ -286,8 +286,6 @@ +--- ggml-metal.m.orig 2023-11-07 18:03:28 ++++ ggml-metal.m 2023-11-07 18:03:29 +@@ -215,7 +215,7 @@ + if (ggmlMetalPathResources) { + sourcePath = [ggmlMetalPathResources stringByAppendingPathComponent:@"ggml-metal.metal"]; + } else { +- sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"]; ++ sourcePath = [bundle pathForResource:@"ggml-metal-whisper" ofType:@"metal"]; + } + if (sourcePath == nil) { + WSP_GGML_METAL_LOG_WARN("%s: error: could not use bundle path to find ggml-metal.metal, falling back to trying cwd\n", __func__); +@@ -355,8 +355,6 @@ void wsp_ggml_metal_free(struct wsp_ggml_metal_context * ctx) { - metal_printf("%s: deallocating\n", __func__); + WSP_GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); #define WSP_GGML_METAL_DEL_KERNEL(name) \ - [ctx->function_##name release]; \ - [ctx->pipeline_##name release]; WSP_GGML_METAL_DEL_KERNEL(add); WSP_GGML_METAL_DEL_KERNEL(add_row); -@@ -342,17 +340,7 @@ - WSP_GGML_METAL_DEL_KERNEL(cpy_f16_f16); +@@ -423,16 +421,6 @@ + WSP_GGML_METAL_DEL_KERNEL(sqr); #undef WSP_GGML_METAL_DEL_KERNEL - - for (int i = 0; i < ctx->n_buffers; ++i) { - [ctx->buffers[i].metal release]; - } - +- - [ctx->library release]; - [ctx->queue release]; - [ctx->device release]; - - dispatch_release(ctx->d_queue); -- + free(ctx); } - diff --git a/scripts/whisper.cpp.patch b/scripts/whisper.cpp.patch index a41f652..44e6316 100644 --- a/scripts/whisper.cpp.patch +++ b/scripts/whisper.cpp.patch @@ -1,53 +1,28 @@ ---- whisper.cpp.orig 2023-10-12 11:44:51 -+++ whisper.cpp 2023-10-12 11:43:31 -@@ -770,6 +770,9 @@ - whisper_state * state = nullptr; - - std::string path_model; // populated by whisper_init_from_file() -+#ifdef WHISPER_USE_COREML -+ bool load_coreml = true; -+#endif - }; - - static void whisper_default_log(const char * text) { -@@ -2854,6 +2857,7 @@ +--- whisper.cpp.orig 2023-11-08 05:39:06 ++++ whisper.cpp 2023-11-08 05:39:07 +@@ -2881,7 +2881,9 @@ + log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - + ++ #ifdef WHISPER_USE_COREML -+if (ctx->load_coreml) { // Not in correct layer for easy patch ++ if (ctx->params.use_coreml) { const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); - + log("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); -@@ -2869,6 +2873,7 @@ +@@ -2896,6 +2898,7 @@ + #endif } else { log("%s: Core ML model loaded\n", __func__); ++ } } -+} #endif - - state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); -@@ -2987,7 +2992,24 @@ - state->rng = std::mt19937(0); - - return state; -+} -+ -+#ifdef WHISPER_USE_COREML -+struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model) { -+ whisper_context * ctx = whisper_init_from_file_no_state(path_model); -+ if (!ctx) { -+ return nullptr; -+ } -+ ctx->load_coreml = false; -+ ctx->state = whisper_init_state(ctx); -+ if (!ctx->state) { -+ whisper_free(ctx); -+ return nullptr; -+ } -+ -+ return ctx; + +@@ -3074,6 +3077,7 @@ + struct whisper_context_params whisper_context_default_params() { + struct whisper_context_params result = { + /*.use_gpu =*/ true, ++ /*.use_coreml =*/ false, + }; + return result; } -+#endif - - int whisper_ctx_init_openvino_encoder( - struct whisper_context * ctx, diff --git a/scripts/whisper.h.patch b/scripts/whisper.h.patch index 88fcd3d..9df401a 100644 --- a/scripts/whisper.h.patch +++ b/scripts/whisper.h.patch @@ -1,12 +1,10 @@ ---- whisper.h.orig 2023-10-12 10:41:41 -+++ whisper.h 2023-10-12 10:38:11 -@@ -99,6 +99,9 @@ - // Various functions for loading a ggml whisper model. - // Allocate (almost) all memory needed for the model. - // Return NULL on failure -+#ifdef WHISPER_USE_COREML -+ WHISPER_API struct whisper_context * whisper_init_from_file_no_coreml(const char * path_model); -+#endif - WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model); - WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size); - WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader); +--- whisper.h.orig 2023-11-08 05:39:06 ++++ whisper.h 2023-11-08 05:39:07 +@@ -80,6 +80,7 @@ + + struct whisper_context_params { + bool use_gpu; ++ bool use_coreml; + }; + + typedef struct whisper_token_data { diff --git a/src/NativeRNWhisper.ts b/src/NativeRNWhisper.ts index 18c7490..8290f62 100644 --- a/src/NativeRNWhisper.ts +++ b/src/NativeRNWhisper.ts @@ -52,17 +52,24 @@ export type CoreMLAsset = { type NativeContextOptions = { filePath: string, isBundleAsset: boolean, + useGpu?: boolean, useCoreMLIos?: boolean, downloadCoreMLAssets?: boolean, coreMLAssets?: CoreMLAsset[], } +export type NativeWhisperContext = { + contextId: number + gpu: boolean + reasonNoGPU: string +} + export interface Spec extends TurboModule { getConstants(): { useCoreML: boolean coreMLAllowFallback: boolean }; - initContext(options: NativeContextOptions): Promise; + initContext(options: NativeContextOptions): Promise; releaseContext(contextId: number): Promise; releaseAllContexts(): Promise; transcribeFile( diff --git a/src/index.ts b/src/index.ts index 968626e..6314355 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,7 +5,7 @@ import { DeviceEventEmitterStatic, Image, } from 'react-native' -import RNWhisper from './NativeRNWhisper' +import RNWhisper, { NativeWhisperContext } from './NativeRNWhisper' import type { TranscribeOptions, TranscribeResult, @@ -183,8 +183,18 @@ const updateAudioSession = async (setting: AudioSessionSettingIos) => { export class WhisperContext { id: number - constructor(id: number) { - this.id = id + gpu: boolean = false + + reasonNoGPU: string = '' + + constructor({ + contextId, + gpu, + reasonNoGPU, + }: NativeWhisperContext) { + this.id = contextId + this.gpu = gpu + this.reasonNoGPU = reasonNoGPU } /** Transcribe audio file */ @@ -440,6 +450,8 @@ export type ContextOptions = { isBundleAsset?: boolean /** Prefer to use Core ML model if exists. If set to false, even if the Core ML model exists, it will not be used. */ useCoreMLIos?: boolean + /** Use GPU if available. Currently iOS only, if it's enabled, Core ML option will be ignored. */ + useGpu?: boolean } const coreMLModelAssetPaths = [ @@ -453,6 +465,7 @@ export async function initWhisper({ filePath, coreMLModelAsset, isBundleAsset, + useGpu = true, useCoreMLIos = true, }: ContextOptions): Promise { let path = '' @@ -499,15 +512,16 @@ export async function initWhisper({ path = filePath } if (path.startsWith('file://')) path = path.slice(7) - const id = await RNWhisper.initContext({ + const { contextId, gpu, reasonNoGPU } = await RNWhisper.initContext({ filePath: path, isBundleAsset: !!isBundleAsset, + useGpu, useCoreMLIos, // Only development mode need download Core ML model assets (from packager server) downloadCoreMLAssets: __DEV__ && !!coreMLAssets, coreMLAssets, }) - return new WhisperContext(id) + return new WhisperContext({ contextId, gpu, reasonNoGPU }) } export async function releaseAllWhisper(): Promise { diff --git a/src/version.json b/src/version.json index f564e33..9b64601 100644 --- a/src/version.json +++ b/src/version.json @@ -1 +1 @@ -{"version":"1.4.2"} \ No newline at end of file +{"version":"1.4.3"} \ No newline at end of file diff --git a/whisper-rn.podspec b/whisper-rn.podspec index 2fd1f8e..a0311ee 100644 --- a/whisper-rn.podspec +++ b/whisper-rn.podspec @@ -16,8 +16,7 @@ if ENV['RNWHISPER_DISABLE_COREML'] != '1' then base_compiler_flags += " -DWHISPER_USE_COREML -DWHISPER_COREML_ALLOW_FALLBACK" end -# TODO: Enable Metal by default when we have use_gpu param -if ENV["RNWHISPER_ENABLE_METAL"] == "1" then +if ENV["RNWHISPER_DISABLE_METAL"] != "1" then base_compiler_flags += " -DWSP_GGML_USE_METAL" # -DWSP_GGML_METAL_NDEBUG end diff --git a/whisper.cpp b/whisper.cpp index 940cdb1..6a5d195 160000 --- a/whisper.cpp +++ b/whisper.cpp @@ -1 +1 @@ -Subproject commit 940cdb13964a563d86c7dc6e160a43ec89b8bb2e +Subproject commit 6a5d195109994b865e1c92a88258ac182399eb64