From 951436c9356f2df7757dbbfa87baa02bb7b34775 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Mon, 23 Dec 2024 14:21:51 +0800 Subject: [PATCH] feat: sync llama.cpp (#100) * feat: sync llama.cpp * feat: sync llama.cpp * fix: add missing GGML_USE_CPU --- android/src/main/CMakeLists.txt | 10 +- .../main/java/com/rnllama/LlamaContext.java | 3 - android/src/main/jni.cpp | 14 +- cpp/amx/amx.cpp | 220 ++ cpp/amx/amx.h | 8 + cpp/amx/common.h | 91 + cpp/amx/mmq.cpp | 2511 +++++++++++++++++ cpp/amx/mmq.h | 10 + cpp/common.cpp | 212 +- cpp/common.h | 105 +- cpp/ggml-aarch64.c | 129 - cpp/ggml-aarch64.h | 19 - cpp/ggml-alloc.c | 1 - cpp/ggml-backend-impl.h | 62 +- cpp/ggml-backend-reg.cpp | 405 ++- cpp/ggml-backend.cpp | 8 +- cpp/ggml-backend.h | 16 + cpp/ggml-common.h | 84 +- ...gml-cpu-aarch64.c => ggml-cpu-aarch64.cpp} | 1188 ++++++-- cpp/ggml-cpu-aarch64.h | 26 +- cpp/ggml-cpu-impl.h | 15 + cpp/ggml-cpu-quants.c | 63 +- cpp/ggml-cpu-traits.cpp | 36 + cpp/ggml-cpu-traits.h | 38 + cpp/ggml-cpu.c | 951 ++++--- cpp/ggml-cpu.cpp | 237 +- cpp/ggml-cpu.h | 56 +- cpp/ggml-impl.h | 43 +- cpp/ggml-metal-impl.h | 39 + cpp/ggml-metal.m | 1148 +++++--- cpp/ggml-opt.cpp | 147 +- cpp/ggml-quants.c | 9 - cpp/ggml-threading.h | 6 +- cpp/ggml.c | 574 ++-- cpp/ggml.h | 130 +- cpp/llama-grammar.cpp | 30 +- cpp/llama-grammar.h | 7 +- cpp/llama-sampling.cpp | 125 +- cpp/llama-vocab.cpp | 7 +- cpp/llama.cpp | 2468 ++++++++++++---- cpp/llama.h | 42 +- cpp/rn-llama.hpp | 33 +- cpp/sampling.cpp | 72 +- cpp/sampling.h | 23 +- cpp/sgemm.cpp | 1 + cpp/unicode.cpp | 113 +- cpp/unicode.h | 19 +- docs/API/README.md | 47 +- docs/API/classes/LlamaContext.md | 36 +- docs/API/classes/SchemaGrammarConverter.md | 32 +- .../SchemaGrammarConverterBuiltinRule.md | 6 +- example/ios/.xcode.env.local | 2 +- example/ios/Podfile.lock | 2 +- .../RNLlamaExample.xcodeproj/project.pbxproj | 4 +- example/src/App.tsx | 1 - ios/RNLlamaContext.mm | 13 +- llama-rn.podspec | 4 +- llama.cpp | 2 +- scripts/bootstrap.sh | 22 +- scripts/common.cpp.patch | 26 +- scripts/common.h.patch | 14 +- scripts/ggml-backend-reg.cpp.patch | 17 +- scripts/ggml-cpu-aarch64.c.patch | 11 - scripts/ggml-metal.m.patch | 12 +- scripts/ggml.c.patch | 6 +- scripts/llama.cpp.patch | 6 +- scripts/sgemm.cpp.patch | 12 - src/NativeRNLlama.ts | 4 - 68 files changed, 8941 insertions(+), 2892 deletions(-) create mode 100644 cpp/amx/amx.cpp create mode 100644 cpp/amx/amx.h create mode 100644 cpp/amx/common.h create mode 100644 cpp/amx/mmq.cpp create mode 100644 cpp/amx/mmq.h delete mode 100644 cpp/ggml-aarch64.c delete mode 100644 cpp/ggml-aarch64.h rename cpp/{ggml-cpu-aarch64.c => ggml-cpu-aarch64.cpp} (79%) create mode 100644 cpp/ggml-cpu-traits.cpp create mode 100644 cpp/ggml-cpu-traits.h delete mode 100644 scripts/ggml-cpu-aarch64.c.patch delete mode 100644 scripts/sgemm.cpp.patch diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index e9c3bb42..c689457e 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.10) project(llama.rn) -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(RNLLAMA_LIB_DIR ${CMAKE_SOURCE_DIR}/../../../cpp) include_directories(${RNLLAMA_LIB_DIR}) @@ -10,14 +10,14 @@ include_directories(${RNLLAMA_LIB_DIR}) set( SOURCE_FILES ${RNLLAMA_LIB_DIR}/ggml.c - ${RNLLAMA_LIB_DIR}/ggml-aarch64.c ${RNLLAMA_LIB_DIR}/ggml-alloc.c ${RNLLAMA_LIB_DIR}/ggml-backend.cpp ${RNLLAMA_LIB_DIR}/ggml-backend-reg.cpp ${RNLLAMA_LIB_DIR}/ggml-cpu.c ${RNLLAMA_LIB_DIR}/ggml-cpu.cpp - ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.c + ${RNLLAMA_LIB_DIR}/ggml-cpu-aarch64.cpp ${RNLLAMA_LIB_DIR}/ggml-cpu-quants.c + ${RNLLAMA_LIB_DIR}/ggml-cpu-traits.cpp ${RNLLAMA_LIB_DIR}/ggml-opt.cpp ${RNLLAMA_LIB_DIR}/ggml-threading.cpp ${RNLLAMA_LIB_DIR}/ggml-quants.c @@ -32,6 +32,8 @@ set( ${RNLLAMA_LIB_DIR}/sgemm.cpp ${RNLLAMA_LIB_DIR}/common.cpp ${RNLLAMA_LIB_DIR}/rn-llama.hpp + ${RNLLAMA_LIB_DIR}/amx/amx.cpp + ${RNLLAMA_LIB_DIR}/amx/mmq.cpp ${CMAKE_SOURCE_DIR}/jni-utils.h ${CMAKE_SOURCE_DIR}/jni.cpp ) @@ -47,7 +49,7 @@ function(build_library target_name cpu_flags) target_link_libraries(${target_name} ${LOG_LIB} android) - target_compile_options(${target_name} PRIVATE -pthread ${cpu_flags}) + target_compile_options(${target_name} PRIVATE -DLM_GGML_USE_CPU -pthread ${cpu_flags}) if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") target_compile_options(${target_name} PRIVATE -DRNLLAMA_ANDROID_ENABLE_LOGGING) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 9c14a3a0..2b36d524 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -217,8 +217,6 @@ public WritableMap completion(ReadableMap params) { params.hasKey("mirostat_tau") ? (float) params.getDouble("mirostat_tau") : 5.00f, // float mirostat_eta, params.hasKey("mirostat_eta") ? (float) params.getDouble("mirostat_eta") : 0.10f, - // boolean penalize_nl, - params.hasKey("penalize_nl") ? params.getBoolean("penalize_nl") : false, // int top_k, params.hasKey("top_k") ? params.getInt("top_k") : 40, // float top_p, @@ -463,7 +461,6 @@ protected static native WritableMap doCompletion( float mirostat, float mirostat_tau, float mirostat_eta, - boolean penalize_nl, int top_k, float top_p, float min_p, diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index d2616571..774b44d7 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -280,8 +280,8 @@ Java_com_rnllama_LlamaContext_initContext( const char *cache_type_k_chars = env->GetStringUTFChars(cache_type_k, nullptr); const char *cache_type_v_chars = env->GetStringUTFChars(cache_type_v, nullptr); - defaultParams.cache_type_k = cache_type_k_chars; - defaultParams.cache_type_v = cache_type_v_chars; + defaultParams.cache_type_k = rnllama::kv_cache_type_from_str(cache_type_k_chars); + defaultParams.cache_type_v = rnllama::kv_cache_type_from_str(cache_type_v_chars); defaultParams.use_mlock = use_mlock; defaultParams.use_mmap = use_mmap; @@ -553,7 +553,6 @@ Java_com_rnllama_LlamaContext_doCompletion( jfloat mirostat, jfloat mirostat_tau, jfloat mirostat_eta, - jboolean penalize_nl, jint top_k, jfloat top_p, jfloat min_p, @@ -579,7 +578,7 @@ Java_com_rnllama_LlamaContext_doCompletion( //llama_reset_timings(llama->ctx); llama->params.prompt = env->GetStringUTFChars(prompt, nullptr); - llama->params.sparams.seed = (seed == -1) ? time(NULL) : seed; + llama->params.sampling.seed = (seed == -1) ? time(NULL) : seed; int max_threads = std::thread::hardware_concurrency(); // Use 2 threads by default on 4-core devices, 4 threads on more cores @@ -587,9 +586,9 @@ Java_com_rnllama_LlamaContext_doCompletion( llama->params.cpuparams.n_threads = n_threads > 0 ? n_threads : default_n_threads; llama->params.n_predict = n_predict; - llama->params.sparams.ignore_eos = ignore_eos; + llama->params.sampling.ignore_eos = ignore_eos; - auto & sparams = llama->params.sparams; + auto & sparams = llama->params.sampling; sparams.temp = temperature; sparams.penalty_last_n = penalty_last_n; sparams.penalty_repeat = penalty_repeat; @@ -598,7 +597,6 @@ Java_com_rnllama_LlamaContext_doCompletion( sparams.mirostat = mirostat; sparams.mirostat_tau = mirostat_tau; sparams.mirostat_eta = mirostat_eta; - sparams.penalize_nl = penalize_nl; sparams.top_k = top_k; sparams.top_p = top_p; sparams.min_p = min_p; @@ -714,7 +712,7 @@ Java_com_rnllama_LlamaContext_doCompletion( auto tokenResult = createWriteableMap(env); putString(env, tokenResult, "token", to_send.c_str()); - if (llama->params.sparams.n_probs > 0) { + if (llama->params.sampling.n_probs > 0) { const std::vector to_send_toks = common_tokenize(llama->ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); diff --git a/cpp/amx/amx.cpp b/cpp/amx/amx.cpp new file mode 100644 index 00000000..41802f88 --- /dev/null +++ b/cpp/amx/amx.cpp @@ -0,0 +1,220 @@ +#include "amx.h" +#include "common.h" +#include "mmq.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-traits.h" + +#if defined(__gnu_linux__) +#include +#include +#endif + +#include +#include +#include + +#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) + +// AMX type_trais +namespace ggml::cpu::amx { +class tensor_traits : public ggml::cpu::tensor_traits { + bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override { + size = lm_ggml_backend_amx_desired_wsize(op); + return true; + } + + bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override { + if (op->op == LM_GGML_OP_MUL_MAT) { + lm_ggml_backend_amx_mul_mat(params, op); + return true; + } + return false; + } +}; + +static ggml::cpu::tensor_traits * get_tensor_traits(lm_ggml_backend_buffer_t, struct lm_ggml_tensor *) { + static tensor_traits traits; + return &traits; +} +} // namespace ggml::cpu::amx + +// AMX buffer interface +static void lm_ggml_backend_amx_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { + free(buffer->context); +} + +static void * lm_ggml_backend_amx_buffer_get_base(lm_ggml_backend_buffer_t buffer) { + return (void *) (buffer->context); +} + +static void lm_ggml_backend_amx_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) { + tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor); + + LM_GGML_UNUSED(buffer); +} + +static void lm_ggml_backend_amx_buffer_memset_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, + uint8_t value, size_t offset, size_t size) { + memset((char *) tensor->data + offset, value, size); + + LM_GGML_UNUSED(buffer); +} + +static void lm_ggml_backend_amx_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + if (qtype_has_amx_kernels(tensor->type)) { + LM_GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, lm_ggml_type_name(tensor->type)); + lm_ggml_backend_amx_convert_weight(tensor, data, offset, size); + } else { + memcpy((char *) tensor->data + offset, data, size); + } + + LM_GGML_UNUSED(buffer); +} + +/* +// need to figure what we need to do with buffer->extra. +static void lm_ggml_backend_amx_buffer_get_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { + LM_GGML_ASSERT(!qtype_has_amx_kernels(tensor->type)); + memcpy(data, (const char *)tensor->data + offset, size); + + LM_GGML_UNUSED(buffer); +} + +static bool lm_ggml_backend_amx_buffer_cpy_tensor(lm_ggml_backend_buffer_t buffer, const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst) { + if (lm_ggml_backend_buffer_is_host(src->buffer)) { + if (qtype_has_amx_kernels(src->type)) { + lm_ggml_backend_amx_convert_weight(dst, src->data, 0, lm_ggml_nbytes(dst)); + } else { + memcpy(dst->data, src->data, lm_ggml_nbytes(src)); + } + return true; + } + return false; + + LM_GGML_UNUSED(buffer); +} +*/ + +static void lm_ggml_backend_amx_buffer_clear(lm_ggml_backend_buffer_t buffer, uint8_t value) { + memset(buffer->context, value, buffer->size); +} + +static lm_ggml_backend_buffer_i lm_ggml_backend_amx_buffer_interface = { + /* .free_buffer = */ lm_ggml_backend_amx_buffer_free_buffer, + /* .get_base = */ lm_ggml_backend_amx_buffer_get_base, + /* .init_tensor = */ lm_ggml_backend_amx_buffer_init_tensor, + /* .memset_tensor = */ lm_ggml_backend_amx_buffer_memset_tensor, + /* .set_tensor = */ lm_ggml_backend_amx_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ lm_ggml_backend_amx_buffer_clear, + /* .reset = */ nullptr, +}; + +static const char * lm_ggml_backend_amx_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { + return "AMX"; + + LM_GGML_UNUSED(buft); +} + +static lm_ggml_backend_buffer_t lm_ggml_backend_amx_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { + void * data = lm_ggml_aligned_malloc(size); + if (data == NULL) { + fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size); + return NULL; + } + + return lm_ggml_backend_buffer_init(buft, lm_ggml_backend_amx_buffer_interface, data, size); +} + +static size_t lm_ggml_backend_amx_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + LM_GGML_UNUSED(buft); +} + +namespace ggml::cpu::amx { +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(lm_ggml_backend_dev_t, const struct lm_ggml_tensor * op) override { + // handle only 2d gemm for now + auto is_contiguous_2d = [](const struct lm_ggml_tensor * t) { + return lm_ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; + }; + + if (op->op == LM_GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous + is_contiguous_2d(op->src[1]) && // src1 must be contiguous + op->src[0]->buffer && op->src[0]->buffer->buft == lm_ggml_backend_amx_buffer_type() && + op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x + (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == LM_GGML_TYPE_F16))) { + // src1 must be host buffer + if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + // src1 must be float32 + if (op->src[1]->type == LM_GGML_TYPE_F32) { + return true; + } + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) override { + if (op->op == LM_GGML_OP_MUL_MAT && op->src[0]->buffer && + op->src[0]->buffer->buft == lm_ggml_backend_amx_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + + return nullptr; + } +}; +} // namespace ggml::cpu::amx + +static size_t lm_ggml_backend_amx_buffer_type_get_alloc_size(lm_ggml_backend_buffer_type_t buft, const lm_ggml_tensor * tensor) { + return lm_ggml_backend_amx_get_alloc_size(tensor); + + LM_GGML_UNUSED(buft); +} + +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 + +static bool lm_ggml_amx_init() { +#if defined(__gnu_linux__) + if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) { + fprintf(stderr, "AMX is not ready to be used!\n"); + return false; + } + return true; +#elif defined(_WIN32) + return true; +#endif +} + +lm_ggml_backend_buffer_type_t lm_ggml_backend_amx_buffer_type() { + static struct lm_ggml_backend_buffer_type lm_ggml_backend_buffer_type_amx = { + /* .iface = */ { + /* .get_name = */ lm_ggml_backend_amx_buffer_type_get_name, + /* .alloc_buffer = */ lm_ggml_backend_amx_buffer_type_alloc_buffer, + /* .get_alignment = */ lm_ggml_backend_amx_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ lm_ggml_backend_amx_buffer_type_get_alloc_size, + /* .is_host = */ nullptr, + }, + /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), + /* .context = */ new ggml::cpu::amx::extra_buffer_type(), + }; + + if (!lm_ggml_amx_init()) { + return nullptr; + } + + return &lm_ggml_backend_buffer_type_amx; +} + +#endif // defined(__AMX_INT8__) && defined(__AVX512VNNI__) diff --git a/cpp/amx/amx.h b/cpp/amx/amx.h new file mode 100644 index 00000000..ffabbcb2 --- /dev/null +++ b/cpp/amx/amx.h @@ -0,0 +1,8 @@ +#include "ggml-backend.h" +#include "ggml-cpu-impl.h" + +// GGML internal header + +#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) +lm_ggml_backend_buffer_type_t lm_ggml_backend_amx_buffer_type(void); +#endif diff --git a/cpp/amx/common.h b/cpp/amx/common.h new file mode 100644 index 00000000..559d0bb6 --- /dev/null +++ b/cpp/amx/common.h @@ -0,0 +1,91 @@ +#pragma once + +#include "ggml.h" +#include "ggml-cpu-impl.h" + +#include +#include +#include + +#if defined(LM_GGML_USE_OPENMP) +#include +#endif + +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 +#define VNNI_BLK 4 + +#define AMX_BLK_SIZE 32 + +#define TMM0 0 +#define TMM1 1 +#define TMM2 2 +#define TMM3 3 +#define TMM4 4 +#define TMM5 5 +#define TMM6 6 +#define TMM7 7 + +// parallel routines +template ::value, int>::type = 0> +inline T div_up(T x, T y) { return (x + y - 1) / y; } + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int n, const func_t& f) { +#if defined(LM_GGML_USE_OPENMP) +#pragma omp parallel +{ + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); +} +#else + f(0, n); +#endif +} + +template +inline void parallel_for_ggml(const lm_ggml_compute_params * params, int n, const func_t & f) { + int tbegin, tend; + balance211(n, params->nth, params->ith, tbegin, tend); + f(tbegin, tend); +} + +// quantized types that have AMX support +inline bool qtype_has_amx_kernels(const enum lm_ggml_type type) { + // TODO: fix padding for vnni format + return (type == LM_GGML_TYPE_Q4_0) || + (type == LM_GGML_TYPE_Q4_1) || + (type == LM_GGML_TYPE_Q8_0) || + (type == LM_GGML_TYPE_Q4_K) || + (type == LM_GGML_TYPE_Q5_K) || + (type == LM_GGML_TYPE_Q6_K) || + (type == LM_GGML_TYPE_IQ4_XS); +} diff --git a/cpp/amx/mmq.cpp b/cpp/amx/mmq.cpp new file mode 100644 index 00000000..e036d988 --- /dev/null +++ b/cpp/amx/mmq.cpp @@ -0,0 +1,2511 @@ + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wpedantic" +#pragma GCC diagnostic ignored "-Wunused-local-typedefs" +#endif + +#include "amx.h" +#include "mmq.h" +#include "ggml-impl.h" +#include "ggml-cpu-impl.h" +#include "ggml-cpu-quants.h" +#include "ggml-quants.h" +#include +#include + +#if defined(__gnu_linux__) +#include +#include +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define RESTRICT __restrict +#else +#define RESTRICT __restrict__ +#endif + +#if (defined(_WIN32) || defined(_WIN64)) +#define ALWAYS_INLINE __forceinline +#elif __has_attribute(always_inline) || defined(__GNUC__) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) + +namespace { + +// Forced unrolling +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +// type traits +template struct PackedTypes {}; +template <> struct PackedTypes { using type = int8_t; }; +template <> struct PackedTypes { using type = uint8_t; }; +template <> struct PackedTypes { using type = int8_t; }; +template using packed_B_type = typename PackedTypes::type; + +template +struct do_compensate : std::integral_constant::value> {}; + +template +struct do_unpack : std::integral_constant::value || + std::is_same::value> {}; + +template +struct is_type_qkk : std::integral_constant::value || + std::is_same::value || + std::is_same::value || + std::is_same::value> {}; + +#define LM_GGML_DISPATCH_FLOATING_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case LM_GGML_TYPE_F16: { \ + using type = lm_ggml_fp16_t; \ + constexpr int blck_size = 16; \ + return __VA_ARGS__(); \ + } \ + case LM_GGML_TYPE_BF16: { \ + using type = lm_ggml_bf16_t; \ + constexpr int blck_size = 32; \ + return __VA_ARGS__(); \ + } \ + default: \ + fprintf(stderr, "Unsupported floating data type\n"); \ + } \ + }() + +#define LM_GGML_DISPATCH_QTYPES(QT, ...) \ + [&] { \ + switch (QT) { \ + case LM_GGML_TYPE_Q4_0: { \ + using type = block_q4_0; \ + using vec_dot_type = block_q8_0; \ + constexpr int blck_size = QK4_0; \ + return __VA_ARGS__(); \ + } \ + case LM_GGML_TYPE_Q4_1: { \ + using type = block_q4_1; \ + using vec_dot_type = block_q8_1; \ + constexpr int blck_size = QK4_1; \ + return __VA_ARGS__(); \ + } \ + case LM_GGML_TYPE_Q8_0: { \ + using type = block_q8_0; \ + using vec_dot_type = block_q8_0; \ + constexpr int blck_size = QK8_0; \ + return __VA_ARGS__(); \ + } \ + case LM_GGML_TYPE_Q4_K: { \ + using type = block_q4_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case LM_GGML_TYPE_Q5_K: { \ + using type = block_q5_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case LM_GGML_TYPE_Q6_K: { \ + using type = block_q6_K; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + case LM_GGML_TYPE_IQ4_XS: { \ + using type = block_iq4_xs; \ + using vec_dot_type = block_q8_K; \ + constexpr int blck_size = QK_K; \ + return __VA_ARGS__(); \ + } \ + default: \ + fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE)); \ + } \ + }() + +#define LM_GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// define amx tile config data structure +struct tile_config_t{ + uint8_t palette_id = 0; + uint8_t start_row = 0; + uint8_t reserved_0[14] = {0}; + uint16_t colsb[16] = {0}; + uint8_t rows[16] = {0}; +}; + +// Notes: amx tile config +// +// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values, +// and accumulate the result to a 16 x 16 matrix C containing INT32 values, +// +// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used +// instead of the normally used 16-16-64 config. +// +// Block A: {16, 32}, dtype = int8_t +// Block B: {16, 32}, dtype = uint8_t/int8_t +// Block C: {16, 16}, dtype = int32_t +// +// Block B needs to be prepacked to vnni format before feeding into TMUL: +// packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64} +// +// Therefore, we get tileconfig: +// A B C +// rows 16 8 16 +// colsb 32 64 16 +// +// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1, +// C used TMM4-TMM7: +// B TMM0 B TMM1 +// A TMM2 C TMM4 C TMM6 +// A TMM3 C TMM5 C TMM7 +// +// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A +// will be needed. +// +// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; +// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. +// +// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ +// advanced-matrix-extensions-intrinsics-functions.html +// + +#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb +void lm_ggml_tile_config_init(void) { + static thread_local bool is_first_time = true; + + if (!is_first_time) { + return; + } + + static thread_local tile_config_t tc; + tile_config_t current_tc; + _tile_storeconfig(¤t_tc); + + // load only when config changes + if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && + memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { + tc.palette_id = 1; + tc.start_row = 0; + TC_CONFIG_TILE(TMM0, 8, 64); + TC_CONFIG_TILE(TMM1, 8, 64); + TC_CONFIG_TILE(TMM2, 16, 32); + TC_CONFIG_TILE(TMM3, 16, 32); + TC_CONFIG_TILE(TMM4, 16, 64); + TC_CONFIG_TILE(TMM5, 16, 64); + TC_CONFIG_TILE(TMM6, 16, 64); + TC_CONFIG_TILE(TMM7, 16, 64); + _tile_loadconfig(&tc); + } + + is_first_time = false; +} + +// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. +// See the notes `s8s8 igemm compensation in avx512-vnni` for detail. +template +int get_tile_size() { + int tile_size = TILE_N * sizeof(TB); + if (do_compensate::value) { + tile_size += TILE_N * sizeof(int32_t); + } + if (std::is_same::value || + std::is_same::value) { + tile_size += TILE_N * 4; + } + if (std::is_same::value) { + tile_size += TILE_N * 2; + } + return tile_size; +} + +template +int get_row_size(int K) { + int KB = K / BLOCK_K; + int row_size = KB * sizeof(TB); + if (do_compensate::value) { + row_size += KB * sizeof(int32_t); + } + if (std::is_same::value || + std::is_same::value) { + row_size += KB * 4; + } + if (std::is_same::value) { + row_size += KB * 2; + } + return row_size; +} + +// vectorized dtype conversion +inline float FP16_TO_FP32(lm_ggml_half val) { + __m256i v = _mm256_setr_epi16( + val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); + __m512 o = _mm512_cvtph_ps(v); + return _mm512_cvtss_f32(o); +} + +inline __m512 FP16_TO_FP32_VEC(lm_ggml_half val) { + __m256i v = _mm256_set1_epi16(val); + return _mm512_cvtph_ps(v); +} + +// horizontal reduce +inline float _mm512_reduce_max_ps(const __m512 x) { + __m512 v = x; + __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_f32x4(v, v, 0xB1); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_ps(v, v, 0x4E); + v = _mm512_max_ps(v, v1); + v1 = _mm512_shuffle_ps(v, v, 0xB1); + v = _mm512_max_ps(v, v1); + return _mm512_cvtss_f32(v); +} + +// transpose utils +#define SHUFFLE_EPI32(a, b, mask) \ + _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) +inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) { + // unpacking and 32-bit elements + v1[0] = _mm256_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm256_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm256_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm256_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm256_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm256_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm256_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm256_unpackhi_epi32(v[6], v[7]); + + // shuffling the 32-bit elements + v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44); + v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee); + v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44); + v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee); + v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44); + v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee); + v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44); + v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee); + + // shuffling 128-bit elements + v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02); + v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02); + v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02); + v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02); + v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13); + v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13); + v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13); + v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13); +} + +inline void transpose_16x4_32bit(__m512i * r, __m512i * d) { + + static const __m512i index1 = _mm512_set_epi32( + 0x0f, 0x0b, 0x07, 0x03, + 0x0e, 0x0a, 0x06, 0x02, + 0x0d, 0x09, 0x05, 0x01, + 0x0c, 0x08, 0x04, 0x00); + + d[0] = _mm512_permutexvar_epi32(index1, r[0]); + d[1] = _mm512_permutexvar_epi32(index1, r[1]); + d[2] = _mm512_permutexvar_epi32(index1, r[2]); + d[3] = _mm512_permutexvar_epi32(index1, r[3]); + + r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44); + r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee); + r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44); + r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee); + + d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88); + d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd); + d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88); + d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd); +} + +inline void transpose_16x16_32bit(__m512i * v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) { + assert(k % QK_K == 0); + const int KB = k / QK_K; + constexpr int kVecs = QK_K / 16; + + block_q8_K * y = reinterpret_cast(vy); + + // hold 16 float vecs from x + __m512 v[kVecs]; + + // hold the quants vecs + __m512i vq[kVecs / 4]; + + // hold the packed quants vecs + __m512i vq_packed[kVecs / 4]; + + const __m512 signBit = _mm512_set1_ps(-0.f); + + for (int i = 0; i < KB; ++i) { + // Compute max(abs(e)) for the block + __m512 vamax = _mm512_set1_ps(0.f); + for (int j = 0; j < kVecs; ++j) { + v[j] = _mm512_loadu_ps(x); x += 16; + vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j])); + } + const float amax = _mm512_reduce_max_ps(vamax); + + // Quantize these floats + const float iscale = 127.f / amax; + y[i].d = LM_GGML_FP32_TO_FP16(1 / iscale); + const float id = ( amax != 0.0f ) ? iscale : 0.f; + const __m512 vscale = _mm512_set1_ps(id); + + // Apply multiplier and round to nearest integer + for (int j = 0; j < kVecs; ++j) { + v[j] = _mm512_mul_ps(v[j], vscale); + v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + } + + // Pack to epi8 vecs + for (int j = 0; j < kVecs / 4; ++j) { + __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0])); + __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1])); + __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2])); + __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3])); + + __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1); + __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1); + + vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1); + _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]); + } + + // Compute the bsums with vnni + transpose_16x4_32bit(vq, vq_packed); + + const __m512i one = _mm512_set1_epi8(1); + __m512i sum = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]); + } + _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum)); + } +} + +// quantize A from float to `vec_dot_type` +template +inline void from_float(const float * x, char * vy, int64_t k); + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { + quantize_row_q8_0(x, (block_q8_0 *)vy, k); +} + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { + quantize_row_q8_1(x, (block_q8_1 *)vy, k); +} + +template <> +inline void from_float(const float * x, char * vy, int64_t k) { +#if 1 + // TODO: this is reference impl! + quantize_row_q8_K_ref(x, (block_q8_K *)vy, k); +#else + quantize_row_q8_K_vnni(x, vy, k); +#endif +} + +// load A from memory to array when nrows can not fill in whole tile +void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) { + assert(nr != TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) { + assert(nr != TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +template +void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { + assert(nr <= TILE_M); + for (int m = 0; m < nr; ++m) { + const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32)); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v); + } +} + +template <> +void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) { + assert(nr <= TILE_M); + // zero padding k from 16 to 32, so that we don't have to re-config amx + const __m128i zero = _mm_setzero_si128(); + for (int m = 0; m < nr; ++m) { + const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16)); + const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1); + _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r); + } +} + +#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) +inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) { + const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi); + const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp); + const __m256i lowMask = _mm256_set1_epi8(0xF); + return _mm256_and_si256(lowMask, bytes); +} + +// used for block_q4_K +inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) { + const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi); + const __m256i lowMask = _mm256_set1_epi8(0xF); + const __m256i q4l = _mm256_and_si256(tmp, lowMask); + const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask); + return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1); +} + +// used for block_q5_K +inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) { + const __m256i lowMask = _mm256_set1_epi8(0xF); + __m256i hmask = _mm256_set1_epi8(1); + hmask = _mm256_slli_epi16(hmask, k); + + const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs); + const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh); + + const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask); + const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4); + const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0); + hmask = _mm256_slli_epi16(hmask, 1); + + const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask); + const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4); + const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1); + + return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1); +} + +// used for block_q6_K +inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) { + const __m256i m4 = _mm256_set1_epi8(0xF); + const __m256i m2 = _mm256_set1_epi8(0x3); + + const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs); + const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32)); + const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh); + + const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256( q6bitsH, m2), 4); + const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4); + const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4); + const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4); + + const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0); + const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1); + const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2); + const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3); + + r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1); + r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1); +} + +inline __m512i packNibbles(__m512i r0, __m512i r1) { + return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4)); +} + +template +inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) { + int8_t tmp[8 * 64]; + __m256i v[8], v2[8]; + for (int n = 0; n < 8; ++n) { + v[n] = bytes_from_nibbles_32(B[n * KB].qs); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]); + } + for (int n = 0; n < 8; ++n) { + v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]); + } + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < 8; n += 2) { + __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64)); + __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64)); + __m512i r1r0 = packNibbles(r0, r1); + _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0); + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { + __m256i v[8], v2[8]; + for (int n = 0; n < 8; ++n) { + v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs)); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]); + } + for (int n = 0; n < 8; ++n) { + v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs)); + } + transpose_8x8_32bit(v, v2); + for (int n = 0; n < 8; ++n) { + _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]); + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { + __m512i v[16]; + // QK_K 256 with 8 groups, handle 2 groups at a time + char * pb = (char *)packed_B; + for (int k = 0; k < QK_K / 64; ++k) { + // pack 2 groups { n, g, k} to {g, k/4, 4n} + // e.g. {16, 2, 32} to {2, 8, 64} + for (int n = 0; n < TILE_N; ++n) { + v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32); + } + + transpose_16x16_32bit(v); + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < TILE_N; n += 2) { + _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); + pb += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { + __m512i v[16]; + const __m512i lowMask = _mm512_set1_epi8(0xF); + // QK_K 256 with 8 groups, handle 2 groups at a time + char * pb = (char *)packed_B; + char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; + for (int k = 0; k < QK_K / 64; ++k) { + // pack 2 groups { n, g, k} to {g, k/4, 4n} + // e.g. {16, 2, 32} to {2, 8, 64} + for (int n = 0; n < TILE_N; ++n) { + v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k); + } + + transpose_16x16_32bit(v); + + // 1. pack lower 4bits with 2 groups + for (int n = 0; n < TILE_N; n += 2) { + // get lower 4 bits + const __m512i r0 = _mm512_and_si512(v[n], lowMask); + const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); + _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; + } + + // 2. pack higher 1bit with 2 groups + const __m512i hmask = _mm512_set1_epi8(0x10); + for (int g = 0; g < 2; ++g) { + __m512i hbits = _mm512_setzero_si512(); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1)); + hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 8 + 4], hmask) ); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1)); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3)); + _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { + __m512i v[32]; + const __m512i lowMask = _mm512_set1_epi8(0xF); + // QK_K 256 with 8 groups, handle 4 groups at a time + char * pb = (char *)packed_B; + char * ph = (char *)packed_B + (QK_K / 2) * TILE_N; + for (int k = 0; k < QK_K / 128; ++k) { + for (int n = 0; n < TILE_N; ++n) { + bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32); + } + + // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7 + transpose_16x16_32bit(v); + transpose_16x16_32bit(v + 16); + + // 1. pack lower 4bits with 4 groups + for (int n = 0; n < 32; n += 2) { + const __m512i r0 = _mm512_and_si512(v[n], lowMask); + const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask); + _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64; + } + + // 2. pack higher 2bit with 4 groups + const __m512i hmask = _mm512_set1_epi8(0x30); + for (int g = 0; g < 8; ++g) { + __m512i hbits = _mm512_setzero_si512(); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4)); + hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2)); + hbits = _mm512_add_epi8(hbits, _mm512_and_si512(v[g * 4 + 2], hmask) ); + hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2)); + _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64; + } + } +} + +template <> +inline void pack_qs(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { + __m512i v[16]; + char * pb = (char *)packed_B; + for (int k = 0; k < QK_K / 64; ++k) { + for (int n = 0; n < TILE_N; ++n) { + __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 0); + __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16); + v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1); + } + + transpose_16x16_32bit(v); + + // pack again with 128 to fully utilize vector length + for (int n = 0; n < TILE_N; n += 2) { + _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1])); + pb += 64; + } + } +} + +// pack B to vnni formats in 4bits or 8 bits +void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + lm_ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2); + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + } +} + +void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + lm_ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2); + lm_ggml_half * m0 = d0 + TILE_N; + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + m0[n] = B[n * KB].m; + } +} + +inline void s8s8_compensation(void * RESTRICT packed_B) { + // packed_B layout: + // quants {TILE_N, TILEK} int8_t + // d0 {TILE_N} lm_ggml_half + // comp {TILE_N} int32_t + const int offset = TILE_N * TILE_K + TILE_N * sizeof(lm_ggml_half); + __m512i vcomp = _mm512_setzero_si512(); + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < 8; ++k) { + __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64)); + vcomp = _mm512_dpbusd_epi32(vcomp, off, vb); + } + _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp); +} + +void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + lm_ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K); + for (int n = 0; n < TILE_N; ++n) { + d0[n] = B[n * KB].d; + } + s8s8_compensation(packed_B); +} + +// convert 8 * {min, scale} from int6 to int8 +inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) { + const uint32_t kmask1 = 0x3f3f3f3f; + const uint32_t kmask2 = 0x0f0f0f0f; + const uint32_t kmask3 = 0x03030303; + + memcpy(utmp, scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// scales {8, TILE_N} uint8 +// mins {8, TILE_N} uint8 +// d {TILE_N} lm_ggml_half +// dmin {TILE_N} lm_ggml_half +void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N); + uint8_t * mins = scales + 8 * TILE_N; + lm_ggml_half * d = reinterpret_cast(mins + 8 * TILE_N); + lm_ggml_half * dmin = d + TILE_N; + + union { + uint32_t u32[4]; + uint8_t u8[16]; + } s; + + for (int n = 0; n < TILE_N; ++n) { + unpack_mins_and_scales(B[n * KB].scales, s.u32); + for (int k = 0; k < 8; ++k) { + scales[k * TILE_N + n] = s.u8[k]; + mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; + } + d[n] = B[n * KB].d; + dmin[n] = B[n * KB].dmin; + } +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// qh {8, TILE_N, 4} uint8 +// scales {8, TILE_N} uint8 +// mins {8, TILE_N} uint8 +// d {TILE_N} lm_ggml_half +// dmin {TILE_N} lm_ggml_half +void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); + uint8_t * mins = scales + 8 * TILE_N; + lm_ggml_half * d = reinterpret_cast(mins + 8 * TILE_N); + lm_ggml_half * dmin = d + TILE_N; + + union { + uint32_t u32[4]; + uint8_t u8[16]; + } s; + + for (int n = 0; n < TILE_N; ++n) { + unpack_mins_and_scales(B[n * KB].scales, s.u32); + for (int k = 0; k < 8; ++k) { + scales[k * TILE_N + n] = s.u8[k]; + mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8]; + } + d[n] = B[n * KB].d; + dmin[n] = B[n * KB].dmin; + } +} + +// packed_B layout: +// quants {16, TILE_N, 8} uint8 +// qh {16, TILE_N, 4} uint8 +// scales {16, TILE_N} uint8 +// d {TILE_N} lm_ggml_half +void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); + lm_ggml_half * d = reinterpret_cast(scales + 16 * TILE_N); + for (int n = 0; n < TILE_N; ++n) { + const int8_t * ps = B[n * KB].scales; + for (int k = 0; k < 16; ++k) { + scales[k * TILE_N + n] = ps[k]; + } + d[n] = B[n * KB].d; + } +} + +// packed_B layout: +// quants {8, TILE_N, 16} uint8 +// scales {8, TILE_N} int8 +// d {TILE_N} lm_ggml_half +void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) { + pack_qs(packed_B, B, KB); + + int8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N); + lm_ggml_half * d = reinterpret_cast(scales + 8 * TILE_N); + + // pack the scales + for (int n = 0; n < TILE_N; ++n) { + uint16_t sh = B[n * KB].scales_h; + for (int k = 0; k < 8; k += 2) { + const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32; + const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >> 4) | ((sh << 2) & 0x30)) - 32; + scales[(k + 0) * TILE_N + n] = ls1; + scales[(k + 1) * TILE_N + n] = ls2; + sh >>= 4; + } + d[n] = B[n * KB].d; + } +} + +template> +void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) { + LM_GGML_UNUSED(tile); + LM_GGML_UNUSED(packed_B); +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B) { + const __m512i off = _mm512_set1_epi8(8); + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); + const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off); + const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) { + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32)); + const __m512i r0 = _mm512_and_si512(bytes, lowMask); + const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +// packed_B_t for QKK is int8_t +template +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + const int packed_B_group_size = QK_K / 2 * TILE_N / 8; + const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size; + const __m512i lowMask = _mm512_set1_epi8(0xF); + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32); + const __m512i r0 = _mm512_and_si512(bytes, lowMask); + const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + // lower 4bits, stride 256 bytes + const int packed_l4_group_size = QK_K / 2 * TILE_N / 8; + const char * pb = (const char *)packed_B + k * packed_l4_group_size; + + // higher 1bit, stride 64 bytes + const int packed_h1_group_size = QK_K / 8 * TILE_N / 8; + const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size; + const __m512i hbits = _mm512_loadu_si512(ph); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + __m512i hmask0 = _mm512_set1_epi8(0x1); + __m512i hmask1 = _mm512_set1_epi8(0x2); + + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(pb + n * 32); + __m512i r0 = _mm512_and_si512(bytes, lowMask); + __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4); + __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4); + + hmask0 = _mm512_slli_epi16(hmask0, 2); + hmask1 = _mm512_slli_epi16(hmask1, 2); + r0 = _mm512_add_epi8(r0, h0); + r1 = _mm512_add_epi8(r1, h1); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + // lower 4bits, stride 128 bytes + const int packed_l4_group_size = QK_K / 2 * TILE_N / 16; + const char * pb = (const char *)packed_B + k * packed_l4_group_size; + + // higher 2bits, stride 64 bytes + const int packed_h2_group_size = QK_K / 4 * TILE_N / 16; + const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size; + const __m512i hbits = _mm512_loadu_si512(ph); + + const __m512i off = _mm512_set1_epi8(32); + const __m512i lowMask = _mm512_set1_epi8(0xF); + __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011 + __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100 + + // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A` + __m512i bytes = _mm512_loadu_si512(pb); + __m512i r0 = _mm512_and_si512(bytes, lowMask); + __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4); + __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2); + _mm512_storeu_si512((__m512i *)(tile + 0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); + _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); + + hmask0 = _mm512_slli_epi16(hmask0, 4); + hmask1 = _mm512_slli_epi16(hmask1, 4); + + bytes = _mm512_loadu_si512(pb + 64); + r0 = _mm512_and_si512(bytes, lowMask); + r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + h0 = _mm512_and_si512(hbits, hmask0); + h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2); + _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off)); + _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off)); +} + +template <> +void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) { + static const __m512i values128 = _mm512_set_epi8( + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 + ); + + const int packed_B_group_size = QK_K / 2 * TILE_N / 8; + const char * pb = (const char *)packed_B + k * packed_B_group_size; + const __m512i lowMask = _mm512_set1_epi8(0xF); + + for (int n = 0; n < 8; n += 2) { + __m512i bytes = _mm512_loadu_si512(pb + n * 32); + const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask)); + const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 0), r0); + _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1); + } +} + +template +struct acc_C {}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(lm_ggml_half)))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vs1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[m * lda].s)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + vsum = _mm512_fmadd_ps(vm0, vs1, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) { + const int offset = TILE_N * TILE_K; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset))); + + for (int m = 0; m < nr; ++m) { + const __m512 vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[m * lda].d)); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N); + const uint8_t * mins = scales + 8 * TILE_N; + const lm_ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N); + const lm_ggml_half * dmin = d0 + TILE_N; + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N); + const uint8_t * mins = scales + 8 * TILE_N; + const lm_ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N); + const lm_ggml_half * dmin = d0 + TILE_N; + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s)); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N); + const lm_ggml_half * d0 = reinterpret_cast(scales + 16 * TILE_N); + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template +struct acc_C { + static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) { + const int8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N); + const lm_ggml_half * d0 = reinterpret_cast(scales + 8 * TILE_N); + + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0)); + + for (int m = 0; m < nr; ++m) { + const float d1 = A[m * lda].d; + const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0); + const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N)); + + __m512 vsum; + if (is_acc) { + vsum = _mm512_loadu_ps(C + m * ldc); + } else { + vsum = _mm512_set1_ps(0.f); + } + + vsum = _mm512_fmadd_ps(vtile, vd, vsum); + _mm512_storeu_ps(C + m * ldc, vsum); + } + } +}; + +template constexpr int get_quants_size(); +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; } +template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; } + +// used for QKK format +template ::value, int>::type = 0> +inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) { + const uint8_t * scales = reinterpret_cast((const char *)packed_B + get_quants_size()); + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N))); + + for (int m = 0; m < nr; ++m) { + __m512i vsumi; + if (is_acc) { + vsumi = _mm512_loadu_si512(sumi + m * TILE_N); + } else { + vsumi = _mm512_setzero_si512(); + } + __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N); + vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale)); + _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi); + } +} + +template +struct tinygemm_kernel_avx { + static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) { + LM_GGML_UNUSED(K); + LM_GGML_UNUSED(A); + LM_GGML_UNUSED(B); + LM_GGML_UNUSED(C); + LM_GGML_UNUSED(ldc); + } +}; + +template +struct tinygemm_kernel_avx { + static void apply(int K, const float * RESTRICT A, const lm_ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N; + assert(BLOCK_K == 16); + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto idx) { + vc[idx] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto idx, auto k) { + constexpr int row = idx / COLS; + constexpr int col = idx % COLS; + + if constexpr (col == 0) { + va = _mm512_loadu_ps(A + row * K + k); + } + if constexpr (row == 0) { + vb[col] = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k))); + } + vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]); + }; + + for (int k = 0; k < K; k += 16) { + Unroll{}(compute, k); + } + + auto storec = [&](auto idx) { + constexpr int row = idx / COLS; + constexpr int col = idx % COLS; + C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]); + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_avx::apply( \ + K, (const float *)src1->data + mb_start * K, \ + (const type *)src0->data + nb_start * K, \ + (float *)dst->data + mb_start * ldc + nb_start, ldc); + + +// re-organize in the format {NB, KB, TILE_SIZE}: +#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size + +template +void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K) { + const int NB = N / TILE_N; + const int KB = K / BLOCK_K; + const int TILE_SIZE = get_tile_size(); + + // parallel on NB should be enough + parallel_for(NB, [&](int begin, int end) { + for (int n = begin; n < end; ++n) { + for (int k = 0; k < KB; ++k) { + int n0 = n * TILE_N; + pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB); + } + } + }); +} + +template +struct tinygemm_kernel_vnni {}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_0); + + const block_q8_0 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512 vc[COLS]; + __m512 vd1; + + // sum of offsets, shared across COLS + // + // avx512-vnni does not have `_mm512_dpbssd_epi32`, + // need to transfrom ss to us: + // a * (b - 8) is equavilent to b * a - 8 * a + // s u u u s u s + // + __m512i vcomp; + + const __m512i off = _mm512_set1_epi8(8); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + // load a and compute compensation + if constexpr (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + vcomp = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]); + } + vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[0 * KB + i].d)); + } + + // load b + __m512i vsum = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; k += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]); + } + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + vsum = _mm512_sub_epi32(vsum, vcomp); + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_1); + + const block_q8_1 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512i vb[8]; + __m512 vc[COLS]; + __m512 vd1, vs1; + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + // load a + if constexpr (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + } + vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[0 * KB + i].d)); + vs1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[0 * KB + i].s)); + } + + // load b + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; k += 2) { + __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32)); + vb[k + 0] = _mm512_and_si512(bytes, lowMask); + vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + } + const int offset = TILE_N * TILE_K / 2; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(lm_ggml_half)))); + + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]); + } + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t); + + const block_q8_0 * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + __m512i va[8]; + __m512i vb[8]; + __m512 vc[COLS]; + __m512 vd1; + + // Notes: s8s8 igemm compensation in avx512-vnni + // change s8s8 to u8s8 with compensate + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // (128 * b is pre-computed when packing B to vnni formats) + // + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + // load a and add offset 128 + if constexpr (col == 0) { + const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs); + for (int k = 0; k < 8; ++k) { + va[k] = _mm512_set1_epi32(a_ptr[k]); + va[k] = _mm512_add_epi8(va[k], off); + } + vd1 = _mm512_set1_ps(LM_GGML_FP16_TO_FP32(A[0 * KB + i].d)); + } + + // load b + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + for (int k = 0; k < 8; ++k) { + vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64)); + } + const int offset = TILE_N * TILE_K; + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset))); + const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(lm_ggml_half); + const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2)); + + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; ++k) { + vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]); + } + vsum = _mm512_sub_epi32(vsum, vcomp); + + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // a.qs: 8 groups, 32 bytes each group (m256i) + __m512i va[8]; + // a.bsum: 8 groups, 2 bytes each group (m128i) + __m512i va_bsum; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_scales = (QK_K / 2) * TILE_N; + const int offset_mins = (QK_K / 2) * TILE_N + 8 * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + 16 * TILE_N; + const int offset_dmin = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(lm_ggml_half); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + // Notes: vnni formats in QK_K + // a) quants vnni format + // int8 {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32 + // from {16, 32} to {8, 64} + // + // b) min vnni format + // int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8 + // from {16, 8} to {4, 32} + // + auto compute = [&](auto col, auto i) { + // load a + if constexpr (col == 0) { + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); + } + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + va_bsum = _mm512_castsi128_si512(q8s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // step 1: accumultate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); + + __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + + b_qs += 64; + } + // vacc += scale * (q8 @ q4) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + + // step 2: accumulate the mins + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); + vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // a.qs: 8 groups, 32 bytes each group (m256i) + __m512i va[8]; + // a.bsum: 8 groups, 2 bytes each group (m128i) + __m512i va_bsum; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_qh = (QK_K / 2) * TILE_N; + const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; + const int offset_mins = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 8 * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N; + const int offset_dmin = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(lm_ggml_half); + + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + // Q5_K and Q4_K shares the same vnni formats, refer to notes above. + auto compute = [&](auto col, auto i) { + // load a + if constexpr (col == 0) { + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32))); + } + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1)); + va_bsum = _mm512_castsi128_si512(q8s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // step 1: accumultate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + const char * b_qh = b_ptr + offset_qh; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + __m512i vsum = _mm512_setzero_si512(); + __m512i hmask0 = _mm512_set1_epi8(0x1); + __m512i hmask1 = _mm512_set1_epi8(0x2); + __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64)); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]); + + __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + + __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4); + __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4); + + hmask0 = _mm512_slli_epi16(hmask0, 2); + hmask1 = _mm512_slli_epi16(hmask1, 2); + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + + b_qs += 64; + } + // vacc += scale * (q8 @ q5) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + + // step 2: accumulate the mins + __m512i acc_m = _mm512_setzero_si512(); + for (int k = 0; k < 4; ++k) { + __m512i vmask = _mm512_set1_epi32(k); + __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum); + __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32))); + acc_m = _mm512_dpwssds_epi32(acc_m, va, vb); + } + const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin))); + vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_q6_K); + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // load the 256 bytes from A to 4 avx512 vectors + __m512i va[4]; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_qh = (QK_K / 2) * TILE_N; + const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; + const int offset_d0 = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N; + + // compensation + __m512i vcomp; + + const __m512i m32s = _mm512_set1_epi32(32); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + if constexpr (col == 0) { + // load a + va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); + va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); + va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); + va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); + + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // accmulate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + const char * b_qh = b_ptr + offset_qh; + int mask = 0; + for (int k_group = 0; k_group < QK_K / 16; ++k_group) { + int r = k_group >> 2; + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + __m512i vsum = _mm512_setzero_si512(); + __m512i hmask = _mm512_set1_epi8(0x3); + + __m512i bytes = _mm512_loadu_si512(b_qs); + __m512i hbits = _mm512_loadu_si512(b_qh); + __m512i vb0 = _mm512_and_si512(bytes, lowMask); + __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4); + __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2); + + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + + va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + bytes = _mm512_loadu_si512(b_qs); + vb0 = _mm512_and_si512(bytes, lowMask); + vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask); + vh0 = _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4)); + vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2); + vb0 = _mm512_add_epi8(vb0, vh0); + vb1 = _mm512_add_epi8(vb1, vh1); + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + b_qh += 64; + + // B * A - 32 * A + __m512i vmask = _mm512_set1_epi32(k_group); + vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); + + // vacc += scale * (q8 @ q6) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](int col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +template +struct tinygemm_kernel_vnni { + static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + + constexpr int COLS = BLOCK_N / 16; + const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2; + + const block_q8_K * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + // load the 256 bytes from A to 4 avx512 vectors + __m512i va[4]; + __m512 vc[COLS]; + __m512 vd1; + + // packed_B: + const int offset_scales = (QK_K / 2) * TILE_N ; + const int offset_d0 = (QK_K / 2) * TILE_N + 8 * TILE_N; + + // compensation + __m512i vcomp; + + const __m256i m128s = _mm256_set1_epi16(128); + const __m512i lowMask = _mm512_set1_epi8(0xF); + + const __m512i values128 = _mm512_set_epi8( + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127, + 113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127 + ); + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + const __m512i values256 = _mm512_add_epi8(values128, off); + + auto loadc = [&](auto col) { + vc[col] = _mm512_setzero_ps(); + }; + Unroll{}(loadc); + + auto compute = [&](auto col, auto i) { + if constexpr (col == 0) { + // load a + va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 0)); + va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 64)); + va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128)); + va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192)); + + // compensation: 128 * A + const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums); + vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s)); + vd1 = _mm512_set1_ps(A[0 * KB + i].d); + } + + // accmulate the quants + __m512i acc = _mm512_setzero_si512(); + const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE); + const char * b_qs = b_ptr; + int mask = 0; + for (int k_group = 0; k_group < QK_K / 32; ++k_group) { + int r = k_group >> 1; + __m512i vmask = _mm512_set1_epi32(k_group); + __m512i vsum = _mm512_setzero_si512(); + for (int k = 0; k < 8; k += 2) { + __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]); + + __m512i bytes = _mm512_loadu_si512(b_qs); + __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask)); + __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask)); + + vsum = _mm512_dpbusd_epi32(vsum, vb0, va0); + vsum = _mm512_dpbusd_epi32(vsum, vb1, va1); + b_qs += 64; + } + // (B + 128) * A - 128 * A + vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp)); + + // vacc += scale * (q8 @ q4) + const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N))); + acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale)); + } + const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0))); + vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]); + }; + + for (int i = 0; i < KB; ++i) { + Unroll{}(compute, i); + } + + //store to C + auto storec = [&](auto col) { + _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]); + }; + Unroll{}(storec); + } +}; + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + KB, (const char *)wdata + 0 * row_size_A, \ + (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ + (float *) dst->data + 0 * N + nb_start, ldc) + +template ::value, int>::type = 0> +void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) { + using packed_B_t = packed_B_type; + const int TILE_SIZE = get_tile_size(); + const bool need_unpack = do_unpack::value; + + LM_GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); + const TA * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + const int m0 = std::min(M, TILE_M); + const int m1 = std::max(M - TILE_M, 0); + const int lda = KB * sizeof(TA); + //const int ldb = KB * sizeof(TB); + + static thread_local packed_B_t Tile0[TILE_N * TILE_K]; + static thread_local packed_B_t Tile1[TILE_N * TILE_K]; + static thread_local int8_t Tile23[TILE_M * TILE_K]; + + static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; + static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; + + // double buffering C to interleave avx512 and amx + int32_t * C_cur = TileC0; + int32_t * C_pre = TileC1; + + auto Tile4 = [&](int32_t * base) { return base; }; + auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; }; + auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; }; + auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; }; + + if (M == 2 * TILE_M) { + // i = 0 + const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE); + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + + _tile_zero(TMM4); + _tile_loadd(TMM2, A[0].qs, lda); + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM5); + _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda); + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); + + if (need_unpack) { + unpack_B(Tile1, B_blk0); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + + _tile_zero(TMM6); + _tile_dpbssd(TMM6, TMM2, TMM1); + _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM7); + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t)); + + for (int i = 1; i < KB; ++i) { + // index of previous iter + const int ii = i - 1; + const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); + LM_GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] { + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + _tile_zero(TMM4); + _tile_loadd(TMM2, A[i].qs, lda); + acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM5); + _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda); + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); + + if (need_unpack) { + unpack_B(Tile1, B_blk1); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + _tile_zero(TMM6); + acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM6, TMM2, TMM1); + _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); + + _tile_zero(TMM7); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); + + std::swap(C_cur, C_pre); + }); + } + // final accumulation + { + int ii = KB - 1; + acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M); + } + } else { + for (int i = 0; i < KB; ++i) { + _tile_zero(TMM4); + _tile_zero(TMM6); + if (m1 != 0) { + _tile_zero(TMM5); + _tile_zero(TMM7); + } + + const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE); + const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE); + if (need_unpack) { + unpack_B(Tile0, B_blk0); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK); + } + + if (need_unpack) { + unpack_B(Tile1, B_blk1); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + } else { + _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); + } + + if (m0 == TILE_M) { + _tile_loadd(TMM2, A[i].qs, lda); + } else { + unpack_A(Tile23, &A[i], KB, m0); + _tile_loadd(TMM2, Tile23, TILE_K); + } + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_dpbssd(TMM6, TMM2, TMM1); + + _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t)); + _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t)); + + LM_GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C, ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); + acc_C::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); + }); + + if (m1 != 0) { + unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1); + _tile_loadd(TMM3, Tile23, TILE_K); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_dpbssd(TMM7, TMM3, TMM1); + _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t)); + _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t)); + LM_GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); + }); + } + } + } + return; +} + +template ::value, int>::type = 0> +void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) { + static_assert(std::is_same::value); + const int TILE_SIZE = get_tile_size(); + + LM_GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N); + const TA * RESTRICT A = static_cast(_A); + const char * RESTRICT B = static_cast(_B); + + const int m0 = std::min(M, TILE_M); + const int m1 = std::max(M - TILE_M, 0); + //const int lda = KB * sizeof(TA); + + static thread_local int8_t Tile0[TILE_N * TILE_K]; + static thread_local int8_t Tile1[TILE_N * TILE_K]; + static thread_local int8_t Tile23[TILE_M * TILE_K]; + + // mat mul result for each group + static thread_local int32_t Tile4[TILE_M * TILE_N]; + static thread_local int32_t Tile5[TILE_M * TILE_N]; + static thread_local int32_t Tile6[TILE_M * TILE_N]; + static thread_local int32_t Tile7[TILE_M * TILE_N]; + + // sum of each QK_K block, contains 8 groups, int32 + static thread_local int32_t Sumi4[TILE_M * TILE_N]; + static thread_local int32_t Sumi5[TILE_M * TILE_N]; + static thread_local int32_t Sumi6[TILE_M * TILE_N]; + static thread_local int32_t Sumi7[TILE_M * TILE_N]; + + const int k_group_size = std::is_same::value ? 16 : 32; + for (int i = 0; i < KB; ++i) { + // step 1: accumulate the quants across 8 groups, each group with 32 + for (int k = 0; k < QK_K / k_group_size; ++k) { + LM_GGML_DISPATCH_BOOL(k > 0, is_acc, [&] { + _tile_zero(TMM4); + _tile_zero(TMM6); + + unpack_B(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k); + _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK); + + unpack_B(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k); + _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); + + unpack_A(Tile23, &A[i], KB, k, m0); + _tile_loadd(TMM2, Tile23, TILE_K); + + _tile_dpbssd(TMM4, TMM2, TMM0); + _tile_dpbssd(TMM6, TMM2, TMM1); + + _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t)); + _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t)); + + scale_C(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0); + scale_C(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0); + + if (m1 != 0) { + _tile_zero(TMM5); + _tile_zero(TMM7); + + unpack_A(Tile23, &A[TILE_M * KB + i], KB, k, m1); + _tile_loadd(TMM3, Tile23, TILE_K); + + _tile_dpbssd(TMM5, TMM3, TMM0); + _tile_dpbssd(TMM7, TMM3, TMM1); + + _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t)); + _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t)); + + scale_C(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1); + scale_C(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1); + } + }); + } + + // step 2: accmulate the mins + LM_GGML_DISPATCH_BOOL(i > 0, is_acc, [&] { + acc_C::apply(C, ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0); + acc_C::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0); + if (m1 != 0) { + acc_C::apply(C + TILE_M * ldc, ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1); + acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1); + } + }); + } + return; +} + +} // anonymous namespace + +// get the packed tensor size for quantized weights +size_t lm_ggml_backend_amx_get_alloc_size(const struct lm_ggml_tensor * tensor) { + const enum lm_ggml_type TYPE = tensor->type; + + const int K = tensor->ne[0]; // ne0: in_features + const int N = tensor->ne[1]; // ne1: out_features + + auto get_tensor_size = [&] { + size_t row_size_B{0}; + LM_GGML_DISPATCH_QTYPES(TYPE, [&] { + row_size_B = get_row_size(K); + }); + return N * row_size_B; + }; + + if (qtype_has_amx_kernels(TYPE)) { + return get_tensor_size(); + } else { + // for f16, bf16 we don't do packing + return lm_ggml_nbytes(tensor); + } +} + +// pack weight to vnni format +void lm_ggml_backend_amx_convert_weight(struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + LM_GGML_ASSERT(offset == 0 && size == lm_ggml_nbytes(tensor)); // only full tensor conversion is supported for now + + const enum lm_ggml_type TYPE = tensor->type; + + const int K = tensor->ne[0]; // ne0: in_features + const int N = tensor->ne[1]; // ne1: out_features + + LM_GGML_DISPATCH_QTYPES(TYPE, [&] { + convert_B_packed_format((void *)((char *)tensor->data + offset), (const type *)data, N, K); + }); +} + +size_t lm_ggml_backend_amx_desired_wsize(const struct lm_ggml_tensor * dst) { + struct lm_ggml_tensor * src0 = dst->src[0]; + + const enum lm_ggml_type TYPE = src0->type; + + const bool is_floating_type = TYPE == LM_GGML_TYPE_F16; + if (is_floating_type) { + return 0; + } + + const int M = dst->ne[1]; + const int K = src0->ne[0]; + + size_t desired_wsize = 0; + + LM_GGML_DISPATCH_QTYPES(TYPE, [&] { + const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); + desired_wsize = M * row_size_A; + }); + + return desired_wsize; +} + +// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX) +// +// src0: weight in shape of {N, K}, quantized +// src1: input in shape of {M, K}, float32 +// dst: output in shape of {M, N}, float32 +// +// the function performs: dst = src1 @ src0.T +// +void lm_ggml_backend_amx_mul_mat(const lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { + struct lm_ggml_tensor * src0 = dst->src[0]; + struct lm_ggml_tensor * src1 = dst->src[1]; + + const enum lm_ggml_type TYPE = src0->type; + + // f16 only has avx512 kernels for now, + // amx kernels will be added once 6th gen xeon is released. + const bool is_floating_type = TYPE == LM_GGML_TYPE_F16; + + const int M = dst->ne[1]; + const int N = dst->ne[0]; + const int K = src0->ne[0]; + const int ldc = dst->nb[1] / dst->nb[0]; + + if (is_floating_type) { + constexpr int BLOCK_M = 4; + constexpr int BLOCK_N = 6; + const int MB = div_up(M, BLOCK_M); + const int NB = div_up(N, BLOCK_N); + + parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + LM_GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { + for (int i = begin; i < end; ++i) { + int mb = i / NB; + int nb = i % NB; + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break; + case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break; + case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break; + case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break; + case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break; + default: fprintf(stderr, "Unexpected block size!\n"); + } + } + }); + }); + return; + } + + // pointer to work space, used convert A from float to quantized type + void * wdata = params->wdata; + + //TODO: performance improvement: merge quant A + if (params->ith == 0) { + LM_GGML_DISPATCH_QTYPES(TYPE, [&] { + const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); + const size_t desired_wsize = M * row_size_A; + if (params->wsize < desired_wsize) { + LM_GGML_ABORT("insufficient work space size"); + } + + // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size + // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size + LM_GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); + + const float * A_data = static_cast(src1->data); + for (int m = 0; m < M; ++m) { + from_float(A_data + m * K, (char *)wdata + m * row_size_A, K); + } + }); + } + + lm_ggml_barrier(params->threadpool); + + if (M == 1) { + // MB = 1 and handle 8 tiles in each block + constexpr int kTilesN = 4; + constexpr int BLOCK_N = TILE_N * kTilesN; + const int NB = div_up(N, BLOCK_N); + + parallel_for_ggml(params, NB, [&](int begin, int end) { + LM_GGML_DISPATCH_QTYPES(TYPE, [&] { + const int KB = K / blck_size; + const int TILE_SIZE = get_tile_size(); + const int row_size_A = KB * sizeof(vec_dot_type); + for (int i = begin; i < end; ++i) { + int nb = i; + int nb_start = nb * BLOCK_N; + int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 + + switch (nb_size) { + //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break; + case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break; + case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break; + case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break; + case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break; + default: fprintf(stderr, "Unexpected n block size!\n"); + } + } + }); + }); + return; + } + + // handle 4 tiles at a tile + constexpr int BLOCK_M = TILE_M * 2; + constexpr int BLOCK_N = TILE_N * 2; + const int MB = div_up(M, BLOCK_M); + const int NB = div_up(N, BLOCK_N); + + parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + // init tile config for each thread + lm_ggml_tile_config_init(); + + LM_GGML_DISPATCH_QTYPES(TYPE, [&] { + const int KB = K / blck_size; + const int TILE_SIZE = get_tile_size(); + const int row_size_A = KB * sizeof(vec_dot_type); + + for (int i = begin; i < end; ++i) { + int mb = i / NB; + int nb = i % NB; + + int mb_start = mb * BLOCK_M; + int mb_size = std::min(BLOCK_M, M - mb_start); + int nb_start = nb * BLOCK_N; + int nb_size = BLOCK_N; + + tinygemm_kernel_amx( + mb_size, nb_size, KB, + (const char *)wdata + mb_start * row_size_A, + (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), + (float *) dst->data + mb_start * N + nb_start, ldc); + } + }); + }); +} + +#endif // if defined(__AMX_INT8__) && defined(__AVX512VNNI__) diff --git a/cpp/amx/mmq.h b/cpp/amx/mmq.h new file mode 100644 index 00000000..c108ca1c --- /dev/null +++ b/cpp/amx/mmq.h @@ -0,0 +1,10 @@ +#pragma once +#include "common.h" + +size_t lm_ggml_backend_amx_desired_wsize(const struct lm_ggml_tensor * dst); + +size_t lm_ggml_backend_amx_get_alloc_size(const struct lm_ggml_tensor * tensor); + +void lm_ggml_backend_amx_convert_weight(struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size); + +void lm_ggml_backend_amx_mul_mat(const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst); diff --git a/cpp/common.cpp b/cpp/common.cpp index 667cd0fa..2279a2fa 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat [](const unsigned char c) { return !std::isprint(c); }), detokenized.end()); - buf << "\n" << std::to_string(i) - << ":token '" << detokenized << "'" - << ":pos " << std::to_string(batch.pos[i]) - << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) - << ":seq_id " << std::to_string(batch.seq_id[i][0]) - << ":logits " << std::to_string(batch.logits[i]); + buf << "\n" << std::to_string(i) + << ", token '" << detokenized << "'" + << ", pos " << std::to_string(batch.pos[i]) + << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) + << ", seq_id " << std::to_string(batch.seq_id[i][0]) + << ", logits " << std::to_string(batch.logits[i]); } buf << " ]"; @@ -652,7 +652,17 @@ bool fs_validate_filename(const std::string & filename) { std::u32string filename_utf32; try { +#if defined(__clang__) + // disable C++17 deprecation warning for std::codecvt_utf8 +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif std::wstring_convert, char32_t> converter; + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + filename_utf32 = converter.from_bytes(filename); // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, @@ -829,9 +839,9 @@ struct common_init_result common_init_from_params(common_params & params) { llama_model * model = nullptr; if (!params.hf_repo.empty() && !params.hf_file.empty()) { - model = common_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); + model = common_load_model_from_hf(params.hf_repo, params.hf_file, params.model, params.hf_token, mparams); } else if (!params.model_url.empty()) { - model = common_load_model_from_url(params.model_url.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); + model = common_load_model_from_url(params.model_url, params.model, params.hf_token, mparams); } else { model = llama_load_model_from_file(params.model.c_str(), mparams); } @@ -925,9 +935,28 @@ struct common_init_result common_init_from_params(common_params & params) { common_lora_adapters_apply(lctx, iparams.lora_adapters); } - if (params.sparams.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { + if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); - params.sparams.ignore_eos = false; + params.sampling.ignore_eos = false; + } + + if (params.sampling.ignore_eos) { + for (llama_token i = 0; i < llama_n_vocab(model); i++) { + if (llama_token_is_eog(model, i)) { + LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); + params.sampling.logit_bias.push_back({i, -INFINITY}); + } + } + } + + if (params.sampling.penalty_last_n == -1) { + LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.penalty_last_n = llama_n_ctx(lctx); + } + + if (params.sampling.dry_penalty_last_n == -1) { + LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); } if (params.warmup) { @@ -979,9 +1008,12 @@ void common_lora_adapters_apply(struct llama_context * ctx, std::vector 0) { @@ -1132,17 +1126,17 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_ } struct llama_model * common_load_model_from_url( - const char * model_url, - const char * path_model, - const char * hf_token, + const std::string & model_url, + const std::string & local_path, + const std::string & hf_token, const struct llama_model_params & params) { // Basic validation of the model_url - if (!model_url || strlen(model_url) == 0) { + if (model_url.empty()) { LOG_ERR("%s: invalid model_url\n", __func__); return NULL; } - if (!common_download_file(model_url, path_model, hf_token)) { + if (!common_download_file(model_url, local_path, hf_token)) { return NULL; } @@ -1153,9 +1147,9 @@ struct llama_model * common_load_model_from_url( /*.no_alloc = */ true, /*.ctx = */ NULL, }; - auto * ctx_gguf = lm_gguf_init_from_file(path_model, lm_gguf_params); + auto * ctx_gguf = lm_gguf_init_from_file(local_path.c_str(), lm_gguf_params); if (!ctx_gguf) { - LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, path_model); + LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, local_path.c_str()); return NULL; } @@ -1174,13 +1168,13 @@ struct llama_model * common_load_model_from_url( // Verify the first split file format // and extract split URL and PATH prefixes { - if (!llama_split_prefix(split_prefix, sizeof(split_prefix), path_model, 0, n_split)) { - LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, path_model, n_split); + if (!llama_split_prefix(split_prefix, sizeof(split_prefix), local_path.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, local_path.c_str(), n_split); return NULL; } - if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url, 0, n_split)) { - LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url, n_split); + if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model_url.c_str(), 0, n_split)) { + LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model_url.c_str(), n_split); return NULL; } } @@ -1207,14 +1201,14 @@ struct llama_model * common_load_model_from_url( } } - return llama_load_model_from_file(path_model, params); + return llama_load_model_from_file(local_path.c_str(), params); } struct llama_model * common_load_model_from_hf( - const char * repo, - const char * model, - const char * path_model, - const char * hf_token, + const std::string & repo, + const std::string & remote_path, + const std::string & local_path, + const std::string & hf_token, const struct llama_model_params & params) { // construct hugging face model url: // @@ -1228,27 +1222,27 @@ struct llama_model * common_load_model_from_hf( std::string model_url = "https://huggingface.co/"; model_url += repo; model_url += "/resolve/main/"; - model_url += model; + model_url += remote_path; - return common_load_model_from_url(model_url.c_str(), path_model, hf_token, params); + return common_load_model_from_url(model_url, local_path, hf_token, params); } #else struct llama_model * common_load_model_from_url( - const char * /*model_url*/, - const char * /*path_model*/, - const char * /*hf_token*/, + const std::string & /*model_url*/, + const std::string & /*local_path*/, + const std::string & /*hf_token*/, const struct llama_model_params & /*params*/) { LOG_WRN("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); return nullptr; } struct llama_model * common_load_model_from_hf( - const char * /*repo*/, - const char * /*model*/, - const char * /*path_model*/, - const char * /*hf_token*/, + const std::string & /*repo*/, + const std::string & /*remote_path*/, + const std::string & /*local_path*/, + const std::string & /*hf_token*/, const struct llama_model_params & /*params*/) { LOG_WRN("%s: llama.cpp built without libcurl, downloading from Hugging Face not supported.\n", __func__); return nullptr; @@ -1283,6 +1277,66 @@ void common_batch_add( batch.n_tokens++; } +// +// Token utils +// + +size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + +size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { + // check for empty sequences + if (a.empty() || b.empty()) { + return 0; + } + + // get the lengths of the input sequences + size_t a_len = a.size(); + size_t b_len = b.size(); + + // initialize the maximum length of the longest common subsequence (LCS) + size_t max_length = 0; + + // use two rows instead of a 2D matrix to optimize space + std::vector prev_row(b_len + 1, 0); + std::vector curr_row(b_len + 1, 0); + + // iterate through the elements of a + for (size_t i = 1; i <= a_len; i++) { + // iterate through the elements of b + for (size_t j = 1; j <= b_len; j++) { + // if elements at the current positions match + if (a[i - 1] == b[j - 1]) { + // if it's the first element of either sequences, set LCS length to 1 + if (i == 1 || j == 1) { + curr_row[j] = 1; + } else { + // increment LCS length by 1 compared to the previous element + curr_row[j] = prev_row[j - 1] + 1; + } + + // update max_length if necessary + if (curr_row[j] > max_length) { + max_length = curr_row[j]; + } + } else { + // reset LCS length if elements don't match + curr_row[j] = 0; + } + } + + // update the previous row for the next iteration + prev_row = curr_row; + } + + // return the maximum length of the LCS + return max_length; +} + // // Vocab utils // @@ -1519,7 +1573,9 @@ void common_embd_normalize(const float * inp, float * out, int n, int embd_norm) break; case 0: // max absolute for (int i = 0; i < n; i++) { - if (sum < std::abs(inp[i])) sum = std::abs(inp[i]); + if (sum < std::abs(inp[i])) { + sum = std::abs(inp[i]); + } } sum /= 32760.0; // make an int16 range break; diff --git a/cpp/common.h b/cpp/common.h index 38649e7f..37216ec5 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -33,11 +33,13 @@ struct common_lora_adapter_container : common_lora_adapter_info { struct llama_lora_adapter * adapter; }; +using llama_tokens = std::vector; + // build info extern int LLAMA_BUILD_NUMBER; -extern char const * LLAMA_COMMIT; -extern char const * LLAMA_COMPILER; -extern char const * LLAMA_BUILD_TARGET; +extern const char * LLAMA_COMMIT; +extern const char * LLAMA_COMPILER; +extern const char * LLAMA_BUILD_TARGET; struct common_control_vector_load_info; @@ -89,6 +91,7 @@ enum llama_example { LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_PARALLEL, + LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_COUNT, }; @@ -104,6 +107,7 @@ enum common_sampler_type { COMMON_SAMPLER_TYPE_TEMPERATURE = 7, COMMON_SAMPLER_TYPE_XTC = 8, COMMON_SAMPLER_TYPE_INFILL = 9, + COMMON_SAMPLER_TYPE_PENALTIES = 10, }; // dimensionality reduction methods, used by cvector-generator @@ -112,8 +116,8 @@ enum dimre_method { DIMRE_METHOD_MEAN, }; -// sampler parameters -struct common_sampler_params { +// sampling parameters +struct common_params_sampling { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler int32_t n_prev = 64; // number of previous tokens to remember @@ -139,14 +143,15 @@ struct common_sampler_params { int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; bool no_perf = false; // disable performance metrics + bool timing_per_token = false; std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY std::vector samplers = { + COMMON_SAMPLER_TYPE_PENALTIES, COMMON_SAMPLER_TYPE_DRY, COMMON_SAMPLER_TYPE_TOP_K, COMMON_SAMPLER_TYPE_TYPICAL_P, @@ -164,6 +169,30 @@ struct common_sampler_params { std::string print() const; }; +struct common_params_speculative { + std::vector devices; // devices to use for offloading + + int32_t n_ctx = 0; // draft context size + int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding + int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.9f; // minimum speculative decoding probability (greedy) + + struct cpu_params cpuparams; + struct cpu_params cpuparams_batch; + + std::string model = ""; // draft model for speculative decoding // NOLINT +}; + +struct common_params_vocoder { + std::string hf_repo = ""; // HF repo // NOLINT + std::string hf_file = ""; // HF file // NOLINT + + std::string model = ""; // model path // NOLINT + std::string model_url = ""; // model url to download // NOLINT +}; + struct common_params { bool vocab_only = false; int32_t n_predict = -1; // new tokens to predict @@ -171,15 +200,9 @@ struct common_params { int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 5; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs int32_t grp_attn_n = 1; // group-attention factor int32_t grp_attn_w = 512; // group-attention width int32_t n_print = -1; // print token count every n tokens (-1 = disabled) @@ -192,26 +215,33 @@ struct common_params { int32_t yarn_orig_ctx = 0; // YaRN original context length float defrag_thold = 0.1f; // KV cache defragmentation threshold + // offload params + std::vector devices; // devices to use for offloading + + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + + enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs + struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - struct cpu_params draft_cpuparams; - struct cpu_params draft_cpuparams_batch; lm_ggml_backend_sched_eval_callback cb_eval = nullptr; void * cb_eval_user_data = nullptr; lm_ggml_numa_strategy numa = LM_GGML_NUMA_STRATEGY_DISABLED; - enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - struct common_sampler_params sparams; + struct common_params_sampling sampling; + struct common_params_speculative speculative; + struct common_params_vocoder vocoder; std::string model = ""; // model path // NOLINT - std::string model_draft = ""; // draft model for speculative decoding // NOLINT - std::string model_alias = "unknown"; // model alias // NOLINT + std::string model_alias = ""; // model alias // NOLINT std::string model_url = ""; // model url to download // NOLINT std::string hf_token = ""; // HF token // NOLINT std::string hf_repo = ""; // HF repo // NOLINT @@ -285,8 +315,8 @@ struct common_params { llama_progress_callback progress_callback = nullptr; void * progress_callback_user_data = nullptr; - std::string cache_type_k = "f16"; // KV cache data type for the K - std::string cache_type_v = "f16"; // KV cache data type for the V + lm_ggml_type cache_type_k = LM_GGML_TYPE_F16; // KV cache data type for the K + lm_ggml_type cache_type_v = LM_GGML_TYPE_F16; // KV cache data type for the V // multimodal models (see examples/llava) std::string mmproj = ""; // path to multimodal projector // NOLINT @@ -436,6 +466,11 @@ std::vector string_split(const std::string & input, ch return parts; } +static bool string_starts_with(const std::string & str, + const std::string & prefix) { // While we wait for C++20's std::string::starts_with... + return str.rfind(prefix, 0) == 0; +} + bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -466,17 +501,28 @@ struct common_init_result { struct common_init_result common_init_from_params(common_params & params); -struct llama_model_params common_model_params_to_llama (const common_params & params); +struct llama_model_params common_model_params_to_llama ( common_params & params); struct llama_context_params common_context_params_to_llama(const common_params & params); struct lm_ggml_threadpool_params lm_ggml_threadpool_params_from_cpu_params(const cpu_params & params); -struct llama_model * common_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); -struct llama_model * common_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); +struct llama_model * common_load_model_from_url( + const std::string & model_url, + const std::string & local_path, + const std::string & hf_token, + const struct llama_model_params & params); +struct llama_model * common_load_model_from_hf( + const std::string & repo, + const std::string & remote_path, + const std::string & local_path, + const std::string & hf_token, + const struct llama_model_params & params); // clear LoRA adapters from context, then apply new list of adapters void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); +// // Batch utils +// void common_batch_clear(struct llama_batch & batch); @@ -487,6 +533,16 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// +// Token utils +// + +// longest common prefix +size_t common_lcp(const llama_tokens & a, const llama_tokens & b); + +// longet common subsequence +size_t common_lcs(const llama_tokens & a, const llama_tokens & b); + // // Vocab utils // @@ -566,7 +622,8 @@ void common_kv_cache_dump_view_seqs(const llama_kv_cache_view & view, int row_si // Embedding utils // -void common_embd_normalize(const float * inp, float * out, int n, int embd_norm = 2); +// TODO: repace embd_norm with an enum +void common_embd_normalize(const float * inp, float * out, int n, int embd_norm); float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); diff --git a/cpp/ggml-aarch64.c b/cpp/ggml-aarch64.c deleted file mode 100644 index 2e396a91..00000000 --- a/cpp/ggml-aarch64.c +++ /dev/null @@ -1,129 +0,0 @@ -#define LM_GGML_COMMON_DECL_C -#include "ggml-common.h" - -#include "ggml-aarch64.h" -#include "ggml-impl.h" -#include "ggml-quants.h" -#include - -#define UNUSED LM_GGML_UNUSED - -static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x4 out; - - for (int i = 0; i < 4; i++) { - out.d[i] = in[i].d; - } - - const int end = QK4_0 * 2 / blck_size_interleave; - - if (blck_size_interleave == 8) { - const uint64_t xor_mask = 0x8888888888888888ULL; - for (int i = 0; i < end; ++i) { - int src_id = i % 4; - int src_offset = (i / 4) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - uint64_t elems; - // Using memcpy to avoid unaligned memory accesses - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - elems ^= xor_mask; - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); - } - } else if (blck_size_interleave == 4) { - const uint32_t xor_mask = 0x88888888; - for (int i = 0; i < end; ++i) { - int src_id = i % 4; - int src_offset = (i / 4) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - uint32_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t)); - elems ^= xor_mask; - memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t)); - } - } else { - LM_GGML_ASSERT(false); - } - - return out; -} - -// interleave 8 block_q4_0s in blocks of blck_size_interleave -// returns an interleaved block_q4_0x8 -// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks -// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave -static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x8 out; - - for (int i = 0; i < 8; i++) { - out.d[i] = in[i].d; - } - - const int end = QK4_0 * 4 / blck_size_interleave; - const uint64_t xor_mask = 0x8888888888888888ULL; - - for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; - int dst_offset = i * blck_size_interleave; - - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - elems ^= xor_mask; - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); - } - - return out; -} - -static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) { - assert(n_per_row % QK4_0 == 0); - const int nb = n_per_row / QK4_0; - - void * out_ptr = NULL; - if (nrows_interleaved == 8) { - out_ptr = (block_q4_0x8 *) dst; - } - else if (nrows_interleaved == 4) { - out_ptr = (block_q4_0x4 *) dst; - } - assert(nrows_interleaved <= 8); - block_q4_0 dst_tmp[8]; - - for (int b = 0; b < (nrow * n_per_row); b += nrows_interleaved * n_per_row) { - - for (int64_t x = 0; x < nb; x++) { - - for (int i = 0; i < nrows_interleaved; i++ ) { - quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0); - } - - if (nrows_interleaved == 8) { - *(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave); - out_ptr = (block_q4_0x8 *) out_ptr + 1; - } - else if (nrows_interleaved == 4) { - *(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave); - out_ptr = (block_q4_0x4 *) out_ptr + 1; - } - } - } - - return ((nrow * n_per_row) / QK4_0 * sizeof(block_q4_0)); -} - -size_t quantize_q4_0_4x4(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - UNUSED(quant_weights); - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 4); -} - -size_t quantize_q4_0_4x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - UNUSED(quant_weights); - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 4, 8); -} - -size_t quantize_q4_0_8x8(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - UNUSED(quant_weights); - return quantize_q4_0_nr_bl(src, dst, nrow, n_per_row, 8, 8); -} diff --git a/cpp/ggml-aarch64.h b/cpp/ggml-aarch64.h deleted file mode 100644 index 3b8c8b37..00000000 --- a/cpp/ggml-aarch64.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include "ggml.h" - -// GGML internal header - -#ifdef __cplusplus -extern "C" { -#endif - -// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") -size_t quantize_q4_0_4x4(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_0_4x8(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q4_0_8x8(const float * LM_GGML_RESTRICT src, void * LM_GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); - -#ifdef __cplusplus -} -#endif - diff --git a/cpp/ggml-alloc.c b/cpp/ggml-alloc.c index e4c00b12..e102749b 100644 --- a/cpp/ggml-alloc.c +++ b/cpp/ggml-alloc.c @@ -534,7 +534,6 @@ static void lm_ggml_gallocr_allocate_node(lm_ggml_gallocr_t galloc, struct lm_gg size_t offset = lm_ggml_dyn_tallocr_alloc(alloc, size, node); hn->buffer_id = buffer_id; hn->offset = offset; - return; } } diff --git a/cpp/ggml-backend-impl.h b/cpp/ggml-backend-impl.h index afcd3a9e..c5470781 100644 --- a/cpp/ggml-backend-impl.h +++ b/cpp/ggml-backend-impl.h @@ -8,6 +8,8 @@ extern "C" { #endif + #define LM_GGML_BACKEND_API_VERSION 1 + // // Backend buffer type // @@ -63,20 +65,20 @@ extern "C" { enum lm_ggml_backend_buffer_usage usage; }; - lm_ggml_backend_buffer_t lm_ggml_backend_buffer_init( + LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_buffer_init( lm_ggml_backend_buffer_type_t buft, struct lm_ggml_backend_buffer_i iface, void * context, size_t size); // do not use directly, use lm_ggml_backend_tensor_copy instead - bool lm_ggml_backend_buffer_copy_tensor(const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); + LM_GGML_API bool lm_ggml_backend_buffer_copy_tensor(const struct lm_ggml_tensor * src, struct lm_ggml_tensor * dst); // multi-buffer // buffer that contains a collection of buffers - lm_ggml_backend_buffer_t lm_ggml_backend_multi_buffer_alloc_buffer(lm_ggml_backend_buffer_t * buffers, size_t n_buffers); - bool lm_ggml_backend_buffer_is_multi_buffer(lm_ggml_backend_buffer_t buffer); - void lm_ggml_backend_multi_buffer_set_usage(lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage); + LM_GGML_API lm_ggml_backend_buffer_t lm_ggml_backend_multi_buffer_alloc_buffer(lm_ggml_backend_buffer_t * buffers, size_t n_buffers); + LM_GGML_API bool lm_ggml_backend_buffer_is_multi_buffer(lm_ggml_backend_buffer_t buffer); + LM_GGML_API void lm_ggml_backend_multi_buffer_set_usage(lm_ggml_backend_buffer_t buffer, enum lm_ggml_backend_buffer_usage usage); // // Backend (stream) @@ -199,17 +201,55 @@ extern "C" { }; struct lm_ggml_backend_reg { - // int api_version; // TODO: for dynamic loading + int api_version; // initialize to LM_GGML_BACKEND_API_VERSION struct lm_ggml_backend_reg_i iface; void * context; }; - // Internal backend registry API - void lm_ggml_backend_register(lm_ggml_backend_reg_t reg); - void lm_ggml_backend_device_register(lm_ggml_backend_dev_t device); - // TODO: backends can be loaded as a dynamic library, in which case it needs to export this function - // typedef lm_ggml_backend_register_t * (*lm_ggml_backend_init)(void); + LM_GGML_API void lm_ggml_backend_register(lm_ggml_backend_reg_t reg); + LM_GGML_API void lm_ggml_backend_device_register(lm_ggml_backend_dev_t device); + + // Add backend dynamic loading support to the backend + + // Initialize the backend + typedef lm_ggml_backend_reg_t (*lm_ggml_backend_init_t)(void); + // Optional: obtain a score for the backend based on the system configuration + // Higher scores are preferred, 0 means the backend is not supported in the current system + typedef int (*lm_ggml_backend_score_t)(void); + +#ifdef LM_GGML_BACKEND_DL +# ifdef __cplusplus +# define LM_GGML_BACKEND_DL_IMPL(reg_fn) \ + extern "C" { \ + LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_init(void); \ + } \ + lm_ggml_backend_reg_t lm_ggml_backend_init(void) { \ + return reg_fn(); \ + } +# define LM_GGML_BACKEND_DL_SCORE_IMPL(score_fn) \ + extern "C" { \ + LM_GGML_BACKEND_API int lm_ggml_backend_score(void); \ + } \ + int lm_ggml_backend_score(void) { \ + return score_fn(); \ + } +# else +# define LM_GGML_BACKEND_DL_IMPL(reg_fn) \ + LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_init(void); \ + lm_ggml_backend_reg_t lm_ggml_backend_init(void) { \ + return reg_fn(); \ + } +# define LM_GGML_BACKEND_DL_SCORE_IMPL(score_fn) \ + LM_GGML_BACKEND_API int lm_ggml_backend_score(void); \ + int lm_ggml_backend_score(void) { \ + return score_fn(); \ + } +# endif +#else +# define LM_GGML_BACKEND_DL_IMPL(reg_fn) +# define LM_GGML_BACKEND_DL_SCORE_IMPL(score_fn) +#endif #ifdef __cplusplus } diff --git a/cpp/ggml-backend-reg.cpp b/cpp/ggml-backend-reg.cpp index 4009d5fc..649cd0b3 100644 --- a/cpp/ggml-backend-reg.cpp +++ b/cpp/ggml-backend-reg.cpp @@ -1,11 +1,34 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" -#include "ggml-cpu.h" #include "ggml-impl.h" +#include +#include #include +#include +#include +#include +#include +#include #include +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include +#elif defined(__APPLE__) +# include +# include +#else +# include +# include +#endif + // Backend registry +#ifdef LM_GGML_USE_CPU +#include "ggml-cpu.h" +#endif #ifdef LM_GGML_USE_CUDA #include "ggml-cuda.h" @@ -28,6 +51,10 @@ #include "ggml-vulkan.h" #endif +#ifdef LM_GGML_USE_OPENCL +#include "ggml-opencl.h" +#endif + #ifdef LM_GGML_USE_BLAS #include "ggml-blas.h" #endif @@ -36,10 +63,6 @@ #include "ggml-rpc.h" #endif -#ifdef LM_GGML_USE_AMX -# include "ggml-amx.h" -#endif - #ifdef LM_GGML_USE_CANN #include "ggml-cann.h" #endif @@ -48,8 +71,75 @@ #include "ggml-kompute.h" #endif +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +static dl_handle * dl_load_library(const std::wstring & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +static dl_handle * dl_load_library(const std::string & path) { + std::wstring_convert> converter; + return dl_load_library(converter.from_bytes(path)); +} + +static void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +static void * dl_load_library(const std::string & path) { + dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + + return handle; +} + +static void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +#endif + +using dl_handle_ptr = std::unique_ptr; + +struct lm_ggml_backend_reg_entry { + lm_ggml_backend_reg_t reg; + dl_handle_ptr handle; +}; + struct lm_ggml_backend_registry { - std::vector backends; + std::vector backends; std::vector devices; lm_ggml_backend_registry() { @@ -69,6 +159,9 @@ struct lm_ggml_backend_registry { #ifdef LM_GGML_USE_VULKAN register_backend(lm_ggml_backend_vk_reg()); #endif +#ifdef LM_GGML_USE_OPENCL + register_backend(lm_ggml_backend_opencl_reg()); +#endif #ifdef LM_GGML_USE_CANN register_backend(lm_ggml_backend_cann_reg()); #endif @@ -78,17 +171,25 @@ struct lm_ggml_backend_registry { #ifdef LM_GGML_USE_RPC register_backend(lm_ggml_backend_rpc_reg()); #endif -#ifdef LM_GGML_USE_AMX - register_backend(lm_ggml_backend_amx_reg()); -#endif #ifdef LM_GGML_USE_KOMPUTE register_backend(lm_ggml_backend_kompute_reg()); #endif - +#ifdef LM_GGML_USE_CPU register_backend(lm_ggml_backend_cpu_reg()); +#endif + } + + ~lm_ggml_backend_registry() { + // FIXME: backends cannot be safely unloaded without a function to destroy all the backend resources, + // since backend threads may still be running and accessing resources from the dynamic library + for (auto & entry : backends) { + if (entry.handle) { + entry.handle.release(); // NOLINT + } + } } - void register_backend(lm_ggml_backend_reg_t reg) { + void register_backend(lm_ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) { if (!reg) { return; } @@ -97,7 +198,7 @@ struct lm_ggml_backend_registry { LM_GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", __func__, lm_ggml_backend_reg_name(reg), lm_ggml_backend_reg_dev_count(reg)); #endif - backends.push_back(reg); + backends.push_back({ reg, std::move(handle) }); for (size_t i = 0; i < lm_ggml_backend_reg_dev_count(reg); i++) { register_device(lm_ggml_backend_reg_dev_get(reg, i)); } @@ -109,6 +210,76 @@ struct lm_ggml_backend_registry { #endif devices.push_back(device); } + + lm_ggml_backend_reg_t load_backend(const char * path, bool silent) { + dl_handle_ptr handle { dl_load_library(path) }; + if (!handle) { + if (!silent) { + LM_GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path); + } + return nullptr; + } + + auto score_fn = (lm_ggml_backend_score_t) dl_get_sym(handle.get(), "lm_ggml_backend_score"); + if (score_fn && score_fn() == 0) { + if (!silent) { + LM_GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path); + } + return nullptr; + } + + auto backend_init_fn = (lm_ggml_backend_init_t) dl_get_sym(handle.get(), "lm_ggml_backend_init"); + if (!backend_init_fn) { + if (!silent) { + LM_GGML_LOG_ERROR("%s: failed to find lm_ggml_backend_init in %s\n", __func__, path); + } + return nullptr; + } + + lm_ggml_backend_reg_t reg = backend_init_fn(); + if (!reg || reg->api_version != LM_GGML_BACKEND_API_VERSION) { + if (!silent) { + if (!reg) { + LM_GGML_LOG_ERROR("%s: failed to initialize backend from %s: lm_ggml_backend_init returned NULL\n", __func__, path); + } else { + LM_GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n", + __func__, path, reg->api_version, LM_GGML_BACKEND_API_VERSION); + } + } + return nullptr; + } + + LM_GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, lm_ggml_backend_reg_name(reg), path); + + register_backend(reg, std::move(handle)); + + return reg; + } + + void unload_backend(lm_ggml_backend_reg_t reg, bool silent) { + auto it = std::find_if(backends.begin(), backends.end(), + [reg](const lm_ggml_backend_reg_entry & entry) { return entry.reg == reg; }); + + if (it == backends.end()) { + if (!silent) { + LM_GGML_LOG_ERROR("%s: backend not found\n", __func__); + } + return; + } + + if (!silent) { + LM_GGML_LOG_DEBUG("%s: unloading %s backend\n", __func__, lm_ggml_backend_reg_name(reg)); + } + + // remove devices + devices.erase( + std::remove_if(devices.begin(), devices.end(), + [reg](lm_ggml_backend_dev_t dev) { return lm_ggml_backend_dev_backend_reg(dev) == reg; }), + devices.end()); + + // remove backend + backends.erase(it); + } }; static lm_ggml_backend_registry & get_reg() { @@ -126,23 +297,32 @@ void lm_ggml_backend_device_register(lm_ggml_backend_dev_t device) { } // Backend (reg) enumeration +static bool striequals(const char * a, const char * b) { + for (; *a && *b; a++, b++) { + if (std::tolower(*a) != std::tolower(*b)) { + return false; + } + } + return *a == *b; +} + size_t lm_ggml_backend_reg_count() { return get_reg().backends.size(); } lm_ggml_backend_reg_t lm_ggml_backend_reg_get(size_t index) { LM_GGML_ASSERT(index < lm_ggml_backend_reg_count()); - return get_reg().backends[index]; + return get_reg().backends[index].reg; } lm_ggml_backend_reg_t lm_ggml_backend_reg_by_name(const char * name) { for (size_t i = 0; i < lm_ggml_backend_reg_count(); i++) { lm_ggml_backend_reg_t reg = lm_ggml_backend_reg_get(i); - if (std::strcmp(lm_ggml_backend_reg_name(reg), name) == 0) { + if (striequals(lm_ggml_backend_reg_name(reg), name)) { return reg; } } - return NULL; + return nullptr; } // Device enumeration @@ -158,11 +338,11 @@ lm_ggml_backend_dev_t lm_ggml_backend_dev_get(size_t index) { lm_ggml_backend_dev_t lm_ggml_backend_dev_by_name(const char * name) { for (size_t i = 0; i < lm_ggml_backend_dev_count(); i++) { lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); - if (strcmp(lm_ggml_backend_dev_name(dev), name) == 0) { + if (striequals(lm_ggml_backend_dev_name(dev), name)) { return dev; } } - return NULL; + return nullptr; } lm_ggml_backend_dev_t lm_ggml_backend_dev_by_type(enum lm_ggml_backend_dev_type type) { @@ -172,14 +352,14 @@ lm_ggml_backend_dev_t lm_ggml_backend_dev_by_type(enum lm_ggml_backend_dev_type return dev; } } - return NULL; + return nullptr; } // Convenience functions lm_ggml_backend_t lm_ggml_backend_init_by_name(const char * name, const char * params) { lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_name(name); if (!dev) { - return NULL; + return nullptr; } return lm_ggml_backend_dev_init(dev, params); } @@ -187,7 +367,7 @@ lm_ggml_backend_t lm_ggml_backend_init_by_name(const char * name, const char * p lm_ggml_backend_t lm_ggml_backend_init_by_type(enum lm_ggml_backend_dev_type type, const char * params) { lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_by_type(type); if (!dev) { - return NULL; + return nullptr; } return lm_ggml_backend_dev_init(dev, params); } @@ -198,7 +378,188 @@ lm_ggml_backend_t lm_ggml_backend_init_best(void) { dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU); } if (!dev) { - return NULL; + return nullptr; + } + return lm_ggml_backend_dev_init(dev, nullptr); +} + +// Dynamic loading +lm_ggml_backend_reg_t lm_ggml_backend_load(const char * path) { + return get_reg().load_backend(path, false); +} + +void lm_ggml_backend_unload(lm_ggml_backend_reg_t reg) { + get_reg().unload_backend(reg, true); +} + +static std::string get_executable_path() { +#if defined(__APPLE__) + // get executable path + std::vector path; + uint32_t size; + while (true) { + size = path.size(); + if (_NSGetExecutablePath(path.data(), &size) == 0) { + break; + } + path.resize(size); + } + std::string base_path(path.data(), size); + // remove executable name + auto last_slash = base_path.find_last_of('/'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + return base_path + "/"; +#elif defined(__linux__) || defined(__FreeBSD__) + std::string base_path = "."; + std::vector path(1024); + while (true) { + // get executable path +# if defined(__linux__) + ssize_t len = readlink("/proc/self/exe", path.data(), path.size()); +# elif defined(__FreeBSD__) + ssize_t len = readlink("/proc/curproc/file", path.data(), path.size()); +# endif + if (len == -1) { + break; + } + if (len < (ssize_t) path.size()) { + base_path = std::string(path.data(), len); + // remove executable name + auto last_slash = base_path.find_last_of('/'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + break; + } + path.resize(path.size() * 2); + } + + return base_path + "/"; +#elif defined(_WIN32) + std::vector path(MAX_PATH); + DWORD len = GetModuleFileNameA(NULL, path.data(), path.size()); + if (len == 0) { + return ""; } - return lm_ggml_backend_dev_init(dev, NULL); + std::string base_path(path.data(), len); + // remove executable name + auto last_slash = base_path.find_last_of('\\'); + if (last_slash != std::string::npos) { + base_path = base_path.substr(0, last_slash); + } + return base_path + "\\"; +#endif +} + +static std::string backend_filename_prefix() { +#ifdef _WIN32 + return "ggml-"; +#else + return "libggml-"; +#endif +} + +static std::string backend_filename_suffix() { +#ifdef _WIN32 + return ".dll"; +#else + return ".so"; +#endif +} + +static lm_ggml_backend_reg_t lm_ggml_backend_load_best(const char * name, bool silent, const char * user_search_path) { + // enumerate all the files that match [lib]ggml-name-*.[so|dll] in the search paths + // TODO: search system paths + std::string file_prefix = backend_filename_prefix() + name + "-"; + std::vector search_paths; + if (user_search_path == nullptr) { + search_paths.push_back("./"); + search_paths.push_back(get_executable_path()); + } else { +#if defined(_WIN32) + search_paths.push_back(std::string(user_search_path) + "\\"); +#else + search_paths.push_back(std::string(user_search_path) + "/"); +#endif + } + + int best_score = 0; + std::string best_path; + + namespace fs = std::filesystem; + for (const auto & search_path : search_paths) { + if (!fs::exists(search_path)) { + continue; + } + fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); + for (const auto & entry : dir_it) { + if (entry.is_regular_file()) { + std::string filename = entry.path().filename().string(); + std::string ext = entry.path().extension().string(); + if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { + dl_handle_ptr handle { dl_load_library(entry.path().c_str()) }; + if (!handle && !silent) { + LM_GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str()); + } + if (handle) { + auto score_fn = (lm_ggml_backend_score_t) dl_get_sym(handle.get(), "lm_ggml_backend_score"); + if (score_fn) { + int s = score_fn(); +#ifndef NDEBUG + LM_GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s); +#endif + if (s > best_score) { + best_score = s; + best_path = entry.path().string(); + } + } else { + if (!silent) { + LM_GGML_LOG_INFO("%s: failed to find lm_ggml_backend_score in %s\n", __func__, entry.path().string().c_str()); + } + } + } + } + } + } + } + + if (best_score == 0) { + // try to load the base backend + for (const auto & search_path : search_paths) { + std::string path = search_path + backend_filename_prefix() + name + backend_filename_suffix(); + if (fs::exists(path)) { + return get_reg().load_backend(path.c_str(), silent); + } + } + return nullptr; + } + + return get_reg().load_backend(best_path.c_str(), silent); +} + +void lm_ggml_backend_load_all() { + lm_ggml_backend_load_all_from_path(nullptr); +} + +void lm_ggml_backend_load_all_from_path(const char * dir_path) { +#ifdef NDEBUG + bool silent = true; +#else + bool silent = false; +#endif + + lm_ggml_backend_load_best("blas", silent, dir_path); + lm_ggml_backend_load_best("cann", silent, dir_path); + lm_ggml_backend_load_best("cuda", silent, dir_path); + lm_ggml_backend_load_best("hip", silent, dir_path); + lm_ggml_backend_load_best("kompute", silent, dir_path); + lm_ggml_backend_load_best("metal", silent, dir_path); + lm_ggml_backend_load_best("rpc", silent, dir_path); + lm_ggml_backend_load_best("sycl", silent, dir_path); + lm_ggml_backend_load_best("vulkan", silent, dir_path); + lm_ggml_backend_load_best("opencl", silent, dir_path); + lm_ggml_backend_load_best("musa", silent, dir_path); + lm_ggml_backend_load_best("cpu", silent, dir_path); } diff --git a/cpp/ggml-backend.cpp b/cpp/ggml-backend.cpp index 27f92ad5..65ad6b4c 100644 --- a/cpp/ggml-backend.cpp +++ b/cpp/ggml-backend.cpp @@ -252,6 +252,7 @@ void lm_ggml_backend_tensor_get_async(lm_ggml_backend_t backend, const struct lm } void lm_ggml_backend_tensor_set(struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + LM_GGML_ASSERT(tensor); lm_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (size == 0) { @@ -266,6 +267,7 @@ void lm_ggml_backend_tensor_set(struct lm_ggml_tensor * tensor, const void * dat } void lm_ggml_backend_tensor_get(const struct lm_ggml_tensor * tensor, void * data, size_t offset, size_t size) { + LM_GGML_ASSERT(tensor); lm_ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (size == 0) { @@ -740,7 +742,8 @@ static int lm_ggml_backend_sched_backend_id_from_cur(lm_ggml_backend_sched_t sch if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) { // since the tensor is pre-allocated, it cannot be moved to another backend - LM_GGML_ABORT("pre-allocated tensor (%s) in a backend that cannot run the operation", tensor->name); + lm_ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + LM_GGML_ABORT("pre-allocated tensor (%s) in a buffer (%s) that cannot run the operation (%s)", tensor->name, lm_ggml_backend_buffer_name(buffer), lm_ggml_op_name(tensor->op)); } // graph input @@ -884,9 +887,6 @@ static void lm_ggml_backend_sched_split_graph(lm_ggml_backend_sched_t sched, str for (int i = 0; i < graph->n_nodes; i++) { struct lm_ggml_tensor * node = graph->nodes[i]; int * node_backend_id = &tensor_backend_id(node); - if (lm_ggml_is_view_op(node->op)) { - continue; - } // do not overwrite user assignments if (*node_backend_id == -1) { *node_backend_id = lm_ggml_backend_sched_backend_id_from_cur(sched, node); diff --git a/cpp/ggml-backend.h b/cpp/ggml-backend.h index 30f0a495..d86f06f6 100644 --- a/cpp/ggml-backend.h +++ b/cpp/ggml-backend.h @@ -190,6 +190,14 @@ extern "C" { typedef void (*lm_ggml_backend_set_n_threads_t)(lm_ggml_backend_t backend, int n_threads); // Get additional buffer types provided by the device (returns a NULL-terminated array) typedef lm_ggml_backend_buffer_type_t * (*lm_ggml_backend_dev_get_extra_bufts_t)(lm_ggml_backend_dev_t device); + // Set the abort callback for the backend + typedef void (*lm_ggml_backend_set_abort_callback_t)(lm_ggml_backend_t backend, lm_ggml_abort_callback abort_callback, void * abort_callback_data); + // Get a list of feature flags supported by the backend (returns a NULL-terminated array) + struct lm_ggml_backend_feature { + const char * name; + const char * value; + }; + typedef struct lm_ggml_backend_feature * (*lm_ggml_backend_get_features_t)(lm_ggml_backend_reg_t reg); // // Backend registry @@ -214,6 +222,14 @@ extern "C" { // = lm_ggml_backend_dev_init(lm_ggml_backend_dev_by_type(GPU) OR lm_ggml_backend_dev_by_type(CPU), NULL) LM_GGML_API lm_ggml_backend_t lm_ggml_backend_init_best(void); + // Load a backend from a dynamic library and register it + LM_GGML_API lm_ggml_backend_reg_t lm_ggml_backend_load(const char * path); + // Unload a backend if loaded dynamically and unregister it + LM_GGML_API void lm_ggml_backend_unload(lm_ggml_backend_reg_t reg); + // Load all known backends from dynamic libraries + LM_GGML_API void lm_ggml_backend_load_all(void); + LM_GGML_API void lm_ggml_backend_load_all_from_path(const char * dir_path); + // // Backend scheduler // diff --git a/cpp/ggml-common.h b/cpp/ggml-common.h index fc17c612..a860dde7 100644 --- a/cpp/ggml-common.h +++ b/cpp/ggml-common.h @@ -6,7 +6,20 @@ typedef uint16_t lm_ggml_half; typedef uint32_t lm_ggml_half2; -#define LM_GGML_COMMON_AGGR +#define LM_GGML_COMMON_AGGR_U +#define LM_GGML_COMMON_AGGR_S + +#define LM_GGML_COMMON_DECL +#elif defined(LM_GGML_COMMON_DECL_CPP) +#include + +typedef uint16_t lm_ggml_half; +typedef uint32_t lm_ggml_half2; + +// std-c++ allow anonymous unions but some compiler warn on it +#define LM_GGML_COMMON_AGGR_U data +// std-c++ do not allow it. +#define LM_GGML_COMMON_AGGR_S data #define LM_GGML_COMMON_DECL #elif defined(LM_GGML_COMMON_DECL_METAL) @@ -15,7 +28,8 @@ typedef uint32_t lm_ggml_half2; typedef half lm_ggml_half; typedef half2 lm_ggml_half2; -#define LM_GGML_COMMON_AGGR +#define LM_GGML_COMMON_AGGR_U +#define LM_GGML_COMMON_AGGR_S #define LM_GGML_COMMON_DECL #elif defined(LM_GGML_COMMON_DECL_CUDA) @@ -29,7 +43,8 @@ typedef half2 lm_ggml_half2; typedef half lm_ggml_half; typedef half2 lm_ggml_half2; -#define LM_GGML_COMMON_AGGR data +#define LM_GGML_COMMON_AGGR_U +#define LM_GGML_COMMON_AGGR_S data #define LM_GGML_COMMON_DECL #elif defined(LM_GGML_COMMON_DECL_HIP) @@ -39,7 +54,8 @@ typedef half2 lm_ggml_half2; typedef half lm_ggml_half; typedef half2 lm_ggml_half2; -#define LM_GGML_COMMON_AGGR data +#define LM_GGML_COMMON_AGGR_U +#define LM_GGML_COMMON_AGGR_S data #define LM_GGML_COMMON_DECL #elif defined(LM_GGML_COMMON_DECL_SYCL) @@ -49,7 +65,8 @@ typedef half2 lm_ggml_half2; typedef sycl::half lm_ggml_half; typedef sycl::half2 lm_ggml_half2; -#define LM_GGML_COMMON_AGGR data +#define LM_GGML_COMMON_AGGR_U +#define LM_GGML_COMMON_AGGR_S data #define LM_GGML_COMMON_DECL #endif @@ -154,9 +171,9 @@ typedef struct { struct { lm_ggml_half d; // delta lm_ggml_half m; // min - } LM_GGML_COMMON_AGGR; + } LM_GGML_COMMON_AGGR_S; lm_ggml_half2 dm; - }; + } LM_GGML_COMMON_AGGR_U; uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; static_assert(sizeof(block_q4_1) == 2 * sizeof(lm_ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding"); @@ -175,9 +192,9 @@ typedef struct { struct { lm_ggml_half d; // delta lm_ggml_half m; // min - } LM_GGML_COMMON_AGGR; + } LM_GGML_COMMON_AGGR_S; lm_ggml_half2 dm; - }; + } LM_GGML_COMMON_AGGR_U; uint8_t qh[4]; // 5-th bit of quants uint8_t qs[QK5_1 / 2]; // nibbles / quants } block_q5_1; @@ -196,37 +213,13 @@ typedef struct { struct { lm_ggml_half d; // delta lm_ggml_half s; // d * sum(qs[i]) - } LM_GGML_COMMON_AGGR; + } LM_GGML_COMMON_AGGR_S; lm_ggml_half2 ds; - }; + } LM_GGML_COMMON_AGGR_U; int8_t qs[QK8_1]; // quants } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(lm_ggml_half) + QK8_1, "wrong q8_1 block size/padding"); -typedef struct { - lm_ggml_half d[4]; // deltas for 4 q4_0 blocks - uint8_t qs[QK4_0 * 2]; // nibbles / quants for 4 q4_0 blocks -} block_q4_0x4; -static_assert(sizeof(block_q4_0x4) == 4 * sizeof(lm_ggml_half) + QK4_0 * 2, "wrong q4_0x4 block size/padding"); - -typedef struct { - lm_ggml_half d[8]; // deltas for 8 q4_0 blocks - uint8_t qs[QK4_0 * 4]; // nibbles / quants for 8 q4_0 blocks -} block_q4_0x8; -static_assert(sizeof(block_q4_0x8) == 8 * sizeof(lm_ggml_half) + QK4_0 * 4, "wrong q4_0x8 block size/padding"); - -typedef struct { - lm_ggml_half d[4]; // deltas for 4 q8_0 blocks - int8_t qs[QK8_0 * 4]; // quants for 4 q8_0 blocks -} block_q8_0x4; -static_assert(sizeof(block_q8_0x4) == 4 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong q8_0x4 block size/padding"); - -typedef struct { - lm_ggml_half d[8]; // deltas for 8 q8_0 blocks - int8_t qs[QK8_0 * 8]; // quants for 8 q8_0 blocks -} block_q8_0x8; -static_assert(sizeof(block_q8_0x8) == 8 * sizeof(lm_ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); - // // Ternary quantization // @@ -261,9 +254,9 @@ typedef struct { struct { lm_ggml_half d; // super-block scale for quantized scales lm_ggml_half dmin; // super-block scale for quantized mins - } LM_GGML_COMMON_AGGR; + } LM_GGML_COMMON_AGGR_S; lm_ggml_half2 dm; - }; + } LM_GGML_COMMON_AGGR_U; } block_q2_K; static_assert(sizeof(block_q2_K) == 2*sizeof(lm_ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); @@ -288,9 +281,9 @@ typedef struct { struct { lm_ggml_half d; // super-block scale for quantized scales lm_ggml_half dmin; // super-block scale for quantized mins - } LM_GGML_COMMON_AGGR; + } LM_GGML_COMMON_AGGR_S; lm_ggml_half2 dm; - }; + } LM_GGML_COMMON_AGGR_U; uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qs[QK_K/2]; // 4--bit quants } block_q4_K; @@ -305,9 +298,9 @@ typedef struct { struct { lm_ggml_half d; // super-block scale for quantized scales lm_ggml_half dmin; // super-block scale for quantized mins - } LM_GGML_COMMON_AGGR; + } LM_GGML_COMMON_AGGR_S; lm_ggml_half2 dm; - }; + } LM_GGML_COMMON_AGGR_U; uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits uint8_t qh[QK_K/8]; // quants, high bit uint8_t qs[QK_K/2]; // quants, low 4 bits @@ -431,6 +424,13 @@ static_assert(sizeof(block_iq4_xs) == sizeof(lm_ggml_half) + sizeof(uint16_t) + #define LM_GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { #define LM_GGML_TABLE_END() }; +#define LM_GGML_COMMON_IMPL +#elif defined(LM_GGML_COMMON_IMPL_CPP) +#include + +#define LM_GGML_TABLE_BEGIN(type, name, size) static const type name[size] = { +#define LM_GGML_TABLE_END() }; + #define LM_GGML_COMMON_IMPL #elif defined(LM_GGML_COMMON_IMPL_METAL) #include @@ -473,7 +473,7 @@ LM_GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128) 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255, LM_GGML_TABLE_END() -//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics +//#if __CUDA_ARCH__ >= LM_GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics LM_GGML_TABLE_BEGIN(uint64_t, ksigns64, 128) 0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff, 0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff, diff --git a/cpp/ggml-cpu-aarch64.c b/cpp/ggml-cpu-aarch64.cpp similarity index 79% rename from cpp/ggml-cpu-aarch64.c rename to cpp/ggml-cpu-aarch64.cpp index 40e033b9..0cd99cc5 100644 --- a/cpp/ggml-cpu-aarch64.c +++ b/cpp/ggml-cpu-aarch64.cpp @@ -1,24 +1,57 @@ -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates -// SPDX-License-Identifier: MIT -// - -#define LM_GGML_COMMON_IMPL_C +#define LM_GGML_COMMON_IMPL_CPP +#define LM_GGML_COMMON_DECL_CPP #include "ggml-common.h" +#include "ggml-backend-impl.h" #include "ggml-quants.h" #include "ggml-impl.h" #include "ggml-cpu.h" #include "ggml-cpu-impl.h" +#include "ggml-cpu-traits.h" -#include -#include -#include -#include -#include // for qsort -#include // for LM_GGML_ASSERT +#include +#include +#include +#include +#include // for qsort +#include // for LM_GGML_ASSERT #include "ggml-cpu-aarch64.h" +// TODO: move to include file? +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + lm_ggml_half d[N]; // deltas for N qK_0 blocks + int8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +// control size +static_assert(sizeof(block<4, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); +static_assert(sizeof(block<4, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); +static_assert(sizeof(block<8, 4>) == 4 * sizeof(lm_ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); +static_assert(sizeof(block<8, 8>) == 8 * sizeof(lm_ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); + +using block_q4_0x4 = block<4, 4>; +using block_q4_0x8 = block<4, 8>; +using block_q8_0x4 = block<8, 4>; +using block_q8_0x8 = block<8, 8>; + +struct block_iq4_nlx4 { + lm_ggml_half d[4]; // deltas for 4 iq4_nl blocks + uint8_t qs[QK4_NL * 2]; // nibbles / quants for 4 iq4_nl blocks +}; + +static_assert(sizeof(block_iq4_nlx4) == 4 * sizeof(lm_ggml_half) + QK4_NL * 2, "wrong iq4_nlx4 block size/padding"); + #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Woverlength-strings" #elif defined(_MSC_VER) @@ -132,7 +165,7 @@ static inline __m512i sum_i16_pairs_int_32x16(const __m512i x) { } static inline __m512i mul_sum_us8_pairs_int32x16(const __m512i ax, const __m512i sy) { -#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) +#if defined(__AVX512VNNI__) const __m512i zero = _mm512_setzero_si512(); return _mm512_dpbusd_epi32(zero, ax, sy); #else @@ -187,12 +220,14 @@ static inline __m256i mul_sum_i8_pairs_int32x8(const __m256i x, const __m256i y) } #endif -static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k) { +static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113}; + +static void quantize_q8_0_4x4(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; - block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy; #if defined(__ARM_NEON) float32x4_t srcv[4][8]; @@ -281,12 +316,12 @@ static void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int6 #endif } -static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k) { +static void quantize_q8_0_4x8(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); const int nb = k / QK8_0; - block_q8_0x4 * restrict y = (block_q8_0x4 *) vy; + block_q8_0x4 * LM_GGML_RESTRICT y = (block_q8_0x4 *) vy; #if defined(__ARM_NEON) float32x4_t srcv[4][8]; @@ -496,7 +531,7 @@ static void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int6 #endif } -void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { +static void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) { assert(nrow == 4); UNUSED(nrow); if (blck_size_interleave == 4) { @@ -508,7 +543,7 @@ void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nro } } -void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { +static void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 4; @@ -527,67 +562,47 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(ncols_interleaved); UNUSED(blocklen); -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - if (lm_ggml_cpu_has_neon()) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v31.16b, #0x4\n" - "movi v30.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x22, %x[a_ptr], #0x2\n" - "movi v29.16b, #0x0\n" - "mov x21, %x[nb]\n" - "2:" // Block loop - "ldr q28, [%x[b_ptr], #0x0]\n" - "ldr q27, [x22, #0x0]\n" - "movi v26.4s, #0x0\n" - "sub x20, x22, #0x2\n" - "ldr q25, [x22, #0x10]\n" - "ldr q24, [%x[b_ptr], #0x10]\n" - "sub x21, x21, #0x1\n" - "add x22, x22, #0x22\n" - "ldr q23, [%x[b_ptr], #0x20]\n" - "ldr q22, [%x[b_ptr], #0x30]\n" - "ld1r { v21.8h }, [x20]\n" - "ldr q20, [%x[b_ptr], #-0x8]\n" - "sshl v16.16b, v28.16b, v31.16b\n" - "and v28.16b, v28.16b, v30.16b\n" - "sshl v19.16b, v24.16b, v31.16b\n" - "and v24.16b, v24.16b, v30.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "sshl v18.16b, v23.16b, v31.16b\n" - "and v23.16b, v23.16b, v30.16b\n" - ".inst 0x4f9be21a // sdot v26.4s, v16.16b, v27.4b[0]\n" - "sshl v17.16b, v22.16b, v31.16b\n" - "and v22.16b, v22.16b, v30.16b\n" - "fcvtl v21.4s, v21.4h\n" - "fcvtl v16.4s, v20.4h\n" - ".inst 0x4f99e39a // sdot v26.4s, v28.16b, v25.4b[0]\n" - "fmul v16.4s, v16.4s, v21.4s\n" - ".inst 0x4fbbe27a // sdot v26.4s, v19.16b, v27.4b[1]\n" - ".inst 0x4fb9e31a // sdot v26.4s, v24.16b, v25.4b[1]\n" - ".inst 0x4f9bea5a // sdot v26.4s, v18.16b, v27.4b[2]\n" - ".inst 0x4f99eafa // sdot v26.4s, v23.16b, v25.4b[2]\n" - ".inst 0x4fbbea3a // sdot v26.4s, v17.16b, v27.4b[3]\n" - ".inst 0x4fb9eada // sdot v26.4s, v22.16b, v25.4b[3]\n" - "scvtf v26.4s, v26.4s, #0x4\n" - "fmla v29.4s, v26.4s, v16.4s\n" - "cbnz x21, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q29, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22" - ); +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); + int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); + int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); + int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16_t a0 = vld1q_s8(a_ptr->qs); + int8x16_t a1 = vld1q_s8(a_ptr->qs + qk/2); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret = vdupq_n_s32(0); + + ret = vdotq_laneq_s32(ret, b0 << 4, a0, 0); + ret = vdotq_laneq_s32(ret, b1 << 4, a0, 1); + ret = vdotq_laneq_s32(ret, b2 << 4, a0, 2); + ret = vdotq_laneq_s32(ret, b3 << 4, a0, 3); + + ret = vdotq_laneq_s32(ret, b0 & 0xf0U, a1, 0); + ret = vdotq_laneq_s32(ret, b1 & 0xf0U, a1, 1); + ret = vdotq_laneq_s32(ret, b2 & 0xf0U, a1, 2); + ret = vdotq_laneq_s32(ret, b3 & 0xf0U, a1, 3); + + acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), + vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } return; } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) float sumf[4]; int sumi; @@ -613,7 +628,7 @@ void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void } } -void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { +static void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 4; @@ -632,72 +647,52 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(ncols_interleaved); UNUSED(blocklen); -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - - __asm__ __volatile__( - "movi v2.16b, #0x4\n" - "movi v1.16b, #0xf0\n" - "add %x[b_ptr], %x[b_ptr], #0x8\n" - "1:" // Column loop - "add x23, %x[a_ptr], #0x2\n" - "movi v0.16b, #0x0\n" - "mov x22, %x[nb]\n" - "2:" // Block loop - "ldr q31, [%x[b_ptr], #0x0]\n" - "ldr q30, [%x[b_ptr], #0x10]\n" - "mov x21, x23\n" - "movi v29.4s, #0x0\n" - "ldr q28, [%x[b_ptr], #0x20]\n" - "ldr q27, [%x[b_ptr], #0x30]\n" - "movi v26.4s, #0x0\n" - "sub x20, x23, #0x2\n" - "ld1r { v25.8h }, [x20]\n" - "ldr q24, [%x[b_ptr], #-0x8]\n" - "sub x22, x22, #0x1\n" - "add x23, x23, #0x22\n" - "ld1r { v23.2d }, [x21], #0x8\n" - "sshl v22.16b, v31.16b, v2.16b\n" - "sshl v16.16b, v30.16b, v2.16b\n" - "add %x[b_ptr], %x[b_ptr], #0x48\n" - "ld1r { v21.2d }, [x21], #0x8\n" - "sshl v20.16b, v28.16b, v2.16b\n" - "sshl v19.16b, v27.16b, v2.16b\n" - "ld1r { v18.2d }, [x21], #0x8\n" - "ld1r { v17.2d }, [x21], #0x8\n" - "and v31.16b, v31.16b, v1.16b\n" - "and v30.16b, v30.16b, v1.16b\n" - ".inst 0x4e9796dd // sdot v29.4s, v22.16b, v23.16b\n" - ".inst 0x4e97961a // sdot v26.4s, v16.16b, v23.16b\n" - "and v28.16b, v28.16b, v1.16b\n" - "and v27.16b, v27.16b, v1.16b\n" - "fcvtl v25.4s, v25.4h\n" - "fcvtl v16.4s, v24.4h\n" - ".inst 0x4e95969d // sdot v29.4s, v20.16b, v21.16b\n" - ".inst 0x4e95967a // sdot v26.4s, v19.16b, v21.16b\n" - "fmul v16.4s, v16.4s, v25.4s\n" - ".inst 0x4e9297fd // sdot v29.4s, v31.16b, v18.16b\n" - ".inst 0x4e9297da // sdot v26.4s, v30.16b, v18.16b\n" - ".inst 0x4e91979d // sdot v29.4s, v28.16b, v17.16b\n" - ".inst 0x4e91977a // sdot v26.4s, v27.16b, v17.16b\n" - "addp v29.4s, v29.4s, v26.4s\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "fmla v0.4s, v29.4s, v16.4s\n" - "cbnz x22, 2b\n" - "sub %x[nc], %x[nc], #0x4\n" - "str q0, [%x[res_ptr], #0x0]\n" - "add %x[res_ptr], %x[res_ptr], #0x10\n" - "cbnz %x[nc], 1b\n" - : [b_ptr] "+&r" (b_ptr), [res_ptr] "+&r" (res_ptr), [nc] "+&r" (nc) - : [a_ptr] "r" (a_ptr), [nb] "r" (nb) - : "memory", "v0", "v1", "v2", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x20", "x21", "x22", "x23" - ); +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16_t b0 = vld1q_s8((const int8_t *) b_ptr->qs); + int8x16_t b1 = vld1q_s8((const int8_t *) b_ptr->qs + 16); + int8x16_t b2 = vld1q_s8((const int8_t *) b_ptr->qs + 32); + int8x16_t b3 = vld1q_s8((const int8_t *) b_ptr->qs + 48); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16_t a0 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs); + int8x16_t a1 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 1); + int8x16_t a2 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 2); + int8x16_t a3 = (int8x16_t) vld1q_dup_s64((const int64_t *) a_ptr->qs + 3); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret0 = vdupq_n_s32(0); + int32x4_t ret1 = vdupq_n_s32(0); + + ret0 = vdotq_s32(ret0, b0 << 4, a0); + ret1 = vdotq_s32(ret1, b1 << 4, a0); + ret0 = vdotq_s32(ret0, b2 << 4, a1); + ret1 = vdotq_s32(ret1, b3 << 4, a1); + + ret0 = vdotq_s32(ret0, b0 & 0xf0U, a2); + ret1 = vdotq_s32(ret1, b1 & 0xf0U, a2); + ret0 = vdotq_s32(ret0, b2 & 0xf0U, a3); + ret1 = vdotq_s32(ret1, b3 & 0xf0U, a3); + + int32x4_t ret = vpaddq_s32(ret0, ret1); + + acc = vfmaq_f32(acc, vcvtq_n_f32_s32(ret, 4), + vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } return; } -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) float sumf[4]; int sumi; @@ -723,7 +718,7 @@ void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void } } -void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { +static void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 8; @@ -996,7 +991,103 @@ void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void } } -void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { +static void lm_ggml_gemv_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) { + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float * res_ptr = s; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf = vdupq_n_f32(0); + for (int l = 0; l < nb; l++) { + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); + uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); + uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); + uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + + int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); + int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); + int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); + int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); + int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); + int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); + int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); + int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); + + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + + int32x4_t sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); + sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); + sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); + sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); + sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); + sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); + sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); + sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); + + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + float32x4_t d = a_d * b_d; + + sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); + } + + vst1q_f32(res_ptr + x * 4, sumf); + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +static void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 4; @@ -1017,7 +1108,7 @@ void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void UNUSED(blocklen); #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - if (lm_ggml_cpu_has_neon()) { + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) { const void * b_ptr = vx; const void * a_ptr = vy; float * res_ptr = s; @@ -1512,7 +1603,7 @@ void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * restrict s, size_t bs, const void } } -void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { +static void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 4; @@ -1966,7 +2057,7 @@ void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * restrict s, size_t bs, const void } } -void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, const void * restrict vy, int nr, int nc) { +static void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 8; @@ -2486,31 +2577,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) // Shuffle pattern one - right side input - const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) - const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) - const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) - const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) - const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) - const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) - const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) - const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) // Shuffle pattern two - right side input - const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) - const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) - const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) - const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) - const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) - const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) - const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) - const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) // Scale values - Load the weight scale values of two block_q4_0x8 const __m512 col_scale_f32 = LM_GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); @@ -2544,31 +2635,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void // Shuffle pattern one - left side input - const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) // Shuffle pattern two - left side input - const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version @@ -2597,10 +2688,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void // Straighten out to make 4 row vectors - __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); - __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); - __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); - __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); @@ -2679,31 +2770,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void const __m512i rhs_mat_2367ABEF_3 = _mm512_shuffle_epi8(signextendlutexpanded, _mm512_and_si512(_mm512_srli_epi16(rhs_raw_mat_2367ABEF_1, 4), m4bexpanded)); //B2(24-31) B3(24-31) B6(24-31) B7(24-31) BA(24-31) BB(24-31) BE(24-31) BF(24-31) // Shuffle pattern one - right side input - const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) - const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) + const __m512i rhs_mat_014589CD_0_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)136); //B0(0-3) B1(0-3) B0(0-3) B1(0-3) B4(0-3) B5(0-3) B4(0-3) B5(0-3) B8(0-3) B9(0-3) B8(0-3) B9(0-3) BC(0-3) BD(0-3) BC(0-3) BD(0-3) + const __m512i rhs_mat_2367ABEF_0_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)136); //B2(0-3) B3(0-3) B2(0-3) B3(0-3) B6(0-3) B7(0-3) B6(0-3) B7(0-3) BA(0-3) BB(0-3) BA(0-3) BB(0-3) BE(0-3) BF(0-3) BE(0-3) BF(0-3) - const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) - const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) + const __m512i rhs_mat_014589CD_1_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)136); //B0(8-11) B1(8-11) B0(8-11) B1(8-11) B4(8-11) B5(8-11) B4(8-11) B5(8-11) B8(8-11) B9(8-11) B8(8-11) B9(8-11) BC(8-11) BD(8-11) BC(8-11) BD(8-11) + const __m512i rhs_mat_2367ABEF_1_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)136); //B2(8-11) B3(8-11) B2(8-11) B3(8-11) B6(8-11) B7(8-11) B6(8-11) B7(8-11) BA(8-11) BB(8-11) BA(8-11) BB(8-11) BE(8-11) BF(8-11) BE(8-11) BF(8-11) - const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) - const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) + const __m512i rhs_mat_014589CD_2_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)136); //B0(16-19) B1(16-19) B0(16-19) B1(16-19) B4(16-19) B5(16-19) B4(16-19) B5(16-19) B8(16-19) B9(16-19) B8(16-19) B9(16-19) BC(16-19) BD(16-19) BC(16-19) BD(16-19) + const __m512i rhs_mat_2367ABEF_2_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)136); //B2(16-19) B3(16-19) B2(16-19) B3(16-19) B6(16-19) B7(16-19) B6(16-19) B7(16-19) BA(16-19) BB(16-19) BA(16-19) BB(16-19) BE(16-19) BF(16-19) BE(16-19) BF(16-19) - const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) - const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) + const __m512i rhs_mat_014589CD_3_sp1 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)136); //B0(24-27) B1(24-27) B0(24-27) B1(24-27) B4(24-27) B5(24-27) B4(24-27) B5(24-27) B8(24-27) B9(24-27) B8(24-27) B9(24-27) BC(24-27) BD(24-27) BC(24-27) BD(24-27) + const __m512i rhs_mat_2367ABEF_3_sp1 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)136); //B2(24-27) B3(24-27) B2(24-27) B3(24-27) B6(24-27) B7(24-27) B6(24-27) B7(24-27) BA(24-27) BB(24-27) BA(24-27) BB(24-27) BE(24-27) BF(24-27) BE(24-27) BF(24-27) // Shuffle pattern two - right side input - const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, 221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) - const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, 221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) + const __m512i rhs_mat_014589CD_0_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_0, (_MM_PERM_ENUM)221); //B0(4-7) B1(4-7) B0(4-7) B1(4-7) B4(4-7) B5(4-7) B4(4-7) B5(4-7) B8(4-7) B9(4-7) B8(4-7) B9(4-7) BC(4-7) BD(4-7) BC(4-7) BD(4-7) + const __m512i rhs_mat_2367ABEF_0_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_0, (_MM_PERM_ENUM)221); //B2(4-7) B3(4-7) B2(4-7) B3(4-7) B6(4-7) B7(4-7) B6(4-7) B7(4-7) BA(4-7) BB(4-7) BA(4-7) BB(4-7) BE(4-7) BF(4-7) BE(4-7) BF(4-7) - const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, 221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) - const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, 221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) + const __m512i rhs_mat_014589CD_1_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_1, (_MM_PERM_ENUM)221); //B0(12-15) B1(12-15) B0(12-15) B1(12-15) B4(12-15) B5(12-15) B4(12-15) B5(12-15) B8(12-15) B9(12-15) B8(12-15) B9(12-15) BC(12-15) BD(12-15) BC(12-15) BD(12-15) + const __m512i rhs_mat_2367ABEF_1_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_1, (_MM_PERM_ENUM)221); //B2(12-15) B3(12-15) B2(12-15) B3(12-15) B6(12-15) B7(12-15) B6(12-15) B7(12-15) BA(12-15) BB(12-15) BA(12-15) BB(12-15) BE(12-15) BF(12-15) BE(12-15) BF(12-15) - const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, 221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) - const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, 221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) + const __m512i rhs_mat_014589CD_2_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_2, (_MM_PERM_ENUM)221); //B0(20-23) B1(20-23) B0(20-23) B1(20-23) B4(20-23) B5(20-23) B4(20-23) B5(20-23) B8(20-23) B9(20-23) B8(20-23) B9(20-23) BC(20-23) BD(20-23) BC(20-23) BD(20-23) + const __m512i rhs_mat_2367ABEF_2_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_2, (_MM_PERM_ENUM)221); //B2(20-23) B3(20-23) B2(20-23) B3(20-23) B6(20-23) B7(20-23) B6(20-23) B7(20-23) BA(20-23) BB(20-23) BA(20-23) BB(20-23) BE(20-23) BF(20-23) BE(20-23) BF(20-23) - const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, 221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) - const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, 221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) + const __m512i rhs_mat_014589CD_3_sp2 = _mm512_shuffle_epi32(rhs_mat_014589CD_3, (_MM_PERM_ENUM)221); //B0(28-31) B1(28-31) B0(28-31) B1(28-31) B4(28-31) B5(28-31) B4(28-31) B5(28-31) B8(28-31) B9(28-31) B8(28-31) B9(28-31) BC(28-31) BD(28-31) BC(28-31) BD(28-31) + const __m512i rhs_mat_2367ABEF_3_sp2 = _mm512_shuffle_epi32(rhs_mat_2367ABEF_3, (_MM_PERM_ENUM)221); //B2(28-31) B3(28-31) B2(28-31) B3(28-31) B6(28-31) B7(28-31) B6(28-31) B7(28-31) BA(28-31) BB(28-31) BA(28-31) BB(28-31) BE(28-31) BF(28-31) BE(28-31) BF(28-31) // Scale values - Load the weight scale values of two block_q4_0x8 @@ -2735,31 +2826,31 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void // Shuffle pattern one - left side input - const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, 160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) - const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, 160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) + const __m512i lhs_mat_01_0_sp1 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)160); //A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) A0(0-3) A0(0-3) A1(0-3) A1(0-3) + const __m512i lhs_mat_23_0_sp1 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)160); //A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) A2(0-3) A2(0-3) A3(0-3) A3(0-3) - const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, 160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) - const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, 160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) + const __m512i lhs_mat_01_1_sp1 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)160); //A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) A0(8-11) A0(8-11) A1(8-11) A1(8-11) + const __m512i lhs_mat_23_1_sp1 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)160); //A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) A2(8-11) A2(8-11) A3(8-11) A3(8-11) - const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, 160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) - const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, 160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) + const __m512i lhs_mat_01_2_sp1 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)160); //A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) A0(16-19) A0(16-19) A1(16-19) A1(16-19) + const __m512i lhs_mat_23_2_sp1 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)160); //A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) A2(16-19) A2(16-19) A3(16-19) A3(16-19) - const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, 160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) - const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, 160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) + const __m512i lhs_mat_01_3_sp1 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)160); //A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) A0(24-27) A0(24-27) A1(24-27) A1(24-27) + const __m512i lhs_mat_23_3_sp1 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)160); //A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) A2(24-27) A2(24-27) A3(24-27) A3(24-27) // Shuffle pattern two - left side input - const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, 245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) - const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, 245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) + const __m512i lhs_mat_01_0_sp2 = _mm512_shuffle_epi32(lhs_mat_01_0, (_MM_PERM_ENUM)245); //A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) A0(4-7) A0(4-7) A1(4-7) A1(4-7) + const __m512i lhs_mat_23_0_sp2 = _mm512_shuffle_epi32(lhs_mat_23_0, (_MM_PERM_ENUM)245); //A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) A2(4-7) A2(4-7) A3(4-7) A3(4-7) - const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, 245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) - const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, 245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) + const __m512i lhs_mat_01_1_sp2 = _mm512_shuffle_epi32(lhs_mat_01_1, (_MM_PERM_ENUM)245); //A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) A0(12-15) A0(12-15) A1(12-15) A1(12-15) + const __m512i lhs_mat_23_1_sp2 = _mm512_shuffle_epi32(lhs_mat_23_1, (_MM_PERM_ENUM)245); //A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) A2(12-15) A2(12-15) A3(12-15) A3(12-15) - const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, 245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) - const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, 245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) + const __m512i lhs_mat_01_2_sp2 = _mm512_shuffle_epi32(lhs_mat_01_2, (_MM_PERM_ENUM)245); //A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) A0(20-23) A0(20-23) A1(20-23) A1(20-23) + const __m512i lhs_mat_23_2_sp2 = _mm512_shuffle_epi32(lhs_mat_23_2, (_MM_PERM_ENUM)245); //A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) A2(20-23) A2(20-23) A3(20-23) A3(20-23) - const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, 245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) - const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, 245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) + const __m512i lhs_mat_01_3_sp2 = _mm512_shuffle_epi32(lhs_mat_01_3, (_MM_PERM_ENUM)245); //A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) A0(28-31) A0(28-31) A1(28-31) A1(28-31) + const __m512i lhs_mat_23_3_sp2 = _mm512_shuffle_epi32(lhs_mat_23_3, (_MM_PERM_ENUM)245); //A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) A2(28-31) A2(28-31) A3(28-31) A3(28-31) // The values arranged in shuffle patterns are operated with dot product operation within 32 bit lane i.e corresponding bytes and multiplied and added into 32 bit integers within 32 bit lane // Resembles MMLAs into 2x2 matrices in ARM Version @@ -2788,10 +2879,10 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void // Straighten out to make 4 row vectors - __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, 78)); - __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, 78), iacc_mat_01); - __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, 78)); - __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, 78), iacc_mat_11); + __m512i iacc_row_0 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_00, _mm512_shuffle_epi32(iacc_mat_01, (_MM_PERM_ENUM)78)); + __m512i iacc_row_1 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_00, (_MM_PERM_ENUM)78), iacc_mat_01); + __m512i iacc_row_2 = _mm512_mask_blend_epi32(0xCCCC, iacc_mat_10, _mm512_shuffle_epi32(iacc_mat_11, (_MM_PERM_ENUM)78)); + __m512i iacc_row_3 = _mm512_mask_blend_epi32(0xCCCC, _mm512_shuffle_epi32(iacc_mat_10, (_MM_PERM_ENUM)78), iacc_mat_11); // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); @@ -3386,7 +3477,117 @@ void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void } } -// FIXME: this code is duplicated from ggml-aarch64.c +static void lm_ggml_gemm_iq4_nl_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) { + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k = 0; k < 4; k++) { + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); + int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; + } +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * LM_GGML_FP16_TO_FP32(b_ptr[l].d[j]) * LM_GGML_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) { block_q4_0x4 out; @@ -3456,20 +3657,20 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in return out; } -static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * restrict data, size_t data_size) { +static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) { LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0); LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + constexpr int nrows_interleaved = 4; block_q4_0x4 * dst = (block_q4_0x4 *)t->data; const block_q4_0 * src = (const block_q4_0 *)data; block_q4_0 dst_tmp[4]; - int nrow = t->ne[1]; // Number of rows - int nrows_interleaved = 4; + int nrow = lm_ggml_nrows(t); int nblocks = t->ne[0] / QK4_0; LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); - if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; } @@ -3487,20 +3688,20 @@ static int repack_q4_0_to_q4_0_4_bl(struct lm_ggml_tensor * t, int interleave_bl LM_GGML_UNUSED(data_size); } -static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor *t, int interleave_block, const void * restrict data, size_t data_size) { +static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) { LM_GGML_ASSERT(t->type == LM_GGML_TYPE_Q4_0); LM_GGML_ASSERT(interleave_block == 8); + constexpr int nrows_interleaved = 8; block_q4_0x8 * dst = (block_q4_0x8*)t->data; const block_q4_0 * src = (const block_q4_0*) data; block_q4_0 dst_tmp[8]; - int nrow = t->ne[1]; // Number of rows - int nrows_interleaved = 8; + int nrow = lm_ggml_nrows(t); int nblocks = t->ne[0] / QK4_0; LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); - if (nrow % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { return -1; } @@ -3518,43 +3719,524 @@ static int repack_q4_0_to_q4_0_8_bl(struct lm_ggml_tensor *t, int interleave_blo LM_GGML_UNUSED(data_size); } -// Prepare for optimized kernels if applicable -void lm_ggml_aarch64_repack_tensor(struct lm_ggml_tensor * cur, enum lm_ggml_type repack_type, const void * restrict data, size_t data_size) { - if (cur->type == repack_type) { - memcpy(cur->data, data, data_size); - return; +static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx4 out; + + for (int i = 0; i < 4; i++) { + out.d[i] = in[i].d; } - LM_GGML_ASSERT(cur->type == LM_GGML_TYPE_Q4_0); + const int end = QK4_NL * 2 / blck_size_interleave; - switch (repack_type) { - case LM_GGML_TYPE_Q4_0_8_8: - repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size); - break; - case LM_GGML_TYPE_Q4_0_4_8: - repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size); - break; - case LM_GGML_TYPE_Q4_0_4_4: - repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size); + // TODO: this branch seems wrong + //if (blck_size_interleave == 8) { + // for (int i = 0; i < end; ++i) { + // int src_id = i % 4; + // int src_offset = (i / 4) * blck_size_interleave; + // int dst_offset = i * blck_size_interleave; + + // // Using memcpy to avoid unaligned memory accesses + // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + // } + //} else + if (blck_size_interleave == 4) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); + } + } else { + LM_GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_4_bl(struct lm_ggml_tensor * t, int interleave_block, const void * LM_GGML_RESTRICT data, size_t data_size) { + LM_GGML_ASSERT(t->type == LM_GGML_TYPE_IQ4_NL); + //LM_GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + LM_GGML_ASSERT(interleave_block == 4); + + block_iq4_nlx4 * dst = (block_iq4_nlx4 *)t->data; + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nl dst_tmp[4]; + int nrow = lm_ggml_nrows(t); + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK4_0; + + LM_GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + LM_GGML_UNUSED(data_size); +} + +namespace ggml::cpu::aarch64 { +// repack +template +int repack(struct lm_ggml_tensor *, const void *, size_t); + +// TODO: generalise. +template <> int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size); +} + +template <> int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size); +} + +template <> int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size); +} + +template <> int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); +} + +// TODO: needs to be revisited +//template <> int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) { +// return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size); +//} + +// gemv +template +void gemv(int, float *, size_t, const void *, const void *, int, int); + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + lm_ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + lm_ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + lm_ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> +void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + lm_ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +// gemm +template +void gemm(int, float *, size_t, const void *, const void *, int, int); + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + lm_ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + lm_ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + lm_ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> +void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + lm_ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) = 0; +}; + +template class tensor_traits : public tensor_traits_base { + + bool work_size(int /* n_threads */, const struct lm_ggml_tensor * op, size_t & size) override { + // not realy a LM_GGML_TYPE_Q8_0 but same size. + switch (op->op) { + case LM_GGML_OP_MUL_MAT: + size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1])); + return true; + case LM_GGML_OP_MUL_MAT_ID: + size = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, lm_ggml_nelements(op->src[1])); + size = LM_GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + size += sizeof(int64_t) * (1+op->src[0]->ne[2]) * op->src[1]->ne[2]; + return true; + default: + // LM_GGML_ABORT("fatal error"); break; + } + return false; + } + + bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) override { + switch (op->op) { + case LM_GGML_OP_MUL_MAT: + forward_mul_mat(params, op); + return true; + case LM_GGML_OP_MUL_MAT_ID: + forward_mul_mat_id(params, op); + return true; default: - LM_GGML_ABORT("Unsupported type"); + // LM_GGML_ABORT("fatal error"); + break; + } + return false; } -} -enum lm_ggml_type lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) { + void forward_mul_mat(lm_ggml_compute_params * params, lm_ggml_tensor * op) { + const lm_ggml_tensor * src0 = op->src[0]; + const lm_ggml_tensor * src1 = op->src[1]; + lm_ggml_tensor * dst = op; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + LM_GGML_ASSERT(ne0 == ne01); + LM_GGML_ASSERT(ne1 == ne11); + LM_GGML_ASSERT(ne2 == ne12); + LM_GGML_ASSERT(ne3 == ne13); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + LM_GGML_ASSERT(lm_ggml_n_dims(op->src[0]) == 2); + // LM_GGML_ASSERT(lm_ggml_n_dims(op->src[1]) == 2); + + char * wdata = static_cast(params->wdata); + const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10); + + assert(params->wsize >= nbw1 * ne11); + + const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float; + + int64_t i11_processed = 0; + for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { + quantize_mat_q8_0((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), 4, ne10, + INTER_SIZE); + } + i11_processed = ne11 - ne11 % 4; + for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + from_float((float *) ((char *) src1->data + i11 * nb11), (void *) (wdata + i11 * nbw1), ne10); + } + + lm_ggml_barrier(params->threadpool); + + const void * src1_wdata = params->wdata; + const size_t src1_col_stride = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10); + int64_t src0_start = (ith * ne01) / nth; + int64_t src0_end = ((ith + 1) * ne01) / nth; + src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start; + src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end; + if (src0_start >= src0_end) { + return; + } + + // If there are more than three rows in src1, use gemm; otherwise, use gemv. + if (ne11 > 3) { + gemm(ne00, (float *) ((char *) dst->data) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); + } + for (int iter = ne11 - ne11 % 4; iter < ne11; iter++) { + gemv(ne00, (float *) ((char *) dst->data + (iter * nb1)) + src0_start, ne01, + (const char *) src0->data + src0_start * nb01, + (const char *) src1_wdata + (src1_col_stride * iter), 1, + src0_end - src0_start); + } + } + + void forward_mul_mat_id(lm_ggml_compute_params * params, lm_ggml_tensor * op) { + const lm_ggml_tensor * src0 = op->src[0]; + const lm_ggml_tensor * src1 = op->src[1]; + const lm_ggml_tensor * ids = op->src[2]; + lm_ggml_tensor * dst = op; + + LM_GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const lm_ggml_from_float_t from_float = lm_ggml_get_type_traits_cpu(LM_GGML_TYPE_Q8_0)->from_float; + + // we don't support permuted src0 or src1 + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type)); + LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + LM_GGML_ASSERT(nb0 == sizeof(float)); + LM_GGML_ASSERT(nb0 <= nb1); + LM_GGML_ASSERT(nb1 <= nb2); + LM_GGML_ASSERT(nb2 <= nb3); + + LM_GGML_ASSERT(ne03 == 1); + LM_GGML_ASSERT(ne13 == 1); + LM_GGML_ASSERT(ne3 == 1); + + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + const size_t nbw1 = lm_ggml_row_size(LM_GGML_TYPE_Q8_0, ne10); + const size_t nbw2 = nbw1*ne11; + const size_t nbw3 = nbw2*ne12; + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + LM_GGML_ASSERT(params->wsize >= (LM_GGML_PAD(nbw3, sizeof(int64_t)) + n_as * sizeof(int64_t) + + n_as * ne12 * sizeof(mmid_row_mapping))); + + auto wdata = (char *) params->wdata; + auto wdata_src1_end = (char *) wdata + LM_GGML_PAD(nbw3, sizeof(int64_t)); + int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] + + // src1: float32 => block_q8_0 + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { + from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11), + (void *) (wdata + i12 * nbw2 + i11 * nbw1), + ne10); + } + } + +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)] + + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + LM_GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } + } + } + + lm_ggml_barrier(params->threadpool); + + // compute each matrix multiplication in sequence + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + auto src0_cur = (const char *) src0->data + cur_a*nb02; + + //const int64_t nr0 = ne01; // src0 rows + const int64_t nr1 = cne1; // src1 rows + + int64_t src0_cur_start = (ith * ne01) / nth; + int64_t src0_cur_end = ((ith + 1) * ne01) / nth; + src0_cur_start = + (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + if (src0_cur_start >= src0_cur_end) return; + + for (int ir1 = 0; ir1 < nr1; ir1++) { + struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + auto src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); + + gemv( + ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, + ne01, src0_cur + src0_cur_start * nb01, + src1_col, 1, src0_cur_end - src0_cur_start); + } + } +#undef MMID_MATRIX_ROW + } + + int repack(struct lm_ggml_tensor * t, const void * data, size_t data_size) override { + LM_GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, lm_ggml_type_name(t->type), + (int) NB_COLS, (int) INTER_SIZE); + return ggml::cpu::aarch64::repack(t, data, data_size); + } +}; + +// instance for Q4 +static const tensor_traits q4_0_4x4_q8_0; +static const tensor_traits q4_0_4x8_q8_0; +static const tensor_traits q4_0_8x8_q8_0; + +// instance for IQ4 +static const tensor_traits iq4_nl_4x4_q8_0; + +} // namespace ggml::cpu::aarch64 + +static const ggml::cpu::tensor_traits * lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur) { if (cur->type == LM_GGML_TYPE_Q4_0) { - // TODO: enable for AVX2 - currently disabled due to bad gemv performance - if (/* lm_ggml_cpu_has_avx2() || */ (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) { - return LM_GGML_TYPE_Q4_0_8_8; + if (lm_ggml_cpu_has_avx2() || (lm_ggml_cpu_has_sve() && lm_ggml_cpu_has_matmul_int8() && lm_ggml_cpu_get_sve_cnt() == QK8_0)) { + if (cur->ne[1] % 8 == 0) { + return &ggml::cpu::aarch64::q4_0_8x8_q8_0; + } } if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_matmul_int8()) { - return LM_GGML_TYPE_Q4_0_4_8; + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::aarch64::q4_0_4x8_q8_0; + } } - if (lm_ggml_cpu_has_neon()) { - return LM_GGML_TYPE_Q4_0_4_4; + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::aarch64::q4_0_4x4_q8_0; + } + } + } else if (cur->type == LM_GGML_TYPE_IQ4_NL) { + if (lm_ggml_cpu_has_neon() && lm_ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &ggml::cpu::aarch64::iq4_nl_4x4_q8_0; + } } } - return cur->type; + return nullptr; +} + +static void lm_ggml_backend_cpu_aarch64_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) { + tensor->extra = (void *) const_cast(lm_ggml_aarch64_get_optimal_repack_type(tensor)); + + LM_GGML_UNUSED(buffer); +} + +static void lm_ggml_backend_cpu_aarch64_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, + const void * data, size_t offset, size_t size) { + LM_GGML_ASSERT(offset == 0); + LM_GGML_ASSERT(size == lm_ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::aarch64::tensor_traits_base *) tensor->extra; + auto OK = tensor_traits->repack(tensor, data, size); + + LM_GGML_ASSERT(OK == 0); + LM_GGML_UNUSED(buffer); +} + +static const char * lm_ggml_backend_cpu_aarch64_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { + return "CPU_AARCH64"; + + LM_GGML_UNUSED(buft); +} + +static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { + lm_ggml_backend_buffer_t buffer = lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = lm_ggml_backend_cpu_aarch64_buffer_init_tensor; + buffer->iface.set_tensor = lm_ggml_backend_cpu_aarch64_buffer_set_tensor; + return buffer; +} + +static size_t lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment(lm_ggml_backend_buffer_type_t buft) { + return TENSOR_ALIGNMENT; + + LM_GGML_UNUSED(buft); +} + +namespace ggml::cpu::aarch64 { +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(lm_ggml_backend_dev_t, const struct lm_ggml_tensor * op) override { + if ( op->op == LM_GGML_OP_MUL_MAT && + op->src[0]->buffer && + (lm_ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type() && + lm_ggml_aarch64_get_optimal_repack_type(op->src[0]) + ) { + if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == LM_GGML_TYPE_F32) { + return true; + } + //if (op->src[1]->type == LM_GGML_TYPE_Q8_0) { + // return true; + //} + // may be possible if Q8_0 packed... + } else if (op->op == LM_GGML_OP_MUL_MAT_ID + && op->src[0]->buffer + && (lm_ggml_n_dims(op->src[0]) == 3) + && op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type() + && lm_ggml_aarch64_get_optimal_repack_type(op->src[0]) + ) { + if (op->src[1]->buffer && !lm_ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == LM_GGML_TYPE_F32) { + return true; + } + //if (op->src[1]->type == LM_GGML_TYPE_Q8_0) { + // return true; + //} + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) override { + if (op->op == LM_GGML_OP_MUL_MAT || op->op == LM_GGML_OP_MUL_MAT_ID) { + if (op->src[0]->buffer && op->src[0]->buffer->buft == lm_ggml_backend_cpu_aarch64_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + } + return nullptr; + } +}; +} // namespace ggml::cpu::aarch64 + +lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void) { + static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_aarch64 = { + /* .iface = */ { + /* .get_name = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_name, + /* .alloc_buffer = */ lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer, + /* .get_alignment = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ nullptr, // defaults to lm_ggml_nbytes + /* .is_host = */ nullptr, + }, + /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), + /* .context = */ new ggml::cpu::aarch64::extra_buffer_type(), + }; + + return &lm_ggml_backend_cpu_buffer_type_aarch64; } diff --git a/cpp/ggml-cpu-aarch64.h b/cpp/ggml-cpu-aarch64.h index dd7b10fc..0b15a23d 100644 --- a/cpp/ggml-cpu-aarch64.h +++ b/cpp/ggml-cpu-aarch64.h @@ -1,30 +1,8 @@ #pragma once +#include "ggml-cpu-traits.h" #include "ggml.h" // GGML internal header -#ifdef __cplusplus -extern "C" { -#endif - -// Quantization -void quantize_mat_q8_0(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave); - -// GEMV -void lm_ggml_gemv_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); -void lm_ggml_gemv_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); -void lm_ggml_gemv_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); - -// GEMM -void lm_ggml_gemm_q4_0_4x4_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); -void lm_ggml_gemm_q4_0_4x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); -void lm_ggml_gemm_q4_0_8x8_q8_0(int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT vx, const void * LM_GGML_RESTRICT vy, int nr, int nc); - -void lm_ggml_aarch64_repack_tensor(struct lm_ggml_tensor * cur, enum lm_ggml_type repack_type, const void * data, size_t data_size); -enum lm_ggml_type lm_ggml_aarch64_get_optimal_repack_type(const struct lm_ggml_tensor * cur); - -#ifdef __cplusplus -} -#endif - +lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void); diff --git a/cpp/ggml-cpu-impl.h b/cpp/ggml-cpu-impl.h index 47f0cad5..23c4e104 100644 --- a/cpp/ggml-cpu-impl.h +++ b/cpp/ggml-cpu-impl.h @@ -15,6 +15,18 @@ extern "C" { #endif +struct lm_ggml_compute_params { + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + + struct lm_ggml_threadpool * threadpool; +}; + + #if defined(_MSC_VER) #define m512bh(p) p @@ -366,6 +378,9 @@ static __m256 __lasx_xvreplfr2vr_s(float val) { } #endif +// TODO: move to ggml-threading +void lm_ggml_barrier(struct lm_ggml_threadpool * tp); + #ifdef __cplusplus } #endif diff --git a/cpp/ggml-cpu-quants.c b/cpp/ggml-cpu-quants.c index 3a2cfe34..dd2bf146 100644 --- a/cpp/ggml-cpu-quants.c +++ b/cpp/ggml-cpu-quants.c @@ -1791,11 +1791,12 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void const int8x16_t y1_l = vld1q_s8(b_y1->qs); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - float32_t _scale[4] = { LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), - LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), - LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), - LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d)}; - + float32_t _scale[4] = { + LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d) + }; float32x4_t scale = vld1q_f32(_scale); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); @@ -1811,13 +1812,15 @@ void lm_ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); + l1, r1)), l2, r2)), l3, r3))), scale); } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); + return; } #endif @@ -2345,10 +2348,12 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void const block_q8_1 * restrict b_y0 = &vy0[i]; const block_q8_1 * restrict b_y1 = &vy1[i]; - float32_t summs_t[4] = {LM_GGML_FP16_TO_FP32(b_x0->m) * LM_GGML_FP16_TO_FP32(b_y0->s), - LM_GGML_FP16_TO_FP32(b_x1->m) * LM_GGML_FP16_TO_FP32(b_y0->s), - LM_GGML_FP16_TO_FP32(b_x0->m) * LM_GGML_FP16_TO_FP32(b_y1->s), - LM_GGML_FP16_TO_FP32(b_x1->m) * LM_GGML_FP16_TO_FP32(b_y1->s)}; + float32_t summs_t[4] = { + LM_GGML_FP16_TO_FP32(b_x0->m) * LM_GGML_FP16_TO_FP32(b_y0->s), + LM_GGML_FP16_TO_FP32(b_x1->m) * LM_GGML_FP16_TO_FP32(b_y0->s), + LM_GGML_FP16_TO_FP32(b_x0->m) * LM_GGML_FP16_TO_FP32(b_y1->s), + LM_GGML_FP16_TO_FP32(b_x1->m) * LM_GGML_FP16_TO_FP32(b_y1->s) + }; summs0 = vaddq_f32(summs0, vld1q_f32(summs_t)); const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -2369,10 +2374,12 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); // mmla into int32x4_t - float32_t _scale[4] = {LM_GGML_FP16_TO_FP32(b_x0->d)*b_y0->d, - LM_GGML_FP16_TO_FP32(b_x0->d)*b_y1->d, - LM_GGML_FP16_TO_FP32(b_x1->d)*b_y0->d, - LM_GGML_FP16_TO_FP32(b_x1->d)*b_y1->d}; + float32_t _scale[4] = { + LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d) + }; float32x4_t scale = vld1q_f32(_scale); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); @@ -2387,15 +2394,17 @@ void lm_ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void int8x16_t r2 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); + l1, r1)), l2, r2)), l3, r3))), scale); } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); + sumv2 = vaddq_f32(sumv2, summs0); vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); + return; } #endif @@ -3372,10 +3381,12 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void const int8x16_t y1_l = vld1q_s8(b_y1->qs); const int8x16_t y1_h = vld1q_s8(b_y1->qs + 16); - float32_t _scale[4] = {LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), - LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), - LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), - LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d)}; + float32_t _scale[4] = { + LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x0->d)*LM_GGML_FP16_TO_FP32(b_y1->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y0->d), + LM_GGML_FP16_TO_FP32(b_x1->d)*LM_GGML_FP16_TO_FP32(b_y1->d) + }; float32x4_t scale = vld1q_f32(_scale); int8x16_t l0 = vreinterpretq_s8_s64(vzip1q_s64(vreinterpretq_s64_s8(x0_l), vreinterpretq_s64_s8(x1_l))); @@ -3391,13 +3402,15 @@ void lm_ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void int8x16_t r3 = vreinterpretq_s8_s64(vzip2q_s64(vreinterpretq_s64_s8(y0_h), vreinterpretq_s64_s8(y1_h))); sumv0 = vmlaq_f32(sumv0,(vcvtq_f32_s32(vmmlaq_s32((vmmlaq_s32((vmmlaq_s32((vmmlaq_s32(vdupq_n_s32(0), l0, r0)), - l1, r1)), l2, r2)), l3, r3))), scale); + l1, r1)), l2, r2)), l3, r3))), scale); } - float32x4_t sumv1 = vextq_f32(sumv0, sumv0, 2); + + float32x4_t sumv1 = vextq_f32 (sumv0, sumv0, 2); float32x4_t sumv2 = vzip1q_f32(sumv0, sumv1); - vst1_f32(s, vget_low_f32(sumv2)); + vst1_f32(s, vget_low_f32 (sumv2)); vst1_f32(s + bs, vget_high_f32(sumv2)); + return; } #endif diff --git a/cpp/ggml-cpu-traits.cpp b/cpp/ggml-cpu-traits.cpp new file mode 100644 index 00000000..84a276f9 --- /dev/null +++ b/cpp/ggml-cpu-traits.cpp @@ -0,0 +1,36 @@ +#include "ggml-cpu-traits.h" + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" + +namespace ggml::cpu { +tensor_traits::~tensor_traits() {} + +extra_buffer_type::~extra_buffer_type() {} +} // namespace ggml::cpu + +bool lm_ggml_cpu_extra_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) { + for (auto extra : lm_ggml_backend_cpu_get_extra_buffers_type()) { + if (extra && extra->context) { + auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; + auto tensor_traits = buf_extra->get_tensor_traits(op); + if (tensor_traits && tensor_traits->compute_forward(params, op)) { + return true; + } + } + } + return false; +} + +bool lm_ggml_cpu_extra_work_size(int n_threads, const struct lm_ggml_tensor * op, size_t * size) { + for (auto extra : lm_ggml_backend_cpu_get_extra_buffers_type()) { + if (extra && extra->context) { + auto buf_extra = (ggml::cpu::extra_buffer_type *) extra->context; + auto tensor_traits = buf_extra->get_tensor_traits(op); + if (tensor_traits && tensor_traits->work_size(n_threads, op, *size)) { + return true; + } + } + } + return false; +} diff --git a/cpp/ggml-cpu-traits.h b/cpp/ggml-cpu-traits.h new file mode 100644 index 00000000..58349294 --- /dev/null +++ b/cpp/ggml-cpu-traits.h @@ -0,0 +1,38 @@ +#pragma once +#include "ggml-backend-impl.h" +#include "ggml-cpu-impl.h" +#include "ggml.h" + +#ifdef __cplusplus +# include +extern "C" { +#endif + +// return true if op part of extra "accelerator" +bool lm_ggml_cpu_extra_compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op); +bool lm_ggml_cpu_extra_work_size(int n_threads, const struct lm_ggml_tensor * op, size_t * size); + +#ifdef __cplusplus +} + +namespace ggml::cpu { +// register in tensor->extra +class tensor_traits { + public: + virtual ~tensor_traits(); + virtual bool work_size(int n_threads, const struct lm_ggml_tensor * op, size_t & size) = 0; + virtual bool compute_forward(struct lm_ggml_compute_params * params, struct lm_ggml_tensor * op) = 0; +}; + +class extra_buffer_type { + public: + virtual ~extra_buffer_type(); + virtual bool supports_op(lm_ggml_backend_dev_t dev, const struct lm_ggml_tensor * op) = 0; + virtual tensor_traits * get_tensor_traits(const struct lm_ggml_tensor * op) = 0; +}; +} // namespace ggml::cpu + +// implemented in ggml-cpu.cpp. +std::vector & lm_ggml_backend_cpu_get_extra_buffers_type(); + +#endif diff --git a/cpp/ggml-cpu.c b/cpp/ggml-cpu.c index 51fe133e..1788c831 100644 --- a/cpp/ggml-cpu.c +++ b/cpp/ggml-cpu.c @@ -3,13 +3,14 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" -#include "ggml-cpu-aarch64.h" +#include "ggml-cpu-traits.h" #include "ggml-cpu-impl.h" #include "ggml-cpu.h" #include "ggml-impl.h" #include "ggml-quants.h" #include "ggml-cpu-quants.h" #include "ggml-threading.h" +#include "amx/amx.h" #include "ggml.h" #if defined(_MSC_VER) || defined(__MINGW32__) @@ -109,10 +110,11 @@ static lm_ggml_fp16_t lm_ggml_table_gelu_quick_f16[1 << 16]; #if defined(__ARM_ARCH) struct lm_ggml_arm_arch_features_type { int has_neon; + int has_dotprod; int has_i8mm; int has_sve; int sve_cnt; -} lm_ggml_arm_arch_features = {-1, -1, -1, 0}; +} lm_ggml_arm_arch_features = {-1, -1, -1, -1, 0}; #endif @@ -124,8 +126,7 @@ struct lm_ggml_arm_arch_features_type { #endif #include - -#if !defined(__clang__) +#if defined(_MSC_VER) && !defined(__clang__) #define LM_GGML_CACHE_ALIGN __declspec(align(LM_GGML_CACHE_LINE)) typedef volatile LONG atomic_int; @@ -222,10 +223,6 @@ typedef void * thread_ret_t; typedef pthread_t lm_ggml_thread_t; -#ifdef LM_GGML_USE_CPU_HBM -#include -#endif - #if defined(__APPLE__) #include #include @@ -299,7 +296,6 @@ static const struct lm_ggml_type_traits_cpu type_traits_cpu[LM_GGML_TYPE_COUNT] }, [LM_GGML_TYPE_Q8_0] = { .from_float = quantize_row_q8_0, - .from_float_to_mat = quantize_mat_q8_0, .vec_dot = lm_ggml_vec_dot_q8_0_q8_0, .vec_dot_type = LM_GGML_TYPE_Q8_0, #if defined (__ARM_FEATURE_MATMUL_INT8) @@ -407,33 +403,6 @@ static const struct lm_ggml_type_traits_cpu type_traits_cpu[LM_GGML_TYPE_COUNT] .vec_dot_type = LM_GGML_TYPE_BF16, .nrows = 1, }, - [LM_GGML_TYPE_Q4_0_4_4] = { - .from_float = NULL, - .vec_dot = NULL, - .vec_dot_type = LM_GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 4, - .gemv = lm_ggml_gemv_q4_0_4x4_q8_0, - .gemm = lm_ggml_gemm_q4_0_4x4_q8_0, - }, - [LM_GGML_TYPE_Q4_0_4_8] = { - .from_float = NULL, - .vec_dot = NULL, - .vec_dot_type = LM_GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 4, - .gemv = lm_ggml_gemv_q4_0_4x8_q8_0, - .gemm = lm_ggml_gemm_q4_0_4x8_q8_0, - }, - [LM_GGML_TYPE_Q4_0_8_8] = { - .from_float = NULL, - .vec_dot = NULL, - .vec_dot_type = LM_GGML_TYPE_Q8_0, - .nrows = 1, - .ncols = 8, - .gemv = lm_ggml_gemv_q4_0_8x8_q8_0, - .gemm = lm_ggml_gemm_q4_0_8x8_q8_0, - }, [LM_GGML_TYPE_TQ1_0] = { .from_float = quantize_row_tq1_0, .vec_dot = lm_ggml_vec_dot_tq1_0_q8_K, @@ -485,21 +454,21 @@ const struct lm_ggml_type_traits_cpu * lm_ggml_get_type_traits_cpu(enum lm_ggml_ #define LM_GGML_F32x4_ADD vaddq_f32 #define LM_GGML_F32x4_MUL vmulq_f32 #define LM_GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#define LM_GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ - } \ - (res) = LM_GGML_F32x4_REDUCE_ONE((x)[0]); \ +#define LM_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + (x)[i] = vaddq_f32((x)[i], (x)[offset+i]); \ + } \ + (res) = (lm_ggml_float) LM_GGML_F32x4_REDUCE_ONE((x)[0]); \ } #define LM_GGML_F32_VEC LM_GGML_F32x4 @@ -614,7 +583,7 @@ do { \ for (int i = 0; i < offset; ++i) { \ x[i] = _mm512_add_ps(x[i], x[offset+i]); \ } \ - res = _mm512_reduce_add_ps(x[0]); \ + res = (lm_ggml_float) _mm512_reduce_add_ps(x[0]); \ } while (0) // TODO: is this optimal ? @@ -664,7 +633,7 @@ do { \ for (int i = 0; i < offset; ++i) { \ x[i] = _mm512_add_ps(x[i], x[offset+i]); \ } \ - res = _mm512_reduce_add_ps(x[0]); \ + res = (lm_ggml_float) _mm512_reduce_add_ps(x[0]); \ } while (0) #define LM_GGML_F16_VEC LM_GGML_F32Cx16 @@ -675,8 +644,8 @@ do { \ #define LM_GGML_F16_VEC_FMA LM_GGML_F32Cx16_FMA #define LM_GGML_F16_VEC_ADD LM_GGML_F32Cx16_ADD #define LM_GGML_F16_VEC_MUL LM_GGML_F32Cx16_MUL -#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx16_REDUCE +#define LM_GGML_F16_VEC_REDUCE LM_GGML_F32Cx16_REDUCE #elif defined(__AVX__) #define LM_GGML_SIMD @@ -745,7 +714,7 @@ do { \ #define LM_GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define LM_GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else -static inline __m256 __avx_f32cx8_load(lm_ggml_fp16_t *x) { +static inline __m256 __avx_f32cx8_load(const lm_ggml_fp16_t * x) { float tmp[8]; for (int i = 0; i < 8; i++) { @@ -1168,28 +1137,28 @@ static inline void __lasx_f32cx8_store(lm_ggml_fp16_t * x, __m256 y) { #define LM_GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a) #define LM_GGML_F32x4_ADD __lsx_vfadd_s #define LM_GGML_F32x4_MUL __lsx_vfmul_s -#define LM_GGML_F32x4_REDUCE(res, x) \ -{ \ - int offset = LM_GGML_F32_ARR >> 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - offset >>= 1; \ - for (int i = 0; i < offset; ++i) { \ - x[i] = __lsx_vfadd_s(x[i], x[offset+i]); \ - } \ - __m128i tmp = __lsx_vsrli_d((__m128i)x[0], 32); \ - tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, x[0]); \ - tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ - const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \ - tmp = __lsx_vsrli_d((__m128i)t0, 32); \ - tmp = (__m128i)__lsx_vfadd_s((__m128)tmp, t0); \ - tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ - res = (lm_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \ +#define LM_GGML_F32x4_REDUCE(res, x) \ +{ \ + int offset = LM_GGML_F32_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = __lsx_vfadd_s(x[i], x[offset + i]); \ + } \ + __m128i tmp = __lsx_vsrli_d((__m128i) x[0], 32); \ + tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]); \ + tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ + const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88); \ + tmp = __lsx_vsrli_d((__m128i) t0, 32); \ + tmp = (__m128i) __lsx_vfadd_s((__m128) tmp, t0); \ + tmp = __lsx_vpickev_w(__lsx_vldi(0), tmp); \ + res = (lm_ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \ } #define LM_GGML_F32_VEC LM_GGML_F32x4 @@ -1357,31 +1326,18 @@ struct lm_ggml_compute_state { int ith; }; -struct lm_ggml_compute_params { - // ith = thread index, nth = number of threads - int ith, nth; - - // work buffer for all threads - size_t wsize; - void * wdata; - - struct lm_ggml_threadpool * threadpool; -}; - // // fundamental operations // inline static void lm_ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - inline static void lm_ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void lm_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void lm_ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } +inline static void lm_ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } inline static void lm_ggml_vec_set_f16(const int n, lm_ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - inline static void lm_ggml_vec_set_bf16(const int n, lm_ggml_bf16_t * x, const lm_ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } - inline static void lm_ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; } inline static void lm_ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; } inline static void lm_ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; } @@ -2276,7 +2232,7 @@ struct lm_ggml_state { static struct lm_ggml_state g_state = {0}; -static void lm_ggml_barrier(struct lm_ggml_threadpool * tp) { +void lm_ggml_barrier(struct lm_ggml_threadpool * tp) { int n_threads = atomic_load_explicit(&tp->n_threads_cur, memory_order_relaxed); if (n_threads == 1) { return; @@ -2430,7 +2386,7 @@ bool lm_ggml_is_numa(void) { #endif #if !defined(HWCAP2_I8MM) -#define HWCAP2_I8MM 0 +#define HWCAP2_I8MM (1 << 13) #endif static void lm_ggml_init_arm_arch_features(void) { @@ -2439,6 +2395,7 @@ static void lm_ggml_init_arm_arch_features(void) { uint32_t hwcap2 = getauxval(AT_HWCAP2); lm_ggml_arm_arch_features.has_neon = !!(hwcap & HWCAP_ASIMD); + lm_ggml_arm_arch_features.has_dotprod = !!(hwcap & HWCAP_ASIMDDP); lm_ggml_arm_arch_features.has_i8mm = !!(hwcap2 & HWCAP2_I8MM); lm_ggml_arm_arch_features.has_sve = !!(hwcap & HWCAP_SVE); @@ -2453,6 +2410,11 @@ static void lm_ggml_init_arm_arch_features(void) { } lm_ggml_arm_arch_features.has_neon = oldp; + if (sysctlbyname("hw.optional.arm.FEAT_DotProd", &oldp, &size, NULL, 0) != 0) { + oldp = 0; + } + lm_ggml_arm_arch_features.has_dotprod = oldp; + if (sysctlbyname("hw.optional.arm.FEAT_I8MM", &oldp, &size, NULL, 0) != 0) { oldp = 0; } @@ -4505,9 +4467,6 @@ static void lm_ggml_compute_forward_add( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: { lm_ggml_compute_forward_add_q_f32(params, dst); } break; @@ -4885,9 +4844,6 @@ static void lm_ggml_compute_forward_add1( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: { lm_ggml_compute_forward_add1_q_f32(params, dst); } break; @@ -5015,9 +4971,6 @@ static void lm_ggml_compute_forward_acc( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: default: { LM_GGML_ABORT("fatal error"); @@ -7433,20 +7386,9 @@ static void lm_ggml_compute_forward_mul_mat( const int ith = params->ith; const int nth = params->nth; - enum lm_ggml_type type = src0->type; - - if (src0->buffer && lm_ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) { - type = (enum lm_ggml_type)(intptr_t)src0->extra; - } - - enum lm_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; + enum lm_ggml_type const vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; lm_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; - lm_ggml_from_float_to_mat_t const from_float_to_mat = type_traits_cpu[vec_dot_type].from_float_to_mat; - int64_t const vec_dot_num_rows = type_traits_cpu[type].nrows; - int64_t const matmul_num_cols = type_traits_cpu[type].ncols; - int64_t const blck_size_interleave = lm_ggml_get_type_traits(type)->blck_size_interleave; - lm_ggml_gemv_t const gemv = type_traits_cpu[type].gemv; - lm_ggml_gemm_t const gemm = type_traits_cpu[type].gemm; + int64_t const vec_dot_num_rows = type_traits_cpu[src0->type].nrows; LM_GGML_ASSERT(ne0 == ne01); LM_GGML_ASSERT(ne1 == ne11); @@ -7454,7 +7396,7 @@ static void lm_ggml_compute_forward_mul_mat( LM_GGML_ASSERT(ne3 == ne13); // we don't support permuted src0 or src1 - LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type)); LM_GGML_ASSERT(nb10 == lm_ggml_type_size(src1->type)); // dst cannot be transposed or permuted @@ -7466,6 +7408,7 @@ static void lm_ggml_compute_forward_mul_mat( // nb01 >= nb00 - src0 is not transposed // compute by src0 rows + // TODO: extract to "extra_op" #if LM_GGML_USE_LLAMAFILE // broadcast factors const int64_t r2 = ne12 / ne02; @@ -7476,15 +7419,15 @@ static void lm_ggml_compute_forward_mul_mat( if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/lm_ggml_blck_size(type), + if (!llamafile_sgemm(ne01, ne11, ne00/lm_ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/lm_ggml_type_size(type), + nb01/lm_ggml_type_size(src0->type), (const char *)src1->data + i12*nb12 + i13*nb13, nb11/lm_ggml_type_size(src1->type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/lm_ggml_type_size(dst->type), ith, nth, - type, + src0->type, src1->type, dst->type)) goto UseGgmlGemm1; @@ -7505,19 +7448,10 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { - int64_t i11_processed = 0; - if ((lm_ggml_n_dims(src1) == 2) && from_float_to_mat && gemm) { - for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) { - from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - 4, ne10, blck_size_interleave); - } - i11_processed = ne11 - ne11 % 4; - } - for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) { + for (int64_t i11 = ith; i11 < ne11; i11 += nth) { from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), - (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), - ne10); + (void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1), + ne10); } } } @@ -7537,15 +7471,15 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/lm_ggml_blck_size(type), + if (!llamafile_sgemm(ne01, ne11, ne00/lm_ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, - nb01/lm_ggml_type_size(type), + nb01/lm_ggml_type_size(src0->type), (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/lm_ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/lm_ggml_type_size(dst->type), ith, nth, - type, + src0->type, vec_dot_type, dst->type)) goto UseGgmlGemm2; @@ -7560,14 +7494,6 @@ UseGgmlGemm2:; // This is the size of the rest of the dimensions of the result const int64_t nr1 = ne1 * ne2 * ne3; - // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols - int64_t num_rows_per_vec_dot = vec_dot_num_rows; - // TODO: currently the mmla kernels support only even numbered rows/cols. - // this check can be removed once they are extended to support odd numbered rows/cols too - if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) { - num_rows_per_vec_dot = 1; - } - // Now select a reasonable chunk size. int chunk_size = 16; @@ -7595,28 +7521,6 @@ UseGgmlGemm2:; const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - if ((lm_ggml_n_dims(src0) == 2) && gemv) { - const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; - const size_t src1_col_stride = lm_ggml_is_contiguous(src1) || src1->type != vec_dot_type ? lm_ggml_row_size(vec_dot_type, ne10) : nb11; - int64_t src0_start = (ith * ne01) / nth; - int64_t src0_end = ((ith + 1) * ne01) / nth; - src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start; - src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end; - if (src0_start >= src0_end) return; - - // If there are more than three rows in src1, use gemm; otherwise, use gemv. - if (gemm && (ne11 > 3)) { - gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01, - (const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start); - } - for (int iter = gemm ? ne11 - ne11 % 4 : 0; iter < ne11; iter++) { - gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01, - (const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1, - src0_end - src0_start); - } - return; - } - // The first chunk comes from our thread_id, the rest will get auto-assigned. int current_chunk = ith; @@ -7630,7 +7534,16 @@ UseGgmlGemm2:; const int64_t ir1_start = dr1 * ith1; const int64_t ir1_end = MIN(ir1_start + dr1, nr1); - lm_ggml_compute_forward_mul_mat_one_chunk(params, dst, type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); + // dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols + int64_t num_rows_per_vec_dot = vec_dot_num_rows; + + // these checks are needed to avoid crossing dim1 boundaries + // can be optimized, but the logic would become more complicated, so keeping it like this for simplicity + if ((nr0 % 2 != 0) || (ne11 % 2 != 0) || ((ir0_end - ir0_start) % 2 != 0) || ((ir1_end - ir1_start) % 2 != 0)) { + num_rows_per_vec_dot = 1; + } + + lm_ggml_compute_forward_mul_mat_one_chunk(params, dst, src0->type, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end); if (nth >= nchunk0 * nchunk1) { break; @@ -7662,8 +7575,6 @@ static void lm_ggml_compute_forward_mul_mat_id( lm_ggml_vec_dot_t const vec_dot = type_traits_cpu[type].vec_dot; enum lm_ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; lm_ggml_from_float_t const from_float = type_traits_cpu[vec_dot_type].from_float; - int64_t const matmul_num_cols = type_traits_cpu[type].ncols; - lm_ggml_gemv_t const gemv = type_traits_cpu[type].gemv; // we don't support permuted src0 or src1 LM_GGML_ASSERT(nb00 == lm_ggml_type_size(type)); @@ -7749,34 +7660,6 @@ static void lm_ggml_compute_forward_mul_mat_id( const int64_t nr0 = ne01; // src0 rows const int64_t nr1 = cne1; // src1 rows - if (((lm_ggml_n_dims(src0) - 1) == 2) && gemv) { - int64_t src0_cur_start = (ith * ne01) / nth; - int64_t src0_cur_end = ((ith + 1) * ne01) / nth; - src0_cur_start = (src0_cur_start % matmul_num_cols) ? src0_cur_start + matmul_num_cols - (src0_cur_start % matmul_num_cols): src0_cur_start; - src0_cur_end = (src0_cur_end % matmul_num_cols) ? src0_cur_end + matmul_num_cols - (src0_cur_end % matmul_num_cols): src0_cur_end; - if (src0_cur_start >= src0_cur_end) return; - - for (int ir1 = 0; ir1 < nr1; ir1++) { - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); - const int id = row_mapping.i1; // selected expert index - - const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 - - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row - - const char * src1_col = (const char *) wdata + - (src1_cont || src1->type != vec_dot_type - ? (i11 + i12 * ne11) * row_size - : (i11 * nb11 + i12 * nb12)); - - gemv(ne00, (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - (const char *) src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); - } - continue; - } - // distribute the thread work across the inner or outer loop based on which one is larger const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows @@ -8084,9 +7967,6 @@ static void lm_ggml_compute_forward_out_prod( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: { lm_ggml_compute_forward_out_prod_q_f32(params, dst); } break; @@ -8239,6 +8119,77 @@ static void lm_ggml_compute_forward_set_f32( } } +static void lm_ggml_compute_forward_set_i32( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + const struct lm_ggml_tensor * src1 = dst->src[1]; + + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); + + // view src0 and dst with these strides and data offset inbytes during set + // nb0 is implicitly element_size because src0 and dst are contiguous + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + if (params->ith == 0) { + // memcpy needs to be synchronized across threads to avoid race conditions. + // => do it in INIT phase + memcpy( + ((char *) dst->data), + ((char *) src0->data), + lm_ggml_nbytes(dst)); + } + lm_ggml_barrier(params->threadpool); + } + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = lm_ggml_nrows(src1); + const int nc = src1->ne[0]; + + LM_GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) + LM_GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) + + // src0 and dst as viewed during set + const size_t nb0 = lm_ggml_element_size(src0); + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + LM_GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= lm_ggml_nbytes(dst)); + + LM_GGML_ASSERT(nb10 == sizeof(int32_t)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int ir = ir0; ir < ir1; ++ir) { + // src0 and dst are viewed with shape of src1 and offset + // => same indices + const int i3 = ir/(ne12*ne11); + const int i2 = (ir - i3*ne12*ne11)/ne11; + const int i1 = (ir - i3*ne12*ne11 - i2*ne11); + + lm_ggml_vec_cpy_i32(nc, + (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), + (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + } +} + static void lm_ggml_compute_forward_set( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst) { @@ -8250,6 +8201,10 @@ static void lm_ggml_compute_forward_set( { lm_ggml_compute_forward_set_f32(params, dst); } break; + case LM_GGML_TYPE_I32: + { + lm_ggml_compute_forward_set_i32(params, dst); + } break; case LM_GGML_TYPE_F16: case LM_GGML_TYPE_BF16: case LM_GGML_TYPE_Q4_0: @@ -8274,9 +8229,6 @@ static void lm_ggml_compute_forward_set( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: default: { LM_GGML_ABORT("fatal error"); @@ -8538,9 +8490,6 @@ static void lm_ggml_compute_forward_get_rows( case LM_GGML_TYPE_IQ4_XS: case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: { lm_ggml_compute_forward_get_rows_q(params, dst); } break; @@ -9130,9 +9079,6 @@ static void lm_ggml_compute_forward_clamp( case LM_GGML_TYPE_IQ3_S: case LM_GGML_TYPE_IQ2_S: case LM_GGML_TYPE_Q8_K: - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - case LM_GGML_TYPE_Q4_0_8_8: case LM_GGML_TYPE_I8: case LM_GGML_TYPE_I16: case LM_GGML_TYPE_I32: @@ -9187,6 +9133,64 @@ static void lm_ggml_rope_cache_init( } } +static void lm_ggml_mrope_cache_init( + float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool indep_sects, + float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale, + float * cache, float sin_sign, float theta_scale) { + // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py + float theta_t = theta_base_t; + float theta_h = theta_base_h; + float theta_w = theta_base_w; + float theta_e = theta_base_e; // extra position id for vision encoder + int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + int sec_w = sections[1] + sections[0]; + int sec_e = sections[2] + sec_w; + LM_GGML_ASSERT(sect_dims <= ne0); + + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0/2] : 1.0f; + + int sector = (i0 / 2) % sect_dims; + if (indep_sects) { + // compute theta independently for each dim sections + // (i.e. reset corresponding theta when `i0` go from one section to another) + if (sector == 0) { + theta_t = theta_base_t; + } + else if (sector == sections[0]) { + theta_h = theta_base_h;; + } + else if (sector == sec_w) { + theta_w = theta_base_w; + } + else if (sector == sec_e) { + theta_e = theta_base_e; + } + } + + float theta = theta_t; + if (sector >= sections[0] && sector < sec_w) { + theta = theta_h; + } + else if (sector >= sec_w && sector < sec_w + sections[2]) { + theta = theta_w; + } + else if (sector >= sec_w + sections[2]) { + theta = theta_e; + } + + rope_yarn( + theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] + ); + cache[i0 + 1] *= sin_sign; + + theta_t *= theta_scale; + theta_w *= theta_scale; + theta_h *= theta_scale; + theta_e *= theta_scale; + } +} + static void lm_ggml_compute_forward_rope_f32( const struct lm_ggml_compute_params * params, struct lm_ggml_tensor * dst, @@ -9197,6 +9201,7 @@ static void lm_ggml_compute_forward_rope_f32( const struct lm_ggml_tensor * src2 = dst->src[2]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -9210,6 +9215,7 @@ static void lm_ggml_compute_forward_rope_f32( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); LM_GGML_TENSOR_UNARY_OP_LOCALS @@ -9242,6 +9248,16 @@ static void lm_ggml_compute_forward_rope_f32( lm_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & LM_GGML_ROPE_TYPE_MROPE; // lm_ggml_rope_multi, multimodal rotary position embedding + const bool is_vision = mode == LM_GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + LM_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + LM_GGML_ASSERT(n_dims == ne0/2); + } const float * freq_factors = NULL; if (src2 != NULL) { @@ -9257,18 +9273,63 @@ static void lm_ggml_compute_forward_rope_f32( const int32_t * pos = (const int32_t *) src1->data; - for (int64_t i3 = 0; i3 < ne3; i3++) { - for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; + for (int64_t i3 = 0; i3 < ne3; i3++) { // batch + for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + if (!is_mrope) { + const int64_t p = pos[i2]; + lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + lm_ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, is_vision, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } - for (int64_t i1 = 0; i1 < ne1; i1++) { + for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads if (ir++ < ir0) continue; if (ir > ir1) break; - if (!is_neox) { + if (is_neox || is_mrope) { + if (is_vision){ + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims] = x0*sin_theta + x1*cos_theta; + } + } else { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = src[0]; + const float x1 = src[n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } + } else { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; @@ -9282,8 +9343,10 @@ static void lm_ggml_compute_forward_rope_f32( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; } - } else { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + } + + if (is_vision) { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { const int64_t ic = i0/2; const float cos_theta = cache[i0 + 0]; @@ -9293,19 +9356,20 @@ static void lm_ggml_compute_forward_rope_f32( float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); const float x0 = src[0]; - const float x1 = src[n_dims/2]; + const float x1 = src[n_dims]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[n_dims] = x0*sin_theta + x1*cos_theta; } - } - - for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { - const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + } else { + // fill the remain channels with data from src tensor + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } } } } @@ -9323,6 +9387,7 @@ static void lm_ggml_compute_forward_rope_f16( const struct lm_ggml_tensor * src2 = dst->src[2]; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -9335,6 +9400,8 @@ static void lm_ggml_compute_forward_rope_f16( memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4); + LM_GGML_TENSOR_UNARY_OP_LOCALS @@ -9367,6 +9434,16 @@ static void lm_ggml_compute_forward_rope_f16( lm_ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); const bool is_neox = mode & LM_GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & LM_GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == LM_GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + LM_GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + LM_GGML_ASSERT(n_dims == ne0/2); + } const float * freq_factors = NULL; if (src2 != NULL) { @@ -9384,16 +9461,61 @@ static void lm_ggml_compute_forward_rope_f16( for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { - const int64_t p = pos[i2]; float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + if (!is_mrope) { + const int64_t p = pos[i2]; + lm_ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + lm_ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, is_vision, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } for (int64_t i1 = 0; i1 < ne1; i1++) { if (ir++ < ir0) continue; if (ir > ir1) break; - if (!is_neox) { + if (is_neox || is_mrope) { + if (is_vision) { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = LM_GGML_FP16_TO_FP32(src[0]); + const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims]); + + dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } else { + for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + const int64_t ic = i0/2; + + const float cos_theta = cache[i0 + 0]; + const float sin_theta = cache[i0 + 1]; + + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + + const float x0 = LM_GGML_FP16_TO_FP32(src[0]); + const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims/2]); + + dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + } + } + } else { for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { const float cos_theta = cache[i0 + 0]; const float sin_theta = cache[i0 + 1]; @@ -9407,8 +9529,10 @@ static void lm_ggml_compute_forward_rope_f16( dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); dst_data[1] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } - } else { - for (int64_t i0 = 0; i0 < n_dims; i0 += 2) { + } + + if (is_vision) { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { const int64_t ic = i0/2; const float cos_theta = cache[i0 + 0]; @@ -9418,19 +9542,19 @@ static void lm_ggml_compute_forward_rope_f16( lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); const float x0 = LM_GGML_FP16_TO_FP32(src[0]); - const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims/2]); + const float x1 = LM_GGML_FP16_TO_FP32(src[n_dims]); - dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); - dst_data[n_dims/2] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); + dst_data[0] = LM_GGML_FP32_TO_FP16(x0*cos_theta - x1*sin_theta); + dst_data[n_dims] = LM_GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta); } - } - - for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { - const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + } else { + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const lm_ggml_fp16_t * const src = (lm_ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + lm_ggml_fp16_t * dst_data = (lm_ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - dst_data[0] = src[0]; - dst_data[1] = src[1]; + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } } } } @@ -10429,6 +10553,40 @@ static void lm_ggml_compute_forward_pad( } } +// lm_ggml_compute_forward_pad_reflect_1d + +static void lm_ggml_compute_forward_pad_reflect_1d( + const struct lm_ggml_compute_params * params, + struct lm_ggml_tensor * dst) { + + const struct lm_ggml_tensor * src0 = dst->src[0]; + + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); + + const int ith = params->ith; + const int nth = params->nth; + + const int32_t * opts = (const int32_t *) dst->op_params; + const int p0 = opts[0]; + const int p1 = opts[1]; + + LM_GGML_TENSOR_UNARY_OP_LOCALS + + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + for (int64_t i1 = ith; i1 < ne1; i1 += nth) { + float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0); + float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0); + + lm_ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); + + for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; } + for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; } + } + } + } +} // lm_ggml_compute_forward_arange @@ -12304,6 +12462,9 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru return; } + // extra_buffer op? + if (lm_ggml_cpu_extra_compute_forward(params, tensor)) return; + switch (tensor->op) { case LM_GGML_OP_DUP: { @@ -12525,6 +12686,10 @@ static void lm_ggml_compute_forward(struct lm_ggml_compute_params * params, stru { lm_ggml_compute_forward_pad(params, tensor); } break; + case LM_GGML_OP_PAD_REFLECT_1D: + { + lm_ggml_compute_forward_pad_reflect_1d(params, tensor); + } break; case LM_GGML_OP_ARANGE: { lm_ggml_compute_forward_arange(params, tensor); @@ -12867,6 +13032,7 @@ static int lm_ggml_get_n_tasks(struct lm_ggml_tensor * node, int n_threads) { } break; case LM_GGML_OP_UPSCALE: case LM_GGML_OP_PAD: + case LM_GGML_OP_PAD_REFLECT_1D: case LM_GGML_OP_ARANGE: case LM_GGML_OP_TIMESTEP_EMBEDDING: case LM_GGML_OP_ARGSORT: @@ -12956,7 +13122,7 @@ static thread_ret_t lm_ggml_graph_compute_secondary_thread(void* data); #include "windows.h" // TODO: support > 64 CPUs -bool lm_ggml_thread_apply_affinity(bool * mask) { +static bool lm_ggml_thread_apply_affinity(bool * mask) { HANDLE h = GetCurrentThread(); uint64_t bitmask = 0ULL; @@ -13246,140 +13412,142 @@ struct lm_ggml_cplan lm_ggml_graph_plan( size_t cur = 0; - switch (node->op) { - case LM_GGML_OP_CPY: - case LM_GGML_OP_DUP: - { - if (lm_ggml_is_quantized(node->type) || - // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 - (node->src[0]->type == LM_GGML_TYPE_F16 && node->src[1] && node->src[1]->type == LM_GGML_TYPE_BF16) || - (node->src[0]->type == LM_GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == LM_GGML_TYPE_F16)) { + if (!lm_ggml_cpu_extra_work_size(n_threads, node, &cur)) { + + switch (node->op) { + case LM_GGML_OP_CPY: + case LM_GGML_OP_DUP: + { + if (lm_ggml_is_quantized(node->type) || + // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 + (node->src[0]->type == LM_GGML_TYPE_F16 && node->src[1] && node->src[1]->type == LM_GGML_TYPE_BF16) || + (node->src[0]->type == LM_GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == LM_GGML_TYPE_F16)) { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->ne[0] * n_tasks; + } + } break; + case LM_GGML_OP_ADD: + case LM_GGML_OP_ADD1: + { + if (lm_ggml_is_quantized(node->src[0]->type)) { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + } break; + case LM_GGML_OP_ACC: + { + if (lm_ggml_is_quantized(node->src[0]->type)) { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; + } + } break; + case LM_GGML_OP_COUNT_EQUAL: + { + cur = lm_ggml_type_size(node->type)*n_tasks; + } break; + case LM_GGML_OP_MUL_MAT: + { + const enum lm_ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type; + + if (node->src[1]->type != vec_dot_type) { + cur = lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(node->src[1])); + } + } break; + case LM_GGML_OP_MUL_MAT_ID: + { + cur = 0; + const struct lm_ggml_tensor * src0 = node->src[0]; + const struct lm_ggml_tensor * src1 = node->src[1]; + const enum lm_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; + if (src1->type != vec_dot_type) { + cur += lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)); + } + const int n_as = src0->ne[2]; + cur += LM_GGML_PAD(cur, sizeof(int64_t)); // align + cur += n_as * sizeof(int64_t); // matrix_row_counts + cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows + } break; + case LM_GGML_OP_OUT_PROD: + { + if (lm_ggml_is_quantized(node->src[0]->type)) { + cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; + } + } break; + case LM_GGML_OP_SOFT_MAX: + case LM_GGML_OP_ROPE: + { cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->ne[0] * n_tasks; - } - } break; - case LM_GGML_OP_ADD: - case LM_GGML_OP_ADD1: - { - if (lm_ggml_is_quantized(node->src[0]->type)) { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; - } - } break; - case LM_GGML_OP_ACC: - { - if (lm_ggml_is_quantized(node->src[0]->type)) { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[1]->ne[0] * n_tasks; - } - } break; - case LM_GGML_OP_COUNT_EQUAL: - { - cur = lm_ggml_type_size(node->type)*n_tasks; - } break; - case LM_GGML_OP_MUL_MAT: - { - const enum lm_ggml_type vec_dot_type = type_traits_cpu[node->src[0]->type].vec_dot_type; + } break; + case LM_GGML_OP_CONV_TRANSPOSE_1D: + { + LM_GGML_ASSERT(node->src[0]->ne[3] == 1); + LM_GGML_ASSERT(node->src[1]->ne[2] == 1); + LM_GGML_ASSERT(node->src[1]->ne[3] == 1); + + const int64_t ne00 = node->src[0]->ne[0]; // K + const int64_t ne01 = node->src[0]->ne[1]; // Cout + const int64_t ne02 = node->src[0]->ne[2]; // Cin + const int64_t ne10 = node->src[1]->ne[0]; // L + const int64_t ne11 = node->src[1]->ne[1]; // Cin + + if ((node->src[0]->type == LM_GGML_TYPE_F16 || + node->src[0]->type == LM_GGML_TYPE_BF16) && + node->src[1]->type == LM_GGML_TYPE_F32) { + cur += sizeof(lm_ggml_fp16_t)*ne00*ne01*ne02; + cur += sizeof(lm_ggml_fp16_t)*ne10*ne11; + } else if (node->src[0]->type == LM_GGML_TYPE_F32 && + node->src[1]->type == LM_GGML_TYPE_F32) { + cur += sizeof(float)*ne00*ne01*ne02; + cur += sizeof(float)*ne10*ne11; + } else { + LM_GGML_ABORT("fatal error"); + } + } break; + case LM_GGML_OP_CONV_TRANSPOSE_2D: + { + const int64_t ne00 = node->src[0]->ne[0]; // W + const int64_t ne01 = node->src[0]->ne[1]; // H + const int64_t ne02 = node->src[0]->ne[2]; // Channels Out + const int64_t ne03 = node->src[0]->ne[3]; // Channels In - if (node->src[1]->type != vec_dot_type) { - cur = lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(node->src[1])); - } - } break; - case LM_GGML_OP_MUL_MAT_ID: - { - cur = 0; - const struct lm_ggml_tensor * src0 = node->src[0]; - const struct lm_ggml_tensor * src1 = node->src[1]; - const enum lm_ggml_type vec_dot_type = type_traits_cpu[src0->type].vec_dot_type; - if (src1->type != vec_dot_type) { - cur += lm_ggml_row_size(vec_dot_type, lm_ggml_nelements(src1)); - } - const int n_as = src0->ne[2]; - cur += LM_GGML_PAD(cur, sizeof(int64_t)); // align - cur += n_as * sizeof(int64_t); // matrix_row_counts - cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows - } break; - case LM_GGML_OP_OUT_PROD: - { - if (lm_ggml_is_quantized(node->src[0]->type)) { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks; - } - } break; - case LM_GGML_OP_SOFT_MAX: - case LM_GGML_OP_ROPE: - { - cur = lm_ggml_type_size(LM_GGML_TYPE_F32) * node->ne[0] * n_tasks; - } break; - case LM_GGML_OP_CONV_TRANSPOSE_1D: - { - LM_GGML_ASSERT(node->src[0]->ne[3] == 1); - LM_GGML_ASSERT(node->src[1]->ne[2] == 1); - LM_GGML_ASSERT(node->src[1]->ne[3] == 1); - - const int64_t ne00 = node->src[0]->ne[0]; // K - const int64_t ne01 = node->src[0]->ne[1]; // Cout - const int64_t ne02 = node->src[0]->ne[2]; // Cin - - const int64_t ne10 = node->src[1]->ne[0]; // L - const int64_t ne11 = node->src[1]->ne[1]; // Cin - - if ((node->src[0]->type == LM_GGML_TYPE_F16 || - node->src[0]->type == LM_GGML_TYPE_BF16) && - node->src[1]->type == LM_GGML_TYPE_F32) { - cur += sizeof(lm_ggml_fp16_t)*ne00*ne01*ne02; - cur += sizeof(lm_ggml_fp16_t)*ne10*ne11; - } else if (node->src[0]->type == LM_GGML_TYPE_F32 && - node->src[1]->type == LM_GGML_TYPE_F32) { - cur += sizeof(float)*ne00*ne01*ne02; - cur += sizeof(float)*ne10*ne11; - } else { - LM_GGML_ABORT("fatal error"); - } - } break; - case LM_GGML_OP_CONV_TRANSPOSE_2D: - { - const int64_t ne00 = node->src[0]->ne[0]; // W - const int64_t ne01 = node->src[0]->ne[1]; // H - const int64_t ne02 = node->src[0]->ne[2]; // Channels Out - const int64_t ne03 = node->src[0]->ne[3]; // Channels In - - const int64_t ne10 = node->src[1]->ne[0]; // W - const int64_t ne11 = node->src[1]->ne[1]; // H - const int64_t ne12 = node->src[1]->ne[2]; // Channels In - - cur += sizeof(lm_ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(lm_ggml_fp16_t)*ne10*ne11*ne12; - } break; - case LM_GGML_OP_FLASH_ATTN_EXT: - { - const int64_t ne00 = node->src[0]->ne[0]; // D + const int64_t ne10 = node->src[1]->ne[0]; // W + const int64_t ne11 = node->src[1]->ne[1]; // H + const int64_t ne12 = node->src[1]->ne[2]; // Channels In - cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread - } break; - case LM_GGML_OP_FLASH_ATTN_BACK: - { - const int64_t D = node->src[0]->ne[0]; - const int64_t ne11 = lm_ggml_up(node->src[1]->ne[1], LM_GGML_SOFT_MAX_UNROLL); - const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in lm_ggml_compute_forward_flash_attn_back - if (node->src[1]->type == LM_GGML_TYPE_F32) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == LM_GGML_TYPE_F16) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } else if (node->src[1]->type == LM_GGML_TYPE_BF16) { - cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 - } - } break; + cur += sizeof(lm_ggml_fp16_t)*ne00*ne01*ne02*ne03; + cur += sizeof(lm_ggml_fp16_t)*ne10*ne11*ne12; + } break; + case LM_GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D - case LM_GGML_OP_CROSS_ENTROPY_LOSS: - { - cur = lm_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); - } break; - case LM_GGML_OP_COUNT: - { - LM_GGML_ABORT("fatal error"); - } - default: - break; + cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread + } break; + case LM_GGML_OP_FLASH_ATTN_BACK: + { + const int64_t D = node->src[0]->ne[0]; + const int64_t ne11 = lm_ggml_up(node->src[1]->ne[1], LM_GGML_SOFT_MAX_UNROLL); + const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in lm_ggml_compute_forward_flash_attn_back + if (node->src[1]->type == LM_GGML_TYPE_F32) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == LM_GGML_TYPE_F16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } else if (node->src[1]->type == LM_GGML_TYPE_BF16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } + } break; + + case LM_GGML_OP_CROSS_ENTROPY_LOSS: + { + cur = lm_ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); + } break; + case LM_GGML_OP_COUNT: + { + LM_GGML_ABORT("fatal error"); + } + default: + break; + } } work_size = MAX(work_size, cur); @@ -13578,29 +13746,6 @@ static void lm_ggml_graph_compute_kickoff(struct lm_ggml_threadpool * threadpool #endif // LM_GGML_USE_OPENMP -void lm_ggml_threadpool_params_init(struct lm_ggml_threadpool_params * p, int n_threads) { - p->n_threads = n_threads; - p->prio = 0; // default priority (usually means normal or inherited) - p->poll = 50; // hybrid-polling enabled - p->strict_cpu = false; // no strict placement (all threads share same cpumask) - p->paused = false; // threads are ready to go - memset(p->cpumask, 0, LM_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) -} - -struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads) { - struct lm_ggml_threadpool_params p; - lm_ggml_threadpool_params_init(&p, n_threads); - return p; -} - -bool lm_ggml_threadpool_params_match(const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1) { - if (p0->n_threads != p1->n_threads ) return false; - if (p0->prio != p1->prio ) return false; - if (p0->poll != p1->poll ) return false; - if (p0->strict_cpu != p1->strict_cpu ) return false; - return memcmp(p0->cpumask, p1->cpumask, LM_GGML_MAX_N_THREADS) == 0; -} - static struct lm_ggml_threadpool * lm_ggml_threadpool_new_impl( struct lm_ggml_threadpool_params * tpp, struct lm_ggml_cgraph * cgraph, @@ -13896,15 +14041,23 @@ int lm_ggml_cpu_has_vsx(void) { } int lm_ggml_cpu_has_neon(void) { -#if defined(__ARM_ARCH) +#if defined(__ARM_ARCH) && defined(__ARM_NEON) return lm_ggml_arm_arch_features.has_neon; #else return 0; #endif } +int lm_ggml_cpu_has_dotprod(void) { +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_DOTPROD) + return lm_ggml_arm_arch_features.has_dotprod; +#else + return 0; +#endif +} + int lm_ggml_cpu_has_sve(void) { -#if defined(__ARM_ARCH) +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE) return lm_ggml_arm_arch_features.has_sve; #else return 0; @@ -13912,7 +14065,7 @@ int lm_ggml_cpu_has_sve(void) { } int lm_ggml_cpu_has_matmul_int8(void) { -#if defined(__ARM_ARCH) +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_MATMUL_INT8) return lm_ggml_arm_arch_features.has_i8mm; #else return 0; @@ -13920,7 +14073,7 @@ int lm_ggml_cpu_has_matmul_int8(void) { } int lm_ggml_cpu_get_sve_cnt(void) { -#if defined(__ARM_ARCH) +#if defined(__ARM_ARCH) && defined(__ARM_FEATURE_SVE) return lm_ggml_arm_arch_features.sve_cnt; #else return 0; diff --git a/cpp/ggml-cpu.cpp b/cpp/ggml-cpu.cpp index 56a2527b..79878f62 100644 --- a/cpp/ggml-cpu.cpp +++ b/cpp/ggml-cpu.cpp @@ -2,11 +2,18 @@ #include "ggml-backend-impl.h" #include "ggml-cpu.h" #include "ggml-cpu-aarch64.h" +#include "ggml-cpu-traits.h" #include "ggml-impl.h" +#include "amx/amx.h" + #include #include #include +#ifdef LM_GGML_USE_CPU_HBM +#include "ggml-cpu-hbm.h" +#endif + #if defined(__APPLE__) #include #include @@ -22,124 +29,20 @@ // ggml-backend interface -#ifdef LM_GGML_USE_CPU_HBM - -// buffer type HBM - -#include - -static const char * lm_ggml_backend_cpu_hbm_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { - return "CPU_HBM"; - - LM_GGML_UNUSED(buft); -} - -static void lm_ggml_backend_cpu_hbm_buffer_free_buffer(lm_ggml_backend_buffer_t buffer) { - hbw_free(buffer->context); -} - -static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_hbm_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { - void * ptr; - int result = hbw_posix_memalign(&ptr, lm_ggml_backend_cpu_buffer_type_get_alignment(buft), size); - if (result != 0) { - LM_GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size); - return NULL; - } - - lm_ggml_backend_buffer_t buffer = lm_ggml_backend_cpu_buffer_from_ptr(ptr, size); - buffer->buft = buft; - buffer->iface.free_buffer = lm_ggml_backend_cpu_hbm_buffer_free_buffer; - - return buffer; -} - -lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void) { - static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_hbm = { - /* .iface = */ { - /* .get_name = */ lm_ggml_backend_cpu_hbm_buffer_type_get_name, - /* .alloc_buffer = */ lm_ggml_backend_cpu_hbm_buffer_type_alloc_buffer, - /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type_get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes - /* .is_host = */ lm_ggml_backend_cpu_buffer_type_is_host, - }, - /* .context = */ NULL, - }; - - return &lm_ggml_backend_cpu_buffer_type_hbm; -} -#endif - -// buffer type AARCH64 - -static void lm_ggml_backend_cpu_aarch64_buffer_init_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor) { - tensor->extra = (void *)lm_ggml_aarch64_get_optimal_repack_type(tensor); // NOLINT - - LM_GGML_UNUSED(buffer); -} - -static void lm_ggml_backend_cpu_aarch64_buffer_set_tensor(lm_ggml_backend_buffer_t buffer, struct lm_ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - LM_GGML_ASSERT(offset == 0); - LM_GGML_ASSERT(size == lm_ggml_nbytes(tensor)); - - enum lm_ggml_type repack_type = (enum lm_ggml_type)(intptr_t)tensor->extra; - - lm_ggml_aarch64_repack_tensor(tensor, repack_type, data, size); - - LM_GGML_UNUSED(buffer); -} - -static const char * lm_ggml_backend_cpu_aarch64_buffer_type_get_name(lm_ggml_backend_buffer_type_t buft) { - return "CPU_AARCH64"; - - LM_GGML_UNUSED(buft); -} - -static lm_ggml_backend_buffer_t lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer(lm_ggml_backend_buffer_type_t buft, size_t size) { - auto * buffer = lm_ggml_backend_buft_alloc_buffer(lm_ggml_backend_cpu_buffer_type(), size); - - if (buffer == NULL) { - return NULL; - } - - buffer->buft = buft; - buffer->iface.init_tensor = lm_ggml_backend_cpu_aarch64_buffer_init_tensor; - buffer->iface.set_tensor = lm_ggml_backend_cpu_aarch64_buffer_set_tensor; - - return buffer; -} - -lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void) { - static struct lm_ggml_backend_buffer_type lm_ggml_backend_cpu_buffer_type_aarch64 = { - /* .iface = */ { - /* .get_name = */ lm_ggml_backend_cpu_aarch64_buffer_type_get_name, - /* .alloc_buffer = */ lm_ggml_backend_cpu_aarch64_buffer_type_alloc_buffer, - /* .get_alignment = */ lm_ggml_backend_cpu_buffer_type()->iface.get_alignment, - /* .get_max_size = */ NULL, // defaults to SIZE_MAX - /* .get_alloc_size = */ NULL, // defaults to lm_ggml_nbytes - /* .is_host = */ NULL, - }, - /* .device = */ lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0), - /* .context = */ NULL, - }; - - return &lm_ggml_backend_cpu_buffer_type_aarch64; -} - -bool lm_ggml_backend_cpu_buft_is_aarch64(lm_ggml_backend_buffer_type_t buft) { - return buft == lm_ggml_backend_cpu_aarch64_buffer_type(); -} - -static lm_ggml_backend_buffer_type_t * lm_ggml_backend_cpu_get_extra_bufts(lm_ggml_backend_dev_t device) { +std::vector& lm_ggml_backend_cpu_get_extra_buffers_type() { static std::vector bufts = []() { std::vector bufts; -#ifdef LM_GGML_USE_CPU_HBM - bufts.push_back(lm_ggml_backend_cpu_hbm_buffer_type()); +#if defined(__AMX_INT8__) && defined(__AVX512VNNI__) + if (lm_ggml_backend_amx_buffer_type()) { + bufts.push_back(lm_ggml_backend_amx_buffer_type()); + } #endif #ifdef LM_GGML_USE_CPU_AARCH64 - bufts.push_back(lm_ggml_backend_cpu_aarch64_buffer_type()); + if (lm_ggml_backend_cpu_aarch64_buffer_type()) { + bufts.push_back(lm_ggml_backend_cpu_aarch64_buffer_type()); + } #endif bufts.push_back(NULL); @@ -147,11 +50,22 @@ static lm_ggml_backend_buffer_type_t * lm_ggml_backend_cpu_get_extra_bufts(lm_gg return bufts; }(); - return bufts.data(); + return bufts; +} + +static lm_ggml_backend_buffer_type_t * lm_ggml_backend_cpu_device_get_extra_buffers_type(lm_ggml_backend_dev_t device) { + return lm_ggml_backend_cpu_get_extra_buffers_type().data(); LM_GGML_UNUSED(device); } +static bool lm_ggml_backend_cpu_is_extra_buffer_type(lm_ggml_backend_buffer_type_t buft) { + for (auto extra : lm_ggml_backend_cpu_get_extra_buffers_type()) { + if (extra && extra == buft) return true; + } + return false; +} + // CPU backend - backend (stream) struct lm_ggml_backend_cpu_context { @@ -456,14 +370,23 @@ static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, co const struct lm_ggml_tensor * src0 = op->src[0]; const struct lm_ggml_tensor * src1 = op->src[1]; - if (src0 && src0->buffer && lm_ggml_backend_cpu_buft_is_aarch64(src0->buffer->buft)) { - if (op->op != LM_GGML_OP_MUL_MAT || src0->type != LM_GGML_TYPE_Q4_0 || lm_ggml_aarch64_get_optimal_repack_type(src0) == LM_GGML_TYPE_Q4_0) { - return false; + if (op->op == LM_GGML_OP_NONE || op->op == LM_GGML_OP_RESHAPE || op->op == LM_GGML_OP_VIEW || op->op == LM_GGML_OP_PERMUTE || op->op == LM_GGML_OP_TRANSPOSE) { + return true; + } + + // extra_buffer_op? + for (auto extra : lm_ggml_backend_cpu_get_extra_buffers_type()) { + if (extra) { + auto buf_extra = (ggml::cpu::extra_buffer_type*) extra->context; + if (buf_extra && buf_extra->supports_op(dev, op)) { + return true; + } } } - for (int i = 1; i < LM_GGML_MAX_SRC; i++) { - if (op->src[i] && op->src[i]->buffer && lm_ggml_backend_cpu_buft_is_aarch64(op->src[i]->buffer->buft)) { + // the other case need host buffer. + for (int i = 0; i < LM_GGML_MAX_SRC; i++) { + if (op->src[i] && op->src[i]->buffer && !lm_ggml_backend_buft_is_host(op->src[i]->buffer->buft)) { return false; } } @@ -471,8 +394,11 @@ static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, co switch (op->op) { case LM_GGML_OP_CPY: return + op->type != LM_GGML_TYPE_IQ3_XXS && + op->type != LM_GGML_TYPE_IQ3_S && op->type != LM_GGML_TYPE_IQ2_XXS && op->type != LM_GGML_TYPE_IQ2_XS && + op->type != LM_GGML_TYPE_IQ2_S && op->type != LM_GGML_TYPE_IQ1_S && op->type != LM_GGML_TYPE_IQ1_M; // missing type_traits.from_float case LM_GGML_OP_MUL_MAT: @@ -486,13 +412,10 @@ static bool lm_ggml_backend_cpu_device_supports_op(lm_ggml_backend_dev_t dev, co default: return true; } - - LM_GGML_UNUSED(dev); } static bool lm_ggml_backend_cpu_device_supports_buft(lm_ggml_backend_dev_t dev, lm_ggml_backend_buffer_type_t buft) { - return lm_ggml_backend_buft_is_host(buft) || lm_ggml_backend_cpu_buft_is_aarch64(buft); - + return lm_ggml_backend_buft_is_host(buft) || lm_ggml_backend_cpu_is_extra_buffer_type(buft); LM_GGML_UNUSED(dev); } @@ -541,16 +464,12 @@ static lm_ggml_backend_dev_t lm_ggml_backend_cpu_reg_get_device(lm_ggml_backend_ return &lm_ggml_backend_cpu_device; } -struct lm_ggml_backend_feature { - const char * name; - const char * value; -}; - -// Not used yet // This is intended to replace the the lm_ggml_cpu_has_* functions when loading the CPU backend dynamically, -// and additionally to allow other backends to expose their own list of features that applications can query using the same API. +// and additionally to allow other backends to expose their own list of features that applications can query using the same API static lm_ggml_backend_feature * lm_ggml_backend_cpu_get_features(lm_ggml_backend_reg_t reg) { static std::vector features = []() { + lm_ggml_cpu_init(); + std::vector features; if (lm_ggml_cpu_has_sse3()) { features.push_back({ "SSE3", "1" }); @@ -561,6 +480,9 @@ static lm_ggml_backend_feature * lm_ggml_backend_cpu_get_features(lm_ggml_backen if (lm_ggml_cpu_has_avx()) { features.push_back({ "AVX", "1" }); } + if (lm_ggml_cpu_has_avx_vnni()) { + features.push_back({ "AVX_VNNI", "1" }); + } if (lm_ggml_cpu_has_avx2()) { features.push_back({ "AVX2", "1" }); } @@ -570,9 +492,6 @@ static lm_ggml_backend_feature * lm_ggml_backend_cpu_get_features(lm_ggml_backen if (lm_ggml_cpu_has_fma()) { features.push_back({ "FMA", "1" }); } - if (lm_ggml_cpu_has_avx_vnni()) { - features.push_back({ "AVX_VNNI", "1" }); - } if (lm_ggml_cpu_has_avx512()) { features.push_back({ "AVX512", "1" }); } @@ -603,6 +522,12 @@ static lm_ggml_backend_feature * lm_ggml_backend_cpu_get_features(lm_ggml_backen if (lm_ggml_cpu_has_sve()) { features.push_back({ "SVE", "1" }); } + if (lm_ggml_cpu_has_dotprod()) { + features.push_back({ "DOTPROD", "1" }); + } + if (lm_ggml_cpu_has_matmul_int8()) { + features.push_back({ "MATMUL_INT8", "1" }); + } if (lm_ggml_cpu_get_sve_cnt() > 0) { static std::string sve_cnt = std::to_string(lm_ggml_cpu_get_sve_cnt()); features.push_back({ "SVE_CNT", sve_cnt.c_str() }); @@ -619,6 +544,18 @@ static lm_ggml_backend_feature * lm_ggml_backend_cpu_get_features(lm_ggml_backen if (lm_ggml_cpu_has_llamafile()) { features.push_back({ "LLAMAFILE", "1" }); } + #ifdef LM_GGML_USE_ACCELERATE + features.push_back({ "ACCELERATE", "1" }); + #endif + #ifdef LM_GGML_USE_CPU_HBM + features.push_back({ "CPU_HBM", "1" }); + #endif + #ifdef LM_GGML_USE_OPENMP + features.push_back({ "OPENMP", "1" }); + #endif + #ifdef LM_GGML_USE_CPU_AARCH64 + features.push_back({ "AARCH64_REPACK", "1" }); + #endif features.push_back({ nullptr, nullptr }); @@ -632,10 +569,35 @@ static lm_ggml_backend_feature * lm_ggml_backend_cpu_get_features(lm_ggml_backen static void * lm_ggml_backend_cpu_get_proc_address(lm_ggml_backend_reg_t reg, const char * name) { if (strcmp(name, "lm_ggml_backend_set_n_threads") == 0) { - return (void *)lm_ggml_backend_cpu_set_n_threads; + lm_ggml_backend_set_n_threads_t fct = lm_ggml_backend_cpu_set_n_threads; + return (void *)fct; } if (strcmp(name, "lm_ggml_backend_dev_get_extra_bufts") == 0) { - return (void *)lm_ggml_backend_cpu_get_extra_bufts; + lm_ggml_backend_dev_get_extra_bufts_t fct = lm_ggml_backend_cpu_device_get_extra_buffers_type; + return (void *)fct; + } + if (strcmp(name, "lm_ggml_backend_get_features") == 0) { + return (void *)lm_ggml_backend_cpu_get_features; + } + if (strcmp(name, "lm_ggml_backend_set_abort_callback") == 0) { + return (void *)lm_ggml_backend_cpu_set_abort_callback; + } + if (strcmp(name, "lm_ggml_backend_cpu_numa_init") == 0) { + return (void *)lm_ggml_numa_init; + } + if (strcmp(name, "lm_ggml_backend_cpu_is_numa") == 0) { + return (void *)lm_ggml_is_numa; + } + + // threadpool - TODO: move to ggml-base + if (strcmp(name, "lm_ggml_threadpool_new") == 0) { + return (void *)lm_ggml_threadpool_new; + } + if (strcmp(name, "lm_ggml_threadpool_free") == 0) { + return (void *)lm_ggml_threadpool_free; + } + if (strcmp(name, "lm_ggml_backend_cpu_set_threadpool") == 0) { + return (void *)lm_ggml_backend_cpu_set_threadpool; } return NULL; @@ -655,9 +617,12 @@ lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void) { lm_ggml_cpu_init(); static struct lm_ggml_backend_reg lm_ggml_backend_cpu_reg = { - /* .iface = */ lm_ggml_backend_cpu_reg_i, - /* .context = */ NULL, + /* .api_version = */ LM_GGML_BACKEND_API_VERSION, + /* .iface = */ lm_ggml_backend_cpu_reg_i, + /* .context = */ NULL, }; return &lm_ggml_backend_cpu_reg; } + +LM_GGML_BACKEND_DL_IMPL(lm_ggml_backend_cpu_reg) diff --git a/cpp/ggml-cpu.h b/cpp/ggml-cpu.h index a49af8ff..eb56d139 100644 --- a/cpp/ggml-cpu.h +++ b/cpp/ggml-cpu.h @@ -7,29 +7,6 @@ extern "C" { #endif - // Scheduling priorities - enum lm_ggml_sched_priority { - LM_GGML_SCHED_PRIO_NORMAL, - LM_GGML_SCHED_PRIO_MEDIUM, - LM_GGML_SCHED_PRIO_HIGH, - LM_GGML_SCHED_PRIO_REALTIME - }; - - // Threadpool params - // Use lm_ggml_threadpool_params_default() or lm_ggml_threadpool_params_init() to populate the defaults - struct lm_ggml_threadpool_params { - bool cpumask[LM_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) - int n_threads; // number of threads - enum lm_ggml_sched_priority prio; // thread priority - uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) - bool strict_cpu; // strict cpu placement - bool paused; // start in paused state - }; - - struct lm_ggml_threadpool; // forward declaration, see ggml.c - - typedef struct lm_ggml_threadpool * lm_ggml_threadpool_t; - // the compute plan that needs to be prepared for lm_ggml_graph_compute() // since https://github.com/ggerganov/ggml/issues/287 struct lm_ggml_cplan { @@ -75,14 +52,11 @@ extern "C" { LM_GGML_BACKEND_API float lm_ggml_get_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3); LM_GGML_BACKEND_API void lm_ggml_set_f32_nd(const struct lm_ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); - LM_GGML_BACKEND_API struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads); - LM_GGML_BACKEND_API void lm_ggml_threadpool_params_init (struct lm_ggml_threadpool_params * p, int n_threads); - LM_GGML_BACKEND_API bool lm_ggml_threadpool_params_match (const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1); - LM_GGML_BACKEND_API struct lm_ggml_threadpool * lm_ggml_threadpool_new (struct lm_ggml_threadpool_params * params); - LM_GGML_BACKEND_API void lm_ggml_threadpool_free (struct lm_ggml_threadpool * threadpool); - LM_GGML_BACKEND_API int lm_ggml_threadpool_get_n_threads(struct lm_ggml_threadpool * threadpool); - LM_GGML_BACKEND_API void lm_ggml_threadpool_pause (struct lm_ggml_threadpool * threadpool); - LM_GGML_BACKEND_API void lm_ggml_threadpool_resume (struct lm_ggml_threadpool * threadpool); + LM_GGML_BACKEND_API struct lm_ggml_threadpool * lm_ggml_threadpool_new (struct lm_ggml_threadpool_params * params); + LM_GGML_BACKEND_API void lm_ggml_threadpool_free (struct lm_ggml_threadpool * threadpool); + LM_GGML_BACKEND_API int lm_ggml_threadpool_get_n_threads (struct lm_ggml_threadpool * threadpool); + LM_GGML_BACKEND_API void lm_ggml_threadpool_pause (struct lm_ggml_threadpool * threadpool); + LM_GGML_BACKEND_API void lm_ggml_threadpool_resume (struct lm_ggml_threadpool * threadpool); // lm_ggml_graph_plan() has to be called before lm_ggml_graph_compute() // when plan.work_size > 0, caller must allocate memory for plan.work_data @@ -104,10 +78,10 @@ extern "C" { LM_GGML_BACKEND_API int lm_ggml_cpu_has_sse3 (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_ssse3 (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx_vnni (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx2 (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_f16c (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_fma (void); - LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx_vnni (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx512 (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx512_vbmi(void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_avx512_vnni(void); @@ -117,6 +91,7 @@ extern "C" { LM_GGML_BACKEND_API int lm_ggml_cpu_has_neon (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_arm_fma (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_fp16_va (void); + LM_GGML_BACKEND_API int lm_ggml_cpu_has_dotprod (void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_matmul_int8(void); LM_GGML_BACKEND_API int lm_ggml_cpu_has_sve (void); LM_GGML_BACKEND_API int lm_ggml_cpu_get_sve_cnt (void); // sve vector length in bytes @@ -128,24 +103,14 @@ extern "C" { // Internal types and functions exposed for tests and benchmarks - typedef void (*lm_ggml_from_float_to_mat_t) - (const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs); typedef void (*lm_ggml_vec_dot_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, size_t bx, const void * LM_GGML_RESTRICT y, size_t by, int nrc); - typedef void (*lm_ggml_gemv_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, - const void * LM_GGML_RESTRICT y, int nr, int nc); - typedef void (*lm_ggml_gemm_t) (int n, float * LM_GGML_RESTRICT s, size_t bs, const void * LM_GGML_RESTRICT x, - const void * LM_GGML_RESTRICT y, int nr, int nc); struct lm_ggml_type_traits_cpu { lm_ggml_from_float_t from_float; - lm_ggml_from_float_to_mat_t from_float_to_mat; lm_ggml_vec_dot_t vec_dot; enum lm_ggml_type vec_dot_type; int64_t nrows; // number of rows to process simultaneously - int64_t ncols; // number of columns to process simultaneously - lm_ggml_gemv_t gemv; - lm_ggml_gemm_t gemm; }; LM_GGML_BACKEND_API const struct lm_ggml_type_traits_cpu * lm_ggml_get_type_traits_cpu(enum lm_ggml_type type); @@ -165,13 +130,6 @@ extern "C" { LM_GGML_BACKEND_API lm_ggml_backend_reg_t lm_ggml_backend_cpu_reg(void); -#ifdef LM_GGML_USE_CPU_HBM - LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_hbm_buffer_type(void); -#endif - - LM_GGML_BACKEND_API lm_ggml_backend_buffer_type_t lm_ggml_backend_cpu_aarch64_buffer_type(void); - LM_GGML_BACKEND_API bool lm_ggml_backend_cpu_buft_is_aarch64(lm_ggml_backend_buffer_type_t buft); - #ifdef __cplusplus } #endif diff --git a/cpp/ggml-impl.h b/cpp/ggml-impl.h index 17cfe65d..2a283094 100644 --- a/cpp/ggml-impl.h +++ b/cpp/ggml-impl.h @@ -14,7 +14,7 @@ #include #endif // __ARM_FEATURE_SVE -#if defined(__ARM_NEON) +#if defined(__ARM_NEON) && !defined(__CUDACC__) // if YCM cannot find , make a symbolic link to it, for example: // // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/ @@ -30,11 +30,13 @@ extern "C" { #endif -#undef MIN -#undef MAX +#ifndef MIN +# define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#ifndef MAX +# define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif // required for mmap as gguf only guarantees 32-byte alignment #define TENSOR_ALIGNMENT 32 @@ -72,8 +74,8 @@ static inline int lm_ggml_up(int n, int m) { // LM_GGML_ATTRIBUTE_FORMAT(2, 3) -void lm_ggml_log_internal (enum lm_ggml_log_level level, const char * format, ...); -void lm_ggml_log_callback_default(enum lm_ggml_log_level level, const char * text, void * user_data); +LM_GGML_API void lm_ggml_log_internal (enum lm_ggml_log_level level, const char * format, ...); +LM_GGML_API void lm_ggml_log_callback_default(enum lm_ggml_log_level level, const char * text, void * user_data); #define LM_GGML_LOG(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_NONE , __VA_ARGS__) #define LM_GGML_LOG_INFO(...) lm_ggml_log_internal(LM_GGML_LOG_LEVEL_INFO , __VA_ARGS__) @@ -295,24 +297,27 @@ struct lm_ggml_cgraph { enum lm_ggml_cgraph_eval_order order; }; +// returns a slice of cgraph with nodes [i0, i1) +// the slice does not have leafs or gradients +// if you need the gradients, get them from the original graph struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph, int i0, int i1); // Memory allocation -void * lm_ggml_aligned_malloc(size_t size); -void lm_ggml_aligned_free(void * ptr, size_t size); +LM_GGML_API void * lm_ggml_aligned_malloc(size_t size); +LM_GGML_API void lm_ggml_aligned_free(void * ptr, size_t size); // FP16 to FP32 conversion #if defined(__ARM_NEON) - #ifdef _MSC_VER + #if defined(_MSC_VER) || (defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) typedef uint16_t lm_ggml_fp16_internal_t; #else typedef __fp16 lm_ggml_fp16_internal_t; #endif #endif -#if defined(__ARM_NEON) && !defined(_MSC_VER) +#if defined(__ARM_NEON) && !defined(_MSC_VER) && !(defined(__CUDACC__) && __CUDACC_VER_MAJOR__ <= 11) #define LM_GGML_COMPUTE_FP16_TO_FP32(x) lm_ggml_compute_fp16_to_fp32(x) #define LM_GGML_COMPUTE_FP32_TO_FP16(x) lm_ggml_compute_fp32_to_fp16(x) @@ -546,6 +551,22 @@ static inline lm_ggml_bf16_t lm_ggml_compute_fp32_to_bf16(float s) { #define LM_GGML_FP32_TO_BF16(x) lm_ggml_compute_fp32_to_bf16(x) #define LM_GGML_BF16_TO_FP32(x) lm_ggml_compute_bf16_to_fp32(x) +// expose GGUF internals for test code + +LM_GGML_API size_t lm_gguf_type_size(enum lm_gguf_type type); + +LM_GGML_API struct lm_gguf_context * lm_gguf_init_from_file_impl(FILE * file, struct lm_gguf_init_params params); + +struct lm_gguf_buf { + void * data; + size_t size; + size_t offset; +}; +LM_GGML_API struct lm_gguf_buf lm_gguf_buf_init(size_t size); +LM_GGML_API void lm_gguf_buf_free(struct lm_gguf_buf buf); + +LM_GGML_API void lm_gguf_write_to_buf(const struct lm_gguf_context * ctx, struct lm_gguf_buf * buf, bool only_meta); + #ifdef __cplusplus } #endif diff --git a/cpp/ggml-metal-impl.h b/cpp/ggml-metal-impl.h index 481e010d..7fd584be 100644 --- a/cpp/ggml-metal-impl.h +++ b/cpp/ggml-metal-impl.h @@ -102,6 +102,21 @@ typedef struct { uint64_t nb3; } lm_ggml_metal_kargs_cpy; +typedef struct { + int64_t ne10; + int64_t ne11; + int64_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; + bool inplace; +} lm_ggml_metal_kargs_set; + typedef struct { int32_t ne00; int32_t ne01; @@ -192,6 +207,30 @@ typedef struct { int16_t r3; } lm_ggml_metal_kargs_mul_mv; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; + int16_t nsg; + int16_t nxpsg; + int16_t r1ptg; +} lm_ggml_metal_kargs_mul_mv_ext; + typedef struct { int32_t nei0; int32_t nei1; diff --git a/cpp/ggml-metal.m b/cpp/ggml-metal.m index 468c783e..23ed294b 100644 --- a/cpp/ggml-metal.m +++ b/cpp/ggml-metal.m @@ -175,6 +175,46 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, + LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, @@ -266,8 +306,11 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, + LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, + LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, LM_GGML_METAL_KERNEL_TYPE_PAD_F32, + LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, @@ -329,6 +372,8 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, + LM_GGML_METAL_KERNEL_TYPE_SET_I32, + LM_GGML_METAL_KERNEL_TYPE_SET_F32, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, @@ -350,6 +395,7 @@ static void lm_ggml_backend_metal_device_rel(struct lm_ggml_backend_metal_device LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, + LM_GGML_METAL_KERNEL_TYPE_ARGMAX, LM_GGML_METAL_KERNEL_TYPE_COUNT }; @@ -464,6 +510,35 @@ @implementation LMGGMLMetalClass #endif NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"]; + if (path_lib == nil) { + // Try to find the resource in the directory where the current binary located. + NSString * current_binary = [[NSProcessInfo processInfo] arguments][0]; + NSString * bin_dir = [current_binary stringByDeletingLastPathComponent]; + NSString * default_metallib_path = [NSString pathWithComponents:@[bin_dir, @"default.metallib"]]; + if ([[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { + LM_GGML_LOG_INFO("%s: found '%s'\n", __func__, [default_metallib_path UTF8String]); + NSDictionary * atts = [[NSFileManager defaultManager] attributesOfItemAtPath:default_metallib_path error:&error]; + if (atts && atts[NSFileType] == NSFileTypeSymbolicLink) { + // Optionally, if this is a symlink, try to resolve it. + default_metallib_path = [[NSFileManager defaultManager] destinationOfSymbolicLinkAtPath:default_metallib_path error:&error]; + if (default_metallib_path && [default_metallib_path length] > 0 && ![[default_metallib_path substringToIndex:1] isEqualToString:@"/"]) { + // It is a relative path, adding the binary directory as directory prefix. + default_metallib_path = [NSString pathWithComponents:@[bin_dir, default_metallib_path]]; + } + if (!default_metallib_path || ![[NSFileManager defaultManager] isReadableFileAtPath:default_metallib_path]) { + // Link to the resource could not be resolved. + default_metallib_path = nil; + } else { + LM_GGML_LOG_INFO("%s: symlink resolved '%s'\n", __func__, [default_metallib_path UTF8String]); + } + } + } else { + // The resource couldn't be found in the binary's directory. + default_metallib_path = nil; + } + path_lib = default_metallib_path; + } + if (try_metallib && path_lib != nil) { // pre-compiled library found NSURL * libURL = [NSURL fileURLWithPath:path_lib]; @@ -699,6 +774,46 @@ @implementation LMGGMLMetalClass LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5, mul_mv_ext_q4_K_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2, mul_mv_ext_q5_K_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3, mul_mv_ext_q5_K_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4, mul_mv_ext_q5_K_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5, mul_mv_ext_q5_K_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2, mul_mv_ext_q6_K_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3, mul_mv_ext_q6_K_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4, mul_mv_ext_q6_K_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5, mul_mv_ext_q6_K_f32_r1_5, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2, mul_mv_ext_iq4_nl_f32_r1_2, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3, mul_mv_ext_iq4_nl_f32_r1_3, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4, mul_mv_ext_iq4_nl_f32_r1_4, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5, mul_mv_ext_iq4_nl_f32_r1_5, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); @@ -790,8 +905,11 @@ @implementation LMGGMLMetalClass LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32, im2col_ext_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32, conv_transpose_1d_f32_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32, conv_transpose_1d_f16_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32, pad_reflect_1d_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32, timestep_embedding_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARANGE_F32, arange_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); @@ -853,6 +971,8 @@ @implementation LMGGMLMetalClass LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256, flash_attn_ext_vec_q5_1_h256, has_simdgroup_reduction); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, flash_attn_ext_vec_q8_0_h256, has_simdgroup_reduction); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, cpy_f32_bf16, use_bfloat); @@ -872,6 +992,7 @@ @implementation LMGGMLMetalClass LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SIN, sin, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_COS, cos, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); + LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true); LM_GGML_METAL_ADD_KERNEL(LM_GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } @@ -989,6 +1110,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ case LM_GGML_OP_REPEAT: case LM_GGML_OP_SCALE: case LM_GGML_OP_CLAMP: + case LM_GGML_OP_CONV_TRANSPOSE_1D: return true; case LM_GGML_OP_SQR: case LM_GGML_OP_SQRT: @@ -997,12 +1119,24 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ return lm_ggml_is_contiguous(op->src[0]); case LM_GGML_OP_SUM_ROWS: case LM_GGML_OP_SOFT_MAX: - case LM_GGML_OP_RMS_NORM: case LM_GGML_OP_GROUP_NORM: return has_simdgroup_reduction; + case LM_GGML_OP_RMS_NORM: + return has_simdgroup_reduction && (op->ne[0] % 4 == 0); + case LM_GGML_OP_ARGMAX: case LM_GGML_OP_NORM: - case LM_GGML_OP_ROPE: return true; + case LM_GGML_OP_ROPE: + { + const int mode = ((const int32_t *) op->op_params)[2]; + if (mode & LM_GGML_ROPE_TYPE_MROPE) { + return false; + } + if (mode & LM_GGML_ROPE_TYPE_VISION) { + return false; + } + return true; + } case LM_GGML_OP_IM2COL: return op->src[0]->type == LM_GGML_TYPE_F16; case LM_GGML_OP_POOL_1D: @@ -1010,6 +1144,7 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ case LM_GGML_OP_POOL_2D: case LM_GGML_OP_UPSCALE: case LM_GGML_OP_PAD: + case LM_GGML_OP_PAD_REFLECT_1D: case LM_GGML_OP_ARANGE: case LM_GGML_OP_TIMESTEP_EMBEDDING: case LM_GGML_OP_ARGSORT: @@ -1067,6 +1202,16 @@ static bool lm_ggml_metal_supports_op(const struct lm_ggml_backend_metal_device_ return false; }; } + case LM_GGML_OP_SET: + { + switch (op->src[0]->type) { + case LM_GGML_TYPE_F32: + case LM_GGML_TYPE_I32: + return true; + default: + return false; + }; + } case LM_GGML_OP_DIAG_MASK_INF: case LM_GGML_OP_GET_ROWS: { @@ -1927,340 +2072,490 @@ static void lm_ggml_metal_encode_node( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = 1; + const int ne11_mm_min = 4; + + // first try to use small-batch mat-mv kernels + // these should be efficient for BS [2, ~8] + if (src1t == LM_GGML_TYPE_F32 && (ne00%256 == 0) && + ( + ( + ( + src0t == LM_GGML_TYPE_F16 || // TODO: helper function + src0t == LM_GGML_TYPE_Q4_0 || + src0t == LM_GGML_TYPE_Q4_1 || + src0t == LM_GGML_TYPE_Q5_0 || + src0t == LM_GGML_TYPE_Q5_1 || + src0t == LM_GGML_TYPE_Q8_0 || + src0t == LM_GGML_TYPE_IQ4_NL || + false) && (ne11 >= 2 && ne11 <= 8) + ) || + ( + ( + src0t == LM_GGML_TYPE_Q4_K || + src0t == LM_GGML_TYPE_Q5_K || + src0t == LM_GGML_TYPE_Q6_K || + false) && (ne11 >= 4 && ne11 <= 8) + ) + ) + ) { + // TODO: determine the optimal parameters based on grid utilization + // I still don't know why we should not always use the maximum available threads: + // + // nsg = pipeline.maxTotalThreadsPerThreadgroup / 32 + // + // my current hypothesis is that the work grid is not evenly divisible for different nsg + // values and there can be some tail effects when nsg is high. need to confirm this + // + const int nsg = 2; // num simdgroups per threadgroup + const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup + const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time) + const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup + int r1ptg = 4; // num src1 rows per threadgroup + + // note: not sure how optimal are those across all different hardware. there might be someting cleverer + switch (ne11) { + case 2: + r1ptg = 2; break; + case 3: + case 6: + r1ptg = 3; break; + case 4: + case 7: + case 8: + r1ptg = 4; break; + case 5: + r1ptg = 5; break; + }; -#if 0 - // the numbers below are measured on M2 Ultra for 7B and 13B models - // these numbers do not translate to other devices or model sizes - // TODO: need to find a better approach - if ([device.name isEqualToString:@"Apple M2 Ultra"]) { - switch (src0t) { - case LM_GGML_TYPE_F16: ne11_mm_min = 2; break; - case LM_GGML_TYPE_Q8_0: ne11_mm_min = 7; break; - case LM_GGML_TYPE_Q2_K: ne11_mm_min = 15; break; - case LM_GGML_TYPE_Q3_K: ne11_mm_min = 7; break; - case LM_GGML_TYPE_Q4_0: - case LM_GGML_TYPE_Q4_1: ne11_mm_min = 15; break; - case LM_GGML_TYPE_Q4_K: ne11_mm_min = 11; break; - case LM_GGML_TYPE_Q5_0: // not tested yet - case LM_GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet - case LM_GGML_TYPE_Q5_K: ne11_mm_min = 7; break; - case LM_GGML_TYPE_Q6_K: ne11_mm_min = 7; break; - default: ne11_mm_min = 1; break; - } - } -#endif + id pipeline = nil; - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs - // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if ([device supportsFamily:MTLGPUFamilyApple7] && - !lm_ggml_is_transposed(src0) && - !lm_ggml_is_transposed(src1) && - src1t == LM_GGML_TYPE_F32 && - ne00 % 32 == 0 && ne00 >= 64 && - (ne11 > ne11_mm_min || (lm_ggml_is_quantized(src0t) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - // some Metal matrix data types require aligned pointers - // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) - switch (src0->type) { - case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; - case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; - case LM_GGML_TYPE_BF16: LM_GGML_ASSERT(nb01 % 8 == 0); break; - default: break; - } + switch (src0->type) { + case LM_GGML_TYPE_F16: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_Q4_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_Q4_1: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_Q5_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_Q5_1: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_Q8_0: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_Q4_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_Q5_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_Q6_K: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + case LM_GGML_TYPE_IQ4_NL: + switch (r1ptg) { + case 2: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break; + case 3: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break; + case 4: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break; + case 5: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break; + default: LM_GGML_ABORT("not implemented"); + } break; + default: LM_GGML_ABORT("not implemented"); + } - id pipeline = nil; - - switch (src0->type) { - case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; - case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; - case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; - case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; - case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; - case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; - case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; - case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; - case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; - case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; - case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; - case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; - case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; - case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; - case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; - case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; - default: LM_GGML_ABORT("MUL MAT-MAT not implemented"); - } + lm_ggml_metal_kargs_mul_mv_ext args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + /*.nsg =*/ nsg, + /*.nxpsg =*/ nxpsg, + /*.r1ptg =*/ r1ptg, + }; - lm_ggml_metal_kargs_mul_mm args = { - /*.ne00 =*/ ne00, - /*.ne02 =*/ ne02, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - }; + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + if ([device supportsFamily:MTLGPUFamilyApple7] && + !lm_ggml_is_transposed(src0) && + !lm_ggml_is_transposed(src1) && + src1t == LM_GGML_TYPE_F32 && + ne00 % 32 == 0 && ne00 >= 64 && + (ne11 > ne11_mm_min || (lm_ggml_is_quantized(src0t) && ne12 > 1))) { + //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + // some Metal matrix data types require aligned pointers + // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) + switch (src0->type) { + case LM_GGML_TYPE_F32: LM_GGML_ASSERT(nb01 % 16 == 0); break; + case LM_GGML_TYPE_F16: LM_GGML_ASSERT(nb01 % 8 == 0); break; + case LM_GGML_TYPE_BF16: LM_GGML_ASSERT(nb01 % 8 == 0); break; + default: break; + } + + id pipeline = nil; + + switch (src0->type) { + case LM_GGML_TYPE_F32: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32 ].pipeline; break; + case LM_GGML_TYPE_F16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32 ].pipeline; break; + case LM_GGML_TYPE_BF16: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32 ].pipeline; break; + case LM_GGML_TYPE_Q4_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32 ].pipeline; break; + case LM_GGML_TYPE_Q4_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32 ].pipeline; break; + case LM_GGML_TYPE_Q5_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break; + case LM_GGML_TYPE_Q5_1: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break; + case LM_GGML_TYPE_Q8_0: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break; + case LM_GGML_TYPE_Q2_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break; + case LM_GGML_TYPE_Q3_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break; + case LM_GGML_TYPE_Q4_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break; + case LM_GGML_TYPE_Q5_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32 ].pipeline; break; + case LM_GGML_TYPE_Q6_K: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; + case LM_GGML_TYPE_IQ2_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; + case LM_GGML_TYPE_IQ3_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ2_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ1_S: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ1_M: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; + case LM_GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + default: LM_GGML_ABORT("MUL MAT-MAT not implemented"); + } + + lm_ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - - [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } else { - int nth0 = 32; - int nth1 = 1; - int nrows = 1; - //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); - - id pipeline = nil; - - // use custom matrix x vector kernel - switch (src0t) { - case LM_GGML_TYPE_F32: - { - LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } else { + int nth0 = 32; + int nth1 = 1; + int nrows = 1; + //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + + id pipeline = nil; + + // use custom matrix x vector kernel + switch (src0t) { + case LM_GGML_TYPE_F32: + { + LM_GGML_ASSERT(src1t == LM_GGML_TYPE_F32); + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; + nrows = 4; + } break; + case LM_GGML_TYPE_F16: + { + nth0 = 32; + nth1 = 1; + if (src1t == LM_GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; + nrows = ne11; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; nrows = 4; - } break; - case LM_GGML_TYPE_F16: - { - nth0 = 32; - nth1 = 1; - if (src1t == LM_GGML_TYPE_F32) { - if (ne11 * ne12 < 4) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; - nrows = ne11; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; - nrows = 4; - } - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; - nrows = 4; - } - } break; - case LM_GGML_TYPE_BF16: - { - nth0 = 32; - nth1 = 1; - if (src1t == LM_GGML_TYPE_F32) { - if (ne11 * ne12 < 4) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; - nrows = ne11; - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; - nrows = 4; - } - } else { - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; - nrows = 4; - } - } break; - case LM_GGML_TYPE_Q4_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; - } break; - case LM_GGML_TYPE_Q4_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; - } break; - case LM_GGML_TYPE_Q5_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; - } break; - case LM_GGML_TYPE_Q5_1: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; - } break; - case LM_GGML_TYPE_Q8_0: - { - nth0 = 8; - nth1 = 8; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; - } break; - case LM_GGML_TYPE_Q2_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; - } break; - case LM_GGML_TYPE_Q3_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; - } break; - case LM_GGML_TYPE_Q4_K: - { - nth0 = 4; //1; - nth1 = 8; //32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; - } break; - case LM_GGML_TYPE_Q5_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; - } break; - case LM_GGML_TYPE_Q6_K: - { - nth0 = 2; - nth1 = 32; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ2_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ2_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ3_XXS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ3_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ2_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ1_S: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ1_M: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ4_NL: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; - } break; - case LM_GGML_TYPE_IQ4_XS: - { - nth0 = 4; - nth1 = 16; - pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; - } break; - default: - { - LM_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); - LM_GGML_ABORT("not implemented"); } - }; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; + nrows = 4; + } + } break; + case LM_GGML_TYPE_BF16: + { + nth0 = 32; + nth1 = 1; + if (src1t == LM_GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; + nrows = ne11; + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; + nrows = 4; + } + } else { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; + nrows = 4; + } + } break; + case LM_GGML_TYPE_Q4_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; + } break; + case LM_GGML_TYPE_Q4_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; + } break; + case LM_GGML_TYPE_Q5_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; + } break; + case LM_GGML_TYPE_Q5_1: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; + } break; + case LM_GGML_TYPE_Q8_0: + { + nth0 = 8; + nth1 = 8; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; + } break; + case LM_GGML_TYPE_Q2_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; + } break; + case LM_GGML_TYPE_Q3_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; + } break; + case LM_GGML_TYPE_Q4_K: + { + nth0 = 4; //1; + nth1 = 8; //32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; + } break; + case LM_GGML_TYPE_Q5_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; + } break; + case LM_GGML_TYPE_Q6_K: + { + nth0 = 2; + nth1 = 32; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ2_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ2_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ3_XXS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ3_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ2_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ1_S: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ1_M: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ4_NL: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; + } break; + case LM_GGML_TYPE_IQ4_XS: + { + nth0 = 4; + nth1 = 16; + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; + } break; + default: + { + LM_GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); + LM_GGML_ABORT("not implemented"); + } + }; - lm_ggml_metal_kargs_mul_mv args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb03 =*/ nb03, - /*.ne10 =*/ ne10, - /*.ne11 =*/ ne11, - /*.ne12 =*/ ne12, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.r2 =*/ r2, - /*.r3 =*/ r3, - }; + lm_ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; - [encoder setComputePipelineState:pipeline]; - [encoder setBytes:&args length:sizeof(args) atIndex:0]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 || - src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K || - src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) { - const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) { - const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) { - const int mem_size = 32*sizeof(float); - [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_Q3_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == LM_GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } else { - const int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - } + if (src0t == LM_GGML_TYPE_Q4_0 || src0t == LM_GGML_TYPE_Q4_1 || src0t == LM_GGML_TYPE_Q5_0 || + src0t == LM_GGML_TYPE_Q5_1 || src0t == LM_GGML_TYPE_Q8_0 || src0t == LM_GGML_TYPE_Q2_K || + src0t == LM_GGML_TYPE_IQ1_S || src0t == LM_GGML_TYPE_IQ1_M || src0t == LM_GGML_TYPE_IQ2_S) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_IQ2_XXS || src0t == LM_GGML_TYPE_IQ2_XS) { + const int mem_size = src0t == LM_GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_IQ3_XXS || src0t == LM_GGML_TYPE_IQ3_S) { + const int mem_size = src0t == LM_GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_IQ4_NL || src0t == LM_GGML_TYPE_IQ4_XS) { + const int mem_size = 32*sizeof(float); + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_Q4_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_Q3_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == LM_GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } else { + const int64_t ny = (ne11 + nrows - 1)/nrows; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + } } break; case LM_GGML_OP_MUL_MAT_ID: { @@ -2672,7 +2967,6 @@ static void lm_ggml_metal_encode_node( } break; case LM_GGML_OP_GROUP_NORM: { - LM_GGML_ASSERT(ne00 % 4 == 0); LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); float eps; @@ -2742,7 +3036,9 @@ static void lm_ggml_metal_encode_node( } break; case LM_GGML_OP_ROPE: { - LM_GGML_ASSERT(ne10 == ne02); + // make sure we have one or more position id(ne10) per token(ne02) + LM_GGML_ASSERT(ne10 % ne02 == 0); + LM_GGML_ASSERT(ne10 >= ne02); const int nth = MIN(1024, ne00); @@ -2908,6 +3204,49 @@ static void lm_ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; } } break; + case LM_GGML_OP_CONV_TRANSPOSE_1D: + { + LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(src1)); + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F16 || src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(src1->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT( dst->type == LM_GGML_TYPE_F32); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + + const int32_t IC = src1->ne[1]; + const int32_t IL = src1->ne[0]; + + const int32_t K = src0->ne[0]; + + const int32_t OL = dst->ne[0]; + const int32_t OC = dst->ne[1]; + + id pipeline; + + switch (src0->type) { + case LM_GGML_TYPE_F32: { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F32_F32].pipeline; + } break; + case LM_GGML_TYPE_F16: { + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_CONV_TRANSPOSE_1D_F16_F32].pipeline; + } break; + default: LM_GGML_ABORT("fatal error"); + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&IC length:sizeof( int32_t) atIndex:3]; + [encoder setBytes:&IL length:sizeof( int32_t) atIndex:4]; + [encoder setBytes:&K length:sizeof( int32_t) atIndex:5]; + [encoder setBytes:&s0 length:sizeof( int32_t) atIndex:6]; + [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:7]; + [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:8]; + + [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case LM_GGML_OP_UPSCALE: { LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); @@ -2977,6 +3316,38 @@ static void lm_ggml_metal_encode_node( const int nth = MIN(1024, ne0); + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; + case LM_GGML_OP_PAD_REFLECT_1D: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + + const int32_t p0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[1]; + + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:6]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:11]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:12]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:13]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:14]; + [encoder setBytes:&p0 length:sizeof(p0) atIndex:15]; + [encoder setBytes:&p1 length:sizeof(p1) atIndex:16]; + + const int nth = MIN(1024, ne0); + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case LM_GGML_OP_ARANGE: @@ -3508,6 +3879,68 @@ static void lm_ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; + case LM_GGML_OP_SET: + { + LM_GGML_ASSERT(lm_ggml_are_same_shape(src0, dst)); + LM_GGML_ASSERT(lm_ggml_is_contiguous(dst) && lm_ggml_is_contiguous(src0)); + + // src0 and dst as viewed during set + const size_t dst_nb0 = lm_ggml_element_size(src0); + + const size_t dst_nb1 = ((int32_t *) dst->op_params)[0]; + const size_t dst_nb2 = ((int32_t *) dst->op_params)[1]; + const size_t dst_nb3 = ((int32_t *) dst->op_params)[2]; + const size_t offset = ((int32_t *) dst->op_params)[3]; + const bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + memcpy(((char *) dst->data), ((char *) src0->data), lm_ggml_nbytes(dst)); + } + + const int im0 = (ne10 == 0 ? 0 : ne10-1); + const int im1 = (ne11 == 0 ? 0 : ne11-1); + const int im2 = (ne12 == 0 ? 0 : ne12-1); + const int im3 = (ne13 == 0 ? 0 : ne13-1); + + LM_GGML_ASSERT(offset + im0*dst_nb0 + im1*dst_nb1 + im2*dst_nb2 + im3*dst_nb3 <= lm_ggml_nbytes(dst)); + + id pipeline = nil; + + switch (src0t) { + case LM_GGML_TYPE_F32: + LM_GGML_ASSERT(nb10 == sizeof(float)); + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break; + case LM_GGML_TYPE_I32: + LM_GGML_ASSERT(nb10 == sizeof(int32_t)); + pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break; + default: LM_GGML_ABORT("fatal error"); + } + + lm_ggml_metal_kargs_set args = { + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb1 =*/ dst_nb1, + /*.nb2 =*/ dst_nb2, + /*.nb3 =*/ dst_nb3, + /*.offs =*/ offset, + /*.inplace =*/ inplace, + }; + + const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10); + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; case LM_GGML_OP_POOL_2D: { LM_GGML_ASSERT(lm_ggml_is_contiguous(src0)); @@ -3567,6 +4000,31 @@ static void lm_ggml_metal_encode_node( [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)]; } break; + case LM_GGML_OP_ARGMAX: + { + LM_GGML_ASSERT(src0->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(lm_ggml_is_contiguous_1(src0)); + LM_GGML_ASSERT(nb00 == lm_ggml_type_size(src0->type)); + + const int64_t nrows = lm_ggml_nrows(src0); + + int nth = 32; // SIMD width + while (nth < ne00 && nth*ne01*ne02*ne03 < 256) { + nth *= 2; + } + + id pipeline = ctx->kernels[LM_GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline; + + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1]; + + [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; + } break; default: { LM_GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, lm_ggml_op_name(dst->op)); @@ -4372,19 +4830,45 @@ static lm_ggml_backend_dev_t lm_ggml_backend_metal_reg_device_get(lm_ggml_backen LM_GGML_UNUSED(index); } +static struct lm_ggml_backend_feature g_lm_ggml_backend_metal_features[] = { +#if defined(LM_GGML_METAL_EMBED_LIBRARY) + { "EMBED_LIBRARY", "1" }, +#endif +#if defined(LM_GGML_METAL_USE_BF16) + { "BF16", "1" }, +#endif + { nil, nil }, +}; + +static struct lm_ggml_backend_feature * lm_ggml_backend_metal_get_features(lm_ggml_backend_reg_t reg) { + return g_lm_ggml_backend_metal_features; + + LM_GGML_UNUSED(reg); +} + +static void * lm_ggml_backend_metal_get_proc_address(lm_ggml_backend_reg_t reg, const char * name) { + if (strcmp(name, "lm_ggml_backend_get_features") == 0) { + return (void *)lm_ggml_backend_metal_get_features; + } + + return NULL; + + LM_GGML_UNUSED(reg); +} static struct lm_ggml_backend_reg_i lm_ggml_backend_metal_reg_i = { /* .get_name = */ lm_ggml_backend_metal_reg_get_name, /* .device_count = */ lm_ggml_backend_metal_reg_device_count, /* .device_get = */ lm_ggml_backend_metal_reg_device_get, - /* .get_proc_address = */ NULL, + /* .get_proc_address = */ lm_ggml_backend_metal_get_proc_address, }; lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void) { // TODO: make this thread-safe somehow? { g_lm_ggml_backend_metal_reg = (struct lm_ggml_backend_reg) { - /* .iface = */ lm_ggml_backend_metal_reg_i, - /* .context = */ NULL, + /* .api_version = */ LM_GGML_BACKEND_API_VERSION, + /* .iface = */ lm_ggml_backend_metal_reg_i, + /* .context = */ NULL, }; g_lm_ggml_backend_metal_device = (struct lm_ggml_backend_device) { @@ -4396,3 +4880,5 @@ lm_ggml_backend_reg_t lm_ggml_backend_metal_reg(void) { return &g_lm_ggml_backend_metal_reg; } + +LM_GGML_BACKEND_DL_IMPL(lm_ggml_backend_metal_reg) diff --git a/cpp/ggml-opt.cpp b/cpp/ggml-opt.cpp index 5204eb3c..4f36c9bc 100644 --- a/cpp/ggml-opt.cpp +++ b/cpp/ggml-opt.cpp @@ -14,51 +14,51 @@ #include struct lm_ggml_opt_dataset { - struct lm_ggml_context * ctx; - lm_ggml_backend_buffer_t buf; - struct lm_ggml_tensor * data; - struct lm_ggml_tensor * labels; + struct lm_ggml_context * ctx = nullptr; + lm_ggml_backend_buffer_t buf = nullptr; + struct lm_ggml_tensor * data = nullptr; + struct lm_ggml_tensor * labels = nullptr; - int64_t ndata; - int64_t ndata_shard; - size_t nbs_data; - size_t nbs_labels; + int64_t ndata = -1; + int64_t ndata_shard = -1; + size_t nbs_data = -1; + size_t nbs_labels = -1; std::vector permutation; }; struct lm_ggml_opt_context { - lm_ggml_backend_sched_t backend_sched; - lm_ggml_cgraph * allocated_graph; - lm_ggml_cgraph * allocated_graph_copy; - struct lm_ggml_context * ctx_static; - struct lm_ggml_context * ctx_static_cpu; - struct lm_ggml_context * ctx_compute; - struct lm_ggml_context * ctx_copy; - lm_ggml_backend_buffer_t buf_static; - lm_ggml_backend_buffer_t buf_static_cpu; + lm_ggml_backend_sched_t backend_sched = nullptr; + lm_ggml_cgraph * allocated_graph = nullptr; + lm_ggml_cgraph * allocated_graph_copy = nullptr; + struct lm_ggml_context * ctx_static = nullptr; + struct lm_ggml_context * ctx_static_cpu = nullptr; + struct lm_ggml_context * ctx_compute = nullptr; + struct lm_ggml_context * ctx_copy = nullptr; + lm_ggml_backend_buffer_t buf_static = nullptr; + lm_ggml_backend_buffer_t buf_static_cpu = nullptr; std::mt19937 rng; - struct lm_ggml_tensor * inputs; - struct lm_ggml_tensor * outputs; - struct lm_ggml_tensor * labels; + struct lm_ggml_tensor * inputs = nullptr; + struct lm_ggml_tensor * outputs = nullptr; + struct lm_ggml_tensor * labels = nullptr; - struct lm_ggml_tensor * loss; - struct lm_ggml_tensor * pred; - struct lm_ggml_tensor * ncorrect; + struct lm_ggml_tensor * loss = nullptr; + struct lm_ggml_tensor * pred = nullptr; + struct lm_ggml_tensor * ncorrect = nullptr; - struct lm_ggml_cgraph * gf; - struct lm_ggml_cgraph * gb_grad; - struct lm_ggml_cgraph * gb_opt; + struct lm_ggml_cgraph * gf = nullptr; + struct lm_ggml_cgraph * gb_grad = nullptr; + struct lm_ggml_cgraph * gb_opt = nullptr; - int64_t iter; - int32_t opt_period; - int32_t opt_i; - bool loss_per_datapoint; + int64_t iter = 1; + int32_t opt_period = 1; + int32_t opt_i = 0; + bool loss_per_datapoint = false; - lm_ggml_opt_get_optimizer_params get_opt_pars; - void * get_opt_pars_ud; - struct lm_ggml_tensor * adamw_params; + lm_ggml_opt_get_optimizer_params get_opt_pars = nullptr; + void * get_opt_pars_ud = nullptr; + struct lm_ggml_tensor * adamw_params = nullptr; }; struct lm_ggml_opt_result { @@ -67,8 +67,8 @@ struct lm_ggml_opt_result { std::vector pred; int64_t ncorrect = 0; - bool loss_per_datapoint = false; - int64_t opt_period = -1; + int64_t opt_period = -1; + bool loss_per_datapoint = false; }; // ====== Dataset ====== @@ -188,11 +188,11 @@ struct lm_ggml_opt_optimizer_params lm_ggml_opt_get_default_optimizer_params(voi } struct lm_ggml_opt_params lm_ggml_opt_default_params( - lm_ggml_backend_sched_t backend_sched, - struct lm_ggml_context * ctx_compute, - struct lm_ggml_tensor * inputs, - struct lm_ggml_tensor * outputs, - enum lm_ggml_opt_loss_type loss_type) { + lm_ggml_backend_sched_t backend_sched, + struct lm_ggml_context * ctx_compute, + struct lm_ggml_tensor * inputs, + struct lm_ggml_tensor * outputs, + enum lm_ggml_opt_loss_type loss_type) { return { /*backend_sched =*/ backend_sched, /*ctx_compute =*/ ctx_compute, @@ -237,25 +237,33 @@ static lm_ggml_tensor * map_tensor(std::map return new_tensor; } -static lm_ggml_cgraph * dup_graph(lm_ggml_context * ctx, lm_ggml_cgraph * graph) { +static lm_ggml_cgraph * dup_graph(lm_ggml_context * ctx, lm_ggml_cgraph * src) { std::map tensor_map; - lm_ggml_cgraph * new_graph = lm_ggml_new_graph_custom(ctx, LM_GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); + lm_ggml_cgraph * dst = lm_ggml_new_graph_custom(ctx, src->size, /*grads =*/ true); - for (int i = 0; i < graph->n_leafs; i++) { - lm_ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i])); + for (int i = 0; i < src->n_leafs; i++) { + lm_ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->leafs[i])); } - for (int i = 0; i < graph->n_nodes; i++) { - lm_ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i])); + LM_GGML_ASSERT(dst->n_leafs == src->n_leafs); + for (int i = 0; i < src->n_nodes; i++) { + lm_ggml_build_forward_expand(dst, map_tensor(tensor_map, ctx, src->nodes[i])); } - for (int i = 0; i < graph->n_nodes; ++i) { - const size_t igrad_src = lm_ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]); - const size_t igrad_dst = lm_ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]); - graph->grads[igrad_dst] = new_graph->grads[igrad_src]; - graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src]; + LM_GGML_ASSERT(dst->n_nodes == src->n_nodes); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = lm_ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = lm_ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + + LM_GGML_ASSERT(igrad_src != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(lm_ggml_bitset_get(src->visited_hash_set.used, igrad_src)); + LM_GGML_ASSERT(igrad_dst != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(lm_ggml_bitset_get(dst->visited_hash_set.used, igrad_dst)); + + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; } - return new_graph; + return dst; } static void lm_ggml_opt_alloc_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgraph * graph) { @@ -284,18 +292,13 @@ static void lm_ggml_opt_alloc_graph(lm_ggml_opt_context_t opt_ctx, lm_ggml_cgrap lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) { lm_ggml_opt_context_t result = new struct lm_ggml_opt_context; - result->backend_sched = params.backend_sched; - result->allocated_graph = nullptr; - result->allocated_graph_copy = nullptr; - result->ctx_compute = params.ctx_compute; - result->ctx_copy = nullptr; - result->inputs = params.inputs; - result->outputs = params.outputs; - result->iter = 1; - result->opt_period = params.opt_period; - result->opt_i = 0; - result->get_opt_pars = params.get_opt_pars; - result->get_opt_pars_ud = params.get_opt_pars_ud; + result->backend_sched = params.backend_sched; + result->ctx_compute = params.ctx_compute; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->opt_period = params.opt_period; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; LM_GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); LM_GGML_ASSERT(result->opt_period >= 1); @@ -348,7 +351,6 @@ lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) { switch (params.loss_type) { case LM_GGML_OPT_LOSS_TYPE_MEAN: { - result->labels = nullptr; result->loss = lm_ggml_sum(result->ctx_static, result->outputs); lm_ggml_set_name(result->loss, "loss_sum"); const float scale = 1.0f / (result->opt_period * lm_ggml_nelements(result->outputs)); @@ -358,7 +360,6 @@ lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) { break; } case LM_GGML_OPT_LOSS_TYPE_SUM: { - result->labels = nullptr; result->loss = lm_ggml_sum(result->ctx_static, result->outputs); lm_ggml_set_name(result->loss, "loss_sum"); result->loss_per_datapoint = false; @@ -413,14 +414,7 @@ lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) { } if (params.build_type == LM_GGML_OPT_BUILD_TYPE_FORWARD) { - result->gb_grad = nullptr; - result->gb_opt = nullptr; - result->buf_static = lm_ggml_backend_alloc_ctx_tensors(result->ctx_static, lm_ggml_backend_sched_get_backend(result->backend_sched, 0)); - result->buf_static_cpu = nullptr; - - lm_ggml_opt_alloc_graph(result, result->gf); - return result; } @@ -429,14 +423,8 @@ lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) { lm_ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); if (params.build_type == LM_GGML_OPT_BUILD_TYPE_GRAD) { - result->gb_opt = nullptr; - result->buf_static = lm_ggml_backend_alloc_ctx_tensors(result->ctx_static, lm_ggml_backend_sched_get_backend(result->backend_sched, 0)); - result->buf_static_cpu = nullptr; - - lm_ggml_opt_alloc_graph(result, result->gb_grad); lm_ggml_graph_reset(result->gb_grad); - return result; } @@ -466,7 +454,6 @@ lm_ggml_opt_context_t lm_ggml_opt_init(struct lm_ggml_opt_params params) { result->buf_static_cpu = lm_ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, lm_ggml_backend_cpu_buffer_type()); - lm_ggml_opt_alloc_graph(result, result->gb_opt); lm_ggml_graph_reset(result->gb_opt); return result; diff --git a/cpp/ggml-quants.c b/cpp/ggml-quants.c index 7b4f460f..ef5fe683 100644 --- a/cpp/ggml-quants.c +++ b/cpp/ggml-quants.c @@ -5220,15 +5220,6 @@ bool lm_ggml_validate_row_data(enum lm_ggml_type type, const void * data, size_t { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; - case LM_GGML_TYPE_Q4_0_4_4: - case LM_GGML_TYPE_Q4_0_4_8: - { - VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4); - } break; - case LM_GGML_TYPE_Q4_0_8_8: - { - VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8); - } break; case LM_GGML_TYPE_I8: case LM_GGML_TYPE_I16: diff --git a/cpp/ggml-threading.h b/cpp/ggml-threading.h index d453d269..a676c27a 100644 --- a/cpp/ggml-threading.h +++ b/cpp/ggml-threading.h @@ -1,11 +1,13 @@ #pragma once +#include "ggml.h" + #ifdef __cplusplus extern "C" { #endif -void lm_ggml_critical_section_start(void); -void lm_ggml_critical_section_end(void); +LM_GGML_API void lm_ggml_critical_section_start(void); +LM_GGML_API void lm_ggml_critical_section_end(void); #ifdef __cplusplus } diff --git a/cpp/ggml.c b/cpp/ggml.c index 4cbe3858..7d6f1ff4 100644 --- a/cpp/ggml.c +++ b/cpp/ggml.c @@ -8,7 +8,10 @@ // FIXME: required here for quantization functions #include "ggml-quants.h" -#include "ggml-aarch64.h" + +#ifdef LM_GGML_USE_CPU_HBM +#include +#endif #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW @@ -788,32 +791,23 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .to_float = (lm_ggml_to_float_t) lm_ggml_bf16_to_fp32_row, .from_float_ref = (lm_ggml_from_float_t) lm_ggml_fp32_to_bf16_row_ref, }, - [LM_GGML_TYPE_Q4_0_4_4] = { - .type_name = "q4_0_4x4", - .blck_size = QK4_0, - .blck_size_interleave = 4, - .type_size = sizeof(block_q4_0), - .is_quantized = true, - .to_float = NULL, - .from_float_ref = NULL, + [31] = { // LM_GGML_TYPE_Q4_0_4_4 + .type_name = "TYPE_Q4_0_4_4 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, }, - [LM_GGML_TYPE_Q4_0_4_8] = { - .type_name = "q4_0_4x8", - .blck_size = QK4_0, - .blck_size_interleave = 8, - .type_size = sizeof(block_q4_0), - .is_quantized = true, - .to_float = NULL, - .from_float_ref = NULL, + [32] = { // LM_GGML_TYPE_Q4_0_4_8 + .type_name = "TYPE_Q4_0_4_8 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, }, - [LM_GGML_TYPE_Q4_0_8_8] = { - .type_name = "q4_0_8x8", - .blck_size = QK4_0, - .blck_size_interleave = 8, - .type_size = sizeof(block_q4_0), - .is_quantized = true, - .to_float = NULL, - .from_float_ref = NULL, + [33] = { // LM_GGML_TYPE_Q4_0_8_8 + .type_name = "TYPE_Q4_0_8_8 REMOVED, use Q4_0 with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, }, [LM_GGML_TYPE_TQ1_0] = { .type_name = "tq1_0", @@ -831,6 +825,24 @@ static const struct lm_ggml_type_traits type_traits[LM_GGML_TYPE_COUNT] = { .to_float = (lm_ggml_to_float_t) dequantize_row_tq2_0, .from_float_ref = (lm_ggml_from_float_t) quantize_row_tq2_0_ref, }, + [36] = { // LM_GGML_TYPE_IQ4_NL_4_4 + .type_name = "TYPE_IQ4_NL_4_4 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [37] = { // LM_GGML_TYPE_IQ4_NL_4_8 + .type_name = "TYPE_IQ4_NL_4_8 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, + [38] = { // LM_GGML_TYPE_IQ4_NL_8_8 + .type_name = "TYPE_IQ4_NL_8_8 REMOVED, use IQ4_NL with runtime repacking", + .blck_size = 0, + .type_size = 0, + .is_quantized = false, + }, }; const struct lm_ggml_type_traits * lm_ggml_get_type_traits(enum lm_ggml_type type) { @@ -941,6 +953,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "POOL_2D_BACK", "UPSCALE", "PAD", + "PAD_REFLECT_1D", "ARANGE", "TIMESTEP_EMBEDDING", "ARGSORT", @@ -974,7 +987,7 @@ static const char * LM_GGML_OP_NAME[LM_GGML_OP_COUNT] = { "OPT_STEP_ADAMW", }; -static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81"); +static_assert(LM_GGML_OP_COUNT == 82, "LM_GGML_OP_COUNT != 82"); static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "none", @@ -1036,6 +1049,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "pool_2d_back(x)", "upscale(x)", "pad(x)", + "pad_reflect_1d(x)", "arange(start, stop, step)", "timestep_embedding(timesteps, dim, max_period)", "argsort(x)", @@ -1069,7 +1083,7 @@ static const char * LM_GGML_OP_SYMBOL[LM_GGML_OP_COUNT] = { "adamw(x)", }; -static_assert(LM_GGML_OP_COUNT == 81, "LM_GGML_OP_COUNT != 81"); +static_assert(LM_GGML_OP_COUNT == 82, "LM_GGML_OP_COUNT != 82"); static_assert(LM_GGML_OP_POOL_COUNT == 2, "LM_GGML_OP_POOL_COUNT != 2"); @@ -1259,9 +1273,6 @@ enum lm_ggml_type lm_ggml_ftype_to_lm_ggml_type(enum lm_ggml_ftype ftype) { case LM_GGML_FTYPE_MOSTLY_IQ4_XS: wtype = LM_GGML_TYPE_IQ4_XS; break; case LM_GGML_FTYPE_MOSTLY_IQ3_S: wtype = LM_GGML_TYPE_IQ3_S; break; case LM_GGML_FTYPE_MOSTLY_IQ2_S: wtype = LM_GGML_TYPE_IQ2_S; break; - case LM_GGML_FTYPE_MOSTLY_Q4_0_4_4: wtype = LM_GGML_TYPE_Q4_0_4_4; break; - case LM_GGML_FTYPE_MOSTLY_Q4_0_4_8: wtype = LM_GGML_TYPE_Q4_0_4_8; break; - case LM_GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = LM_GGML_TYPE_Q4_0_8_8; break; case LM_GGML_FTYPE_UNKNOWN: wtype = LM_GGML_TYPE_COUNT; break; case LM_GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = LM_GGML_TYPE_COUNT; break; } @@ -2255,6 +2266,7 @@ struct lm_ggml_tensor * lm_ggml_argmax( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a) { LM_GGML_ASSERT(lm_ggml_is_matrix(a)); + LM_GGML_ASSERT(a->ne[0] <= INT32_MAX); struct lm_ggml_tensor * result = lm_ggml_new_tensor_1d(ctx, LM_GGML_TYPE_I32, a->ne[1]); @@ -3505,15 +3517,18 @@ static struct lm_ggml_tensor * lm_ggml_rope_impl( LM_GGML_ASSERT(c->ne[0] >= n_dims / 2); } + int sections[4] = {0, 0, 0, 0}; + struct lm_ggml_tensor * result = inplace ? lm_ggml_view_tensor(ctx, a) : lm_ggml_dup_tensor(ctx, a); - int32_t params[11] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + int32_t params[15] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; memcpy(params + 5, &freq_base, sizeof(float)); memcpy(params + 6, &freq_scale, sizeof(float)); memcpy(params + 7, &ext_factor, sizeof(float)); memcpy(params + 8, &attn_factor, sizeof(float)); memcpy(params + 9, &beta_fast, sizeof(float)); memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(params + 11, §ions, sizeof(int)*4); lm_ggml_set_op_params(result, params, sizeof(params)); result->op = LM_GGML_OP_ROPE; @@ -3535,6 +3550,53 @@ struct lm_ggml_tensor * lm_ggml_rope( ); } +struct lm_ggml_tensor * lm_ggml_rope_multi( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + // Multimodal Rotary Position Embedding + LM_GGML_ASSERT((mode & 1) == 0 && "mode & 1 == 1 is no longer supported"); + + LM_GGML_ASSERT(lm_ggml_is_vector(b)); + LM_GGML_ASSERT(b->type == LM_GGML_TYPE_I32); + LM_GGML_ASSERT(a->ne[2] * 4 == b->ne[0]); // mrope expecting 4 position ids per token + + if (c) { + LM_GGML_ASSERT(c->type == LM_GGML_TYPE_F32); + LM_GGML_ASSERT(c->ne[0] >= n_dims / 2); + } + + struct lm_ggml_tensor * result = lm_ggml_dup_tensor(ctx, a); + + int32_t params[11 + 4] = { /*n_past*/ 0, n_dims, mode, /*n_ctx*/ 0, n_ctx_orig }; + memcpy(params + 5, &freq_base, sizeof(float)); + memcpy(params + 6, &freq_scale, sizeof(float)); + memcpy(params + 7, &ext_factor, sizeof(float)); + memcpy(params + 8, &attn_factor, sizeof(float)); + memcpy(params + 9, &beta_fast, sizeof(float)); + memcpy(params + 10, &beta_slow, sizeof(float)); + memcpy(¶ms[11], sections, sizeof(int)*4); + lm_ggml_set_op_params(result, params, sizeof(params)); + + result->op = LM_GGML_OP_ROPE; + result->src[0] = a; + result->src[1] = b; + result->src[2] = c; + + return result; +} + struct lm_ggml_tensor * lm_ggml_rope_inplace( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, @@ -3698,13 +3760,84 @@ struct lm_ggml_tensor * lm_ggml_clamp( return result; } -// lm_ggml_conv_1d - static int64_t lm_ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; } -LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_1d( +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +// a: [OC,IC, KH, KW] +// b: [N, IC, IH, IW] +// result: [N, OH, OW, IC*KH*KW] +struct lm_ggml_tensor * lm_ggml_im2col( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D, + enum lm_ggml_type dst_type) { + if (is_2D) { + LM_GGML_ASSERT(a->ne[2] == b->ne[2]); + } else { + //LM_GGML_ASSERT(b->ne[1] % a->ne[1] == 0); + LM_GGML_ASSERT(b->ne[1] == a->ne[1]); + LM_GGML_ASSERT(b->ne[3] == 1); + } + + const int64_t OH = is_2D ? lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; + const int64_t OW = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + + LM_GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); + LM_GGML_ASSERT((OW > 0) && "b too small compared to a"); + + const int64_t ne[4] = { + is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], + OW, + is_2D ? OH : b->ne[2], + is_2D ? b->ne[3] : 1, + }; + + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, dst_type, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + lm_ggml_set_op_params(result, params, sizeof(params)); + + result->op = LM_GGML_OP_IM2COL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct lm_ggml_tensor * lm_ggml_im2col_back( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int64_t * ne, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D) { + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + lm_ggml_set_op_params(result, params, sizeof(params)); + + result->op = LM_GGML_OP_IM2COL_BACK; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// lm_ggml_conv_1d + +struct lm_ggml_tensor * lm_ggml_conv_1d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, @@ -3734,137 +3867,75 @@ struct lm_ggml_tensor* lm_ggml_conv_1d_ph( return lm_ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); } -// lm_ggml_conv_transpose_1d - -static int64_t lm_ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { - return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; -} +// lm_ggml_conv_1d_dw -LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( +struct lm_ggml_tensor * lm_ggml_conv_1d_dw( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, int s0, int p0, int d0) { - LM_GGML_ASSERT(lm_ggml_is_matrix(b)); - LM_GGML_ASSERT(a->ne[2] == b->ne[1]); - LM_GGML_ASSERT(a->ne[3] == 1); + struct lm_ggml_tensor * new_a = lm_ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]); + struct lm_ggml_tensor * new_b = lm_ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]); - LM_GGML_ASSERT(p0 == 0); - LM_GGML_ASSERT(d0 == 1); + struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, LM_GGML_TYPE_F16); - const int64_t ne[4] = { - lm_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), - a->ne[1], b->ne[2], 1, - }; - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); + struct lm_ggml_tensor * result = lm_ggml_mul_mat(ctx, im2col, a); - int32_t params[] = { s0, p0, d0 }; - lm_ggml_set_op_params(result, params, sizeof(params)); - - result->op = LM_GGML_OP_CONV_TRANSPOSE_1D; - result->src[0] = a; - result->src[1] = b; + result = lm_ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1); return result; } -// lm_ggml_conv_depthwise +// lm_ggml_conv_1d_dw_ph -struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d( +struct lm_ggml_tensor * lm_ggml_conv_1d_dw_ph( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { - struct lm_ggml_tensor * new_a = lm_ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); - struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, new_a, - lm_ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), - s0, s1, p0, p1, d0, d1, true, LM_GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] - struct lm_ggml_tensor * new_b = lm_ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + int d0) { + return lm_ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0); +} - new_a = lm_ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] - struct lm_ggml_tensor * result = lm_ggml_mul_mat(ctx, new_a, new_b); - result = lm_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] +// lm_ggml_conv_transpose_1d - return result; +static int64_t lm_ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; } -// lm_ggml_conv_2d -// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] -// a: [OC,IC, KH, KW] -// b: [N, IC, IH, IW] -// result: [N, OH, OW, IC*KH*KW] -struct lm_ggml_tensor * lm_ggml_im2col( +LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, struct lm_ggml_tensor * b, int s0, - int s1, int p0, - int p1, - int d0, - int d1, - bool is_2D, - enum lm_ggml_type dst_type) { - if(is_2D) { - LM_GGML_ASSERT(a->ne[2] == b->ne[2]); - } else { - LM_GGML_ASSERT(a->ne[1] == b->ne[1]); - LM_GGML_ASSERT(b->ne[3] == 1); - } - - const int64_t OH = is_2D ? lm_ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1) : 0; - const int64_t OW = lm_ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0); + int d0) { + LM_GGML_ASSERT(lm_ggml_is_matrix(b)); + LM_GGML_ASSERT(a->ne[2] == b->ne[1]); + LM_GGML_ASSERT(a->ne[3] == 1); - LM_GGML_ASSERT((!is_2D || OH > 0) && "b too small compared to a"); - LM_GGML_ASSERT((OW > 0) && "b too small compared to a"); + LM_GGML_ASSERT(p0 == 0); + LM_GGML_ASSERT(d0 == 1); const int64_t ne[4] = { - is_2D ? (a->ne[2] * a->ne[1] * a->ne[0]) : a->ne[1] * a->ne[0], - OW, - is_2D ? OH : b->ne[2], - is_2D ? b->ne[3] : 1, + lm_ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), + a->ne[1], b->ne[2], 1, }; + struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, dst_type, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; + int32_t params[] = { s0, p0, d0 }; lm_ggml_set_op_params(result, params, sizeof(params)); - result->op = LM_GGML_OP_IM2COL; + result->op = LM_GGML_OP_CONV_TRANSPOSE_1D; result->src[0] = a; result->src[1] = b; return result; } -struct lm_ggml_tensor * lm_ggml_im2col_back( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, - struct lm_ggml_tensor * b, - int64_t * ne, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1, - bool is_2D) { - struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_F32, 4, ne); - int32_t params[] = { s0, s1, p0, p1, d0, d1, (is_2D ? 1 : 0) }; - lm_ggml_set_op_params(result, params, sizeof(params)); - - result->op = LM_GGML_OP_IM2COL_BACK; - result->src[0] = a; - result->src[1] = b; - - return result; -} +// lm_ggml_conv_2d // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] @@ -3911,6 +3982,31 @@ struct lm_ggml_tensor * lm_ggml_conv_2d_s1_ph( return lm_ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); } +// lm_ggml_conv_2d_dw + +struct lm_ggml_tensor * lm_ggml_conv_2d_dw( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct lm_ggml_tensor * new_a = lm_ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); + struct lm_ggml_tensor * im2col = lm_ggml_im2col(ctx, new_a, + lm_ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), + s0, s1, p0, p1, d0, d1, true, LM_GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] + struct lm_ggml_tensor * new_b = lm_ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + + new_a = lm_ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] + struct lm_ggml_tensor * result = lm_ggml_mul_mat(ctx, new_a, new_b); + result = lm_ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] + + return result; +} + // lm_ggml_conv_transpose_2d_p0 static int64_t lm_ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { @@ -4087,6 +4183,37 @@ struct lm_ggml_tensor * lm_ggml_pad( return result; } +// lm_ggml_pad_reflect_1d + +struct lm_ggml_tensor * lm_ggml_pad_reflect_1d( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int p0, + int p1) { + LM_GGML_ASSERT(p0 >= 0); + LM_GGML_ASSERT(p1 >= 0); + + LM_GGML_ASSERT(p0 < a->ne[0]); // padding length on each size must be less than the + LM_GGML_ASSERT(p1 < a->ne[0]); // existing length of the dimension being padded + + LM_GGML_ASSERT(lm_ggml_is_contiguous(a)); + LM_GGML_ASSERT(a->type == LM_GGML_TYPE_F32); + + struct lm_ggml_tensor * result = lm_ggml_new_tensor_4d(ctx, a->type, + a->ne[0] + p0 + p1, + a->ne[1], + a->ne[2], + a->ne[3]); + + int32_t params[] = { p0, p1 }; + lm_ggml_set_op_params(result, params, sizeof(params)); + + result->op = LM_GGML_OP_PAD_REFLECT_1D; + result->src[0] = a; + + return result; +} + // lm_ggml_arange struct lm_ggml_tensor * lm_ggml_arange( @@ -4138,6 +4265,7 @@ struct lm_ggml_tensor * lm_ggml_argsort( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, enum lm_ggml_sort_order order) { + LM_GGML_ASSERT(a->ne[0] <= INT32_MAX); struct lm_ggml_tensor * result = lm_ggml_new_tensor(ctx, LM_GGML_TYPE_I32, LM_GGML_MAX_DIMS, a->ne); lm_ggml_set_op_params_i32(result, 0, (int32_t) order); @@ -5019,8 +5147,10 @@ static void lm_ggml_hash_map_free(struct hash_map * map) { } // utility functions to change gradients -// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator -// else if a is in zero_table, replace a +// isrc is the index of tensor in cgraph->visited_has_set.keys +// the corresponding gradient (accumulators) are also at position isrc +// if tensor has a gradient accumulator, modify that accumulator in-place +// else if there is no gradient for tensor, set the corresponding value // else, just add/subtract/etc. the gradients static void lm_ggml_add_or_set( @@ -5028,11 +5158,14 @@ static void lm_ggml_add_or_set( struct lm_ggml_cgraph * cgraph, size_t isrc, struct lm_ggml_tensor * tensor) { + struct lm_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + LM_GGML_ASSERT(src); if (cgraph->grads[isrc]) { - cgraph->grads[isrc] = lm_ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + cgraph->grads[isrc] = lm_ggml_add_impl(ctx, cgraph->grads[isrc], tensor, /*inplace =*/ cgraph->grad_accs[isrc]); } else { cgraph->grads[isrc] = tensor; } + lm_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } @@ -5040,18 +5173,20 @@ static void lm_ggml_acc_or_set( struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, size_t isrc, - struct lm_ggml_tensor * src, struct lm_ggml_tensor * tensor, const size_t nb1, const size_t nb2, const size_t nb3, const size_t offset) { + struct lm_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + LM_GGML_ASSERT(src); if (cgraph->grads[isrc]) { cgraph->grads[isrc] = lm_ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]); } else { struct lm_ggml_tensor * a_zero = lm_ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN cgraph->grads[isrc] = lm_ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false); } + lm_ggml_format_name(cgraph->grads[isrc], "grad for %s", cgraph->visited_hash_set.keys[isrc]->name); lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } @@ -5059,13 +5194,15 @@ static void lm_ggml_add1_or_set( struct lm_ggml_context * ctx, struct lm_ggml_cgraph * cgraph, size_t isrc, - struct lm_ggml_tensor * src, struct lm_ggml_tensor * tensor) { + struct lm_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + LM_GGML_ASSERT(src); if (cgraph->grads[isrc]) { cgraph->grads[isrc] = lm_ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); } else { cgraph->grads[isrc] = lm_ggml_repeat(ctx, tensor, src); } + lm_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } @@ -5074,11 +5211,14 @@ static void lm_ggml_sub_or_set( struct lm_ggml_cgraph * cgraph, size_t isrc, struct lm_ggml_tensor * tensor) { + struct lm_ggml_tensor * src = cgraph->visited_hash_set.keys[isrc]; + LM_GGML_ASSERT(src); if (cgraph->grads[isrc]) { cgraph->grads[isrc] = lm_ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); } else { cgraph->grads[isrc] = lm_ggml_neg(ctx, tensor); } + lm_ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); lm_ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } @@ -5095,12 +5235,12 @@ static void lm_ggml_compute_backward( struct lm_ggml_tensor * src1 = tensor->src[1]; struct lm_ggml_tensor * src2 = tensor->src[2]; struct lm_ggml_hash_set * hash_set = &cgraph->visited_hash_set; - const size_t isrc0 = lm_ggml_hash_find(hash_set, src0); - const size_t isrc1 = lm_ggml_hash_find(hash_set, src1); - const size_t isrc2 = lm_ggml_hash_find(hash_set, src2); - const bool src0_needs_grads = isrc0 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0]; - const bool src1_needs_grads = isrc1 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1]; - const bool src2_needs_grads = isrc2 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2]; + const size_t isrc0 = src0 ? lm_ggml_hash_find(hash_set, src0) : (size_t) -1; + const size_t isrc1 = src1 ? lm_ggml_hash_find(hash_set, src1) : (size_t) -1; + const size_t isrc2 = src2 ? lm_ggml_hash_find(hash_set, src2) : (size_t) -1; + const bool src0_needs_grads = src0 && isrc0 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0]; + const bool src1_needs_grads = src1 && isrc1 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1]; + const bool src2_needs_grads = src2 && isrc2 != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2]; switch (tensor->op) { case LM_GGML_OP_DUP: { @@ -5200,7 +5340,7 @@ static void lm_ggml_compute_backward( } break; case LM_GGML_OP_SUM: { if (src0_needs_grads) { - lm_ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad); + lm_ggml_add1_or_set(ctx, cgraph, isrc0, grad); } } break; case LM_GGML_OP_SUM_ROWS: { @@ -5210,7 +5350,7 @@ static void lm_ggml_compute_backward( } break; case LM_GGML_OP_MEAN: { if (src0_needs_grads) { - lm_ggml_add1_or_set(ctx, cgraph, isrc0, src0, lm_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + lm_ggml_add1_or_set(ctx, cgraph, isrc0, lm_ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); } } break; case LM_GGML_OP_REPEAT: { @@ -5363,7 +5503,7 @@ static void lm_ggml_compute_backward( nb3 = (nb3 / n0) * ng; } - lm_ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset); + lm_ggml_acc_or_set(ctx, cgraph, isrc0, grad, nb1, nb2, nb3, offset); } } break; case LM_GGML_OP_PERMUTE: { @@ -5597,10 +5737,9 @@ void lm_ggml_build_backward_expand( const int n_nodes_f = cgraph->n_nodes; - const size_t hash_size = lm_ggml_hash_size(2*cgraph->size); - memset(cgraph->grads, 0, hash_size*sizeof(struct lm_ggml_tensor *)); - memset(cgraph->grad_accs, 0, hash_size*sizeof(struct lm_ggml_tensor *)); - bool * grads_needed = calloc(hash_size, sizeof(bool)); + memset(cgraph->grads, 0, cgraph->visited_hash_set.size*sizeof(struct lm_ggml_tensor *)); + memset(cgraph->grad_accs, 0, cgraph->visited_hash_set.size*sizeof(struct lm_ggml_tensor *)); + bool * grads_needed = calloc(cgraph->visited_hash_set.size, sizeof(bool)); { bool any_params = false; @@ -5621,7 +5760,7 @@ void lm_ggml_build_backward_expand( continue; } - bool node_needs_grad = node->flags & LM_GGML_TENSOR_FLAG_PARAM; + bool node_needs_grad = (node->flags & LM_GGML_TENSOR_FLAG_PARAM) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS); bool ignore_src[LM_GGML_MAX_SRC] = {false}; switch (node->op) { // gradients in node->src[0] for one reason or another have no effect on output gradients @@ -5638,7 +5777,7 @@ void lm_ggml_build_backward_expand( } break; // gradients in node->src[1] for one reason or another have no effect on output gradients - case LM_GGML_OP_CPY: // gradients in CPY target are irrelevant + case LM_GGML_OP_CPY: // gradients in CPY target are irrelevant case LM_GGML_OP_GET_ROWS: // row indices not differentiable case LM_GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS case LM_GGML_OP_ROPE: // positions not differentiable @@ -5665,9 +5804,12 @@ void lm_ggml_build_backward_expand( node->op == LM_GGML_OP_RESHAPE || node->op == LM_GGML_OP_PERMUTE || node->op == LM_GGML_OP_TRANSPOSE); const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node); + LM_GGML_ASSERT(igrad != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad)); if ((accumulate && (node->flags & LM_GGML_TENSOR_FLAG_PARAM)) || (node->flags & LM_GGML_TENSOR_FLAG_LOSS)) { - cgraph->grads[igrad] = lm_ggml_dup_tensor(ctx_static, node); - cgraph->grad_accs[igrad] = cgraph->grads[igrad]; + cgraph->grad_accs[igrad] = lm_ggml_dup_tensor(ctx_static, node); + cgraph->grads[igrad] = cgraph->grad_accs[igrad]; + lm_ggml_format_name(cgraph->grad_accs[igrad], "grad acc for %s", node->name); } grads_needed[igrad] = true; } @@ -5761,15 +5903,15 @@ struct lm_ggml_cgraph * lm_ggml_new_graph(struct lm_ggml_context * ctx) { struct lm_ggml_cgraph lm_ggml_graph_view(struct lm_ggml_cgraph * cgraph0, int i0, int i1) { struct lm_ggml_cgraph cgraph = { - /*.size =*/ 0, - /*.n_nodes =*/ i1 - i0, - /*.n_leafs =*/ 0, - /*.nodes =*/ cgraph0->nodes + i0, - /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, - /*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL, - /*.leafs =*/ NULL, - /*.hash_table =*/ { 0, NULL, NULL }, - /*.order =*/ cgraph0->order, + /*.size =*/ 0, + /*.n_nodes =*/ i1 - i0, + /*.n_leafs =*/ 0, + /*.nodes =*/ cgraph0->nodes + i0, + /*.grads =*/ NULL, // gradients would need visited_hash_set + /*.grad_accs =*/ NULL, + /*.leafs =*/ NULL, + /*.visited_hash_set =*/ { 0, NULL, NULL }, + /*.order =*/ cgraph0->order, }; return cgraph; @@ -5799,12 +5941,22 @@ void lm_ggml_graph_cpy(struct lm_ggml_cgraph * src, struct lm_ggml_cgraph * dst) } } + if (dst->grads) { + memset(dst->grads, 0, dst->visited_hash_set.size*sizeof(struct lm_ggml_tensor *)); + memset(dst->grad_accs, 0, dst->visited_hash_set.size*sizeof(struct lm_ggml_tensor *)); + } if (src->grads) { LM_GGML_ASSERT(dst->grads != NULL); LM_GGML_ASSERT(dst->grad_accs != NULL); for (int i = 0; i < src->n_nodes; ++i) { const size_t igrad_src = lm_ggml_hash_find(&src->visited_hash_set, src->nodes[i]); const size_t igrad_dst = lm_ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + + LM_GGML_ASSERT(igrad_src != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(lm_ggml_bitset_get(src->visited_hash_set.used, igrad_src)); + LM_GGML_ASSERT(igrad_dst != LM_GGML_HASHSET_FULL); + LM_GGML_ASSERT(lm_ggml_bitset_get(dst->visited_hash_set.used, igrad_dst)); + dst->grads[igrad_dst] = src->grads[igrad_src]; dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; } @@ -5839,12 +5991,8 @@ void lm_ggml_graph_reset(struct lm_ggml_cgraph * cgraph) { if (node->op == LM_GGML_OP_OPT_STEP_ADAMW) { // clear momenta - if (node->src[2]->data) { - lm_ggml_set_zero(node->src[2]); - } - if (node->src[3]->data) { - lm_ggml_set_zero(node->src[3]); - } + lm_ggml_set_zero(node->src[2]); + lm_ggml_set_zero(node->src[3]); } // initial gradients of loss should be 1, 0 otherwise @@ -5923,12 +6071,12 @@ struct lm_ggml_tensor * lm_ggml_graph_get_tensor(const struct lm_ggml_cgraph * c struct lm_ggml_tensor * lm_ggml_graph_get_grad(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node); - return igrad != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL; + return igrad != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grads ? cgraph->grads[igrad] : NULL; } struct lm_ggml_tensor * lm_ggml_graph_get_grad_acc(const struct lm_ggml_cgraph * cgraph, const struct lm_ggml_tensor * node) { const size_t igrad = lm_ggml_hash_find(&cgraph->visited_hash_set, node); - return igrad != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL; + return igrad != LM_GGML_HASHSET_FULL && lm_ggml_bitset_get(cgraph->visited_hash_set.used, igrad) && cgraph->grad_accs ? cgraph->grad_accs[igrad] : NULL; } void lm_ggml_graph_print(const struct lm_ggml_cgraph * cgraph) { @@ -6240,9 +6388,6 @@ size_t lm_ggml_quantize_chunk( case LM_GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case LM_GGML_TYPE_Q4_0_4_4: result = quantize_q4_0_4x4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case LM_GGML_TYPE_Q4_0_4_8: result = quantize_q4_0_4x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case LM_GGML_TYPE_Q4_0_8_8: result = quantize_q4_0_8x8(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case LM_GGML_TYPE_F16: { size_t elemsize = sizeof(lm_ggml_fp16_t); @@ -6378,7 +6523,7 @@ struct lm_gguf_context { void * data; }; -static size_t lm_gguf_type_size(enum lm_gguf_type type) { +size_t lm_gguf_type_size(enum lm_gguf_type type) { LM_GGML_ASSERT(0 <= type && type < LM_GGUF_TYPE_COUNT); return LM_GGUF_TYPE_SIZE[type]; } @@ -6506,13 +6651,7 @@ struct lm_gguf_context * lm_gguf_init_empty(void) { return ctx; } -struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gguf_init_params params) { - FILE * file = lm_ggml_fopen(fname, "rb"); - if (!file) { - fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno)); - return NULL; - } - +struct lm_gguf_context * lm_gguf_init_from_file_impl(FILE * file, struct lm_gguf_init_params params) { // offset from start of file size_t offset = 0; @@ -6525,7 +6664,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg for (uint32_t i = 0; i < sizeof(magic); i++) { if (magic[i] != LM_GGUF_MAGIC[i]) { fprintf(stderr, "%s: invalid magic characters '%c%c%c%c'\n", __func__, magic[0], magic[1], magic[2], magic[3]); - fclose(file); return NULL; } } @@ -6536,7 +6674,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg struct lm_gguf_context * ctx = calloc(1, sizeof(struct lm_gguf_context)); if (!ctx) { fprintf(stderr, "%s: failed to allocate memory for context\n", __func__); - fclose(file); return NULL; } @@ -6554,7 +6691,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg if (ctx->header.version == 1) { fprintf(stderr, "%s: GGUFv1 is no longer supported. please use a more up-to-date version\n", __func__); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6567,7 +6703,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg if (!ok) { fprintf(stderr, "%s: failed to read header\n", __func__); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6577,12 +6712,13 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg { const uint64_t n_kv = ctx->header.n_kv; - ctx->kv = calloc(n_kv, sizeof(struct lm_gguf_kv)); - if (!ctx->kv) { - fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__); - fclose(file); - lm_gguf_free(ctx); - return NULL; + if (n_kv > 0) { + ctx->kv = calloc(n_kv, sizeof(struct lm_gguf_kv)); + if (!ctx->kv) { + fprintf(stderr, "%s: failed to allocate memory for kv pairs\n", __func__); + lm_gguf_free(ctx); + return NULL; + } } for (uint64_t i = 0; i < n_kv; ++i) { @@ -6629,7 +6765,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg // prevent from integer overflow in the malloc below if (kv->value.arr.n >= SIZE_MAX/lm_gguf_type_size(kv->value.arr.type)) { fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6637,7 +6772,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg kv->value.arr.data = calloc(kv->value.arr.n, lm_gguf_type_size(kv->value.arr.type)); if (!kv->value.arr.data) { fprintf(stderr, "%s: failed to allocate memory for array\n", __func__); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6649,7 +6783,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg // prevent from integer overflow in the malloc below if (kv->value.arr.n >= SIZE_MAX/sizeof(struct lm_gguf_str)) { fprintf(stderr, "%s: array size is too large (%" PRIu64 ")\n", __func__, kv->value.arr.n); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6657,7 +6790,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg kv->value.arr.data = calloc(kv->value.arr.n, sizeof(struct lm_gguf_str)); if (!kv->value.arr.data) { fprintf(stderr, "%s: failed to allocate memory for array\n", __func__); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6688,7 +6820,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg if (!ok) { fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6699,7 +6830,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg ctx->infos = calloc(ctx->header.n_tensors, sizeof(struct lm_gguf_tensor_info)); if (!ctx->infos) { fprintf(stderr, "%s: failed to allocate memory for tensor infos\n", __func__); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6735,7 +6865,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg if (!ok) { fprintf(stderr, "%s: failed to read tensor info\n", __func__); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6774,10 +6903,17 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg (int64_t) info->ne[2] * (int64_t) info->ne[3]; - if (lm_ggml_blck_size(info->type) == 0 || ne % lm_ggml_blck_size(info->type) != 0) { + if (lm_ggml_blck_size(info->type) == 0 ) { + // this tensor type support have been removed: + fprintf(stderr, "%s: tensor '%s' of type %d: %s\n", + __func__, info->name.data, (int) info->type, lm_ggml_type_name(info->type)); + lm_gguf_free(ctx); + return NULL; + } + + if (ne % lm_ggml_blck_size(info->type) != 0) { fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n", __func__, info->name.data, (int) info->type, lm_ggml_type_name(info->type), ne, lm_ggml_blck_size(info->type)); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6809,7 +6945,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg *params.ctx = lm_ggml_init(pdata); if (*params.ctx == NULL) { fprintf(stderr, "%s: failed to initialize context\n", __func__); - fclose(file); lm_gguf_free(ctx); return NULL; } @@ -6828,7 +6963,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg if (!ok) { fprintf(stderr, "%s: failed to read tensor data\n", __func__); - fclose(file); lm_ggml_free(ctx_data); lm_gguf_free(ctx); return NULL; @@ -6867,7 +7001,6 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg if (!ok) { fprintf(stderr, "%s: failed to read the tensor data\n", __func__); - fclose(file); lm_ggml_free(ctx_data); lm_gguf_free(ctx); return NULL; @@ -6876,11 +7009,21 @@ struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gg lm_ggml_set_no_alloc(ctx_data, params.no_alloc); } - fclose(file); - return ctx; } +struct lm_gguf_context * lm_gguf_init_from_file(const char * fname, struct lm_gguf_init_params params) { + FILE * file = lm_ggml_fopen(fname, "rb"); + if (!file) { + fprintf(stderr, "%s: failed to open '%s': '%s'\n", __func__, fname, strerror(errno)); + return NULL; + } + + struct lm_gguf_context * result = lm_gguf_init_from_file_impl(file, params); + fclose(file); + return result; +} + void lm_gguf_free(struct lm_gguf_context * ctx) { if (ctx == NULL) { return; @@ -7340,13 +7483,7 @@ void lm_gguf_set_tensor_data(struct lm_gguf_context * ctx, const char * name, co // fwrite(val, sizeof(char), size, file); //} -struct lm_gguf_buf { - void * data; - size_t size; - size_t offset; -}; - -static struct lm_gguf_buf lm_gguf_buf_init(size_t size) { +struct lm_gguf_buf lm_gguf_buf_init(size_t size) { struct lm_gguf_buf buf = { /*buf.data =*/ size == 0 ? NULL : LM_GGML_CALLOC(1, size), /*buf.size =*/ size, @@ -7356,7 +7493,7 @@ static struct lm_gguf_buf lm_gguf_buf_init(size_t size) { return buf; } -static void lm_gguf_buf_free(struct lm_gguf_buf buf) { +void lm_gguf_buf_free(struct lm_gguf_buf buf) { if (buf.data) { LM_GGML_FREE(buf.data); } @@ -7394,7 +7531,7 @@ static void lm_gguf_bwrite_el(struct lm_gguf_buf * buf, const void * val, size_t buf->offset += el_size; } -static void lm_gguf_write_to_buf(const struct lm_gguf_context * ctx, struct lm_gguf_buf * buf, bool only_meta) { +void lm_gguf_write_to_buf(const struct lm_gguf_context * ctx, struct lm_gguf_buf * buf, bool only_meta) { // write header lm_gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic)); lm_gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version)); @@ -7549,3 +7686,26 @@ void lm_ggml_log_set(lm_ggml_log_callback log_callback, void * user_data) { g_logger_state.log_callback = log_callback ? log_callback : lm_ggml_log_callback_default; g_logger_state.log_callback_user_data = user_data; } + +void lm_ggml_threadpool_params_init(struct lm_ggml_threadpool_params * p, int n_threads) { + p->n_threads = n_threads; + p->prio = 0; // default priority (usually means normal or inherited) + p->poll = 50; // hybrid-polling enabled + p->strict_cpu = false; // no strict placement (all threads share same cpumask) + p->paused = false; // threads are ready to go + memset(p->cpumask, 0, LM_GGML_MAX_N_THREADS); // all-zero means use the default affinity (usually inherited) +} + +struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads) { + struct lm_ggml_threadpool_params p; + lm_ggml_threadpool_params_init(&p, n_threads); + return p; +} + +bool lm_ggml_threadpool_params_match(const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1) { + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; + return memcmp(p0->cpumask, p1->cpumask, LM_GGML_MAX_N_THREADS) == 0; +} diff --git a/cpp/ggml.h b/cpp/ggml.h index f86241df..27129d2e 100644 --- a/cpp/ggml.h +++ b/cpp/ggml.h @@ -237,7 +237,9 @@ #define LM_GGML_EXIT_SUCCESS 0 #define LM_GGML_EXIT_ABORTED 1 -#define LM_GGML_ROPE_TYPE_NEOX 2 +#define LM_GGML_ROPE_TYPE_NEOX 2 +#define LM_GGML_ROPE_TYPE_MROPE 8 +#define LM_GGML_ROPE_TYPE_VISION 24 #define LM_GGUF_MAGIC "GGUF" @@ -384,12 +386,15 @@ extern "C" { LM_GGML_TYPE_F64 = 28, LM_GGML_TYPE_IQ1_M = 29, LM_GGML_TYPE_BF16 = 30, - LM_GGML_TYPE_Q4_0_4_4 = 31, - LM_GGML_TYPE_Q4_0_4_8 = 32, - LM_GGML_TYPE_Q4_0_8_8 = 33, + // LM_GGML_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files + // LM_GGML_TYPE_Q4_0_4_8 = 32, + // LM_GGML_TYPE_Q4_0_8_8 = 33, LM_GGML_TYPE_TQ1_0 = 34, LM_GGML_TYPE_TQ2_0 = 35, - LM_GGML_TYPE_COUNT, + // LM_GGML_TYPE_IQ4_NL_4_4 = 36, + // LM_GGML_TYPE_IQ4_NL_4_8 = 37, + // LM_GGML_TYPE_IQ4_NL_8_8 = 38, + LM_GGML_TYPE_COUNT = 39, }; // precision @@ -430,9 +435,6 @@ extern "C" { LM_GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors LM_GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors LM_GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors - LM_GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors - LM_GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors - LM_GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors }; // available tensor operations: @@ -496,6 +498,7 @@ extern "C" { LM_GGML_OP_POOL_2D_BACK, LM_GGML_OP_UPSCALE, // nearest interpolate LM_GGML_OP_PAD, + LM_GGML_OP_PAD_REFLECT_1D, LM_GGML_OP_ARANGE, LM_GGML_OP_TIMESTEP_EMBEDDING, LM_GGML_OP_ARGSORT, @@ -1442,6 +1445,22 @@ extern "C" { float beta_fast, float beta_slow); + LM_GGML_API struct lm_ggml_tensor * lm_ggml_rope_multi( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + struct lm_ggml_tensor * b, + struct lm_ggml_tensor * c, + int n_dims, + int sections[4], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + // in-place, returns view(a) LM_GGML_API struct lm_ggml_tensor * lm_ggml_rope_ext_inplace( struct lm_ggml_context * ctx, @@ -1545,17 +1564,6 @@ extern "C" { int d1, // dilation dimension 1 bool is_2D); - LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_depthwise_2d( - struct lm_ggml_context * ctx, - struct lm_ggml_tensor * a, // convolution kernel - struct lm_ggml_tensor * b, // data - int s0, // stride dimension 0 - int s1, // stride dimension 1 - int p0, // padding dimension 0 - int p1, // padding dimension 1 - int d0, // dilation dimension 0 - int d1); // dilation dimension 1 - LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_1d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, // convolution kernel @@ -1573,6 +1581,23 @@ extern "C" { int s, // stride int d); // dilation + // depthwise + // TODO: this is very likely wrong for some cases! - needs more testing + LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_1d_dw( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation + + LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_1d_dw_ph( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data + int s0, // stride + int d0); // dilation + LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_1d( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, // convolution kernel @@ -1592,7 +1617,6 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 - // kernel size is a->ne[0] x a->ne[1] // stride is equal to kernel size // padding is zero @@ -1619,6 +1643,18 @@ extern "C" { struct lm_ggml_tensor * a, struct lm_ggml_tensor * b); + // depthwise + LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_2d_dw( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, // convolution kernel + struct lm_ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + LM_GGML_API struct lm_ggml_tensor * lm_ggml_conv_transpose_2d_p0( struct lm_ggml_context * ctx, struct lm_ggml_tensor * a, @@ -1692,6 +1728,13 @@ extern "C" { int p2, int p3); + // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] + LM_GGML_API struct lm_ggml_tensor * lm_ggml_pad_reflect_1d( + struct lm_ggml_context * ctx, + struct lm_ggml_tensor * a, + int p0, + int p1); + // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 // timesteps: [N,] // return: [N, dim] @@ -2194,11 +2237,19 @@ extern "C" { LM_GGML_API size_t lm_gguf_get_meta_size(const struct lm_gguf_context * ctx); LM_GGML_API void lm_gguf_get_meta_data(const struct lm_gguf_context * ctx, void * data); -#ifdef __cplusplus -// restrict not standard in C++ -#define LM_GGML_RESTRICT +#ifdef __cplusplus + // restrict not standard in C++ +# if defined(__GNUC__) +# define LM_GGML_RESTRICT __restrict__ +# elif defined(__clang__) +# define LM_GGML_RESTRICT __restrict +# elif defined(_MSC_VER) +# define LM_GGML_RESTRICT __restrict +# else +# define LM_GGML_RESTRICT +# endif #else -#define LM_GGML_RESTRICT restrict +# define LM_GGML_RESTRICT restrict #endif typedef void (*lm_ggml_to_float_t) (const void * LM_GGML_RESTRICT x, float * LM_GGML_RESTRICT y, int64_t k); typedef void (*lm_ggml_from_float_t)(const float * LM_GGML_RESTRICT x, void * LM_GGML_RESTRICT y, int64_t k); @@ -2215,6 +2266,37 @@ extern "C" { LM_GGML_API const struct lm_ggml_type_traits * lm_ggml_get_type_traits(enum lm_ggml_type type); + // ggml threadpool + // TODO: currently, only a few functions are in the base ggml API, while the rest are in the CPU backend + // the goal should be to create an API that other backends can use move everything to the ggml base + + // scheduling priorities + enum lm_ggml_sched_priority { + LM_GGML_SCHED_PRIO_NORMAL, + LM_GGML_SCHED_PRIO_MEDIUM, + LM_GGML_SCHED_PRIO_HIGH, + LM_GGML_SCHED_PRIO_REALTIME + }; + + // threadpool params + // Use lm_ggml_threadpool_params_default() or lm_ggml_threadpool_params_init() to populate the defaults + struct lm_ggml_threadpool_params { + bool cpumask[LM_GGML_MAX_N_THREADS]; // mask of cpu cores (all-zeros means use default affinity settings) + int n_threads; // number of threads + enum lm_ggml_sched_priority prio; // thread priority + uint32_t poll; // polling level (0 - no polling, 100 - aggressive polling) + bool strict_cpu; // strict cpu placement + bool paused; // start in paused state + }; + + struct lm_ggml_threadpool; // forward declaration, see ggml.c + + typedef struct lm_ggml_threadpool * lm_ggml_threadpool_t; + + LM_GGML_API struct lm_ggml_threadpool_params lm_ggml_threadpool_params_default(int n_threads); + LM_GGML_API void lm_ggml_threadpool_params_init (struct lm_ggml_threadpool_params * p, int n_threads); + LM_GGML_API bool lm_ggml_threadpool_params_match (const struct lm_ggml_threadpool_params * p0, const struct lm_ggml_threadpool_params * p1); + #ifdef __cplusplus } #endif diff --git a/cpp/llama-grammar.cpp b/cpp/llama-grammar.cpp index 6cc2b386..311cb3b0 100644 --- a/cpp/llama-grammar.cpp +++ b/cpp/llama-grammar.cpp @@ -822,15 +822,11 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & stacks_new) { - stacks_new.clear(); - stacks_new.reserve(stacks.size()); +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) { + llama_grammar_stacks stacks_new; + stacks_new.reserve(grammar->stacks.size()); - for (const auto & stack : stacks) { + for (const auto & stack : grammar->stacks) { if (stack.empty()) { continue; } @@ -844,9 +840,11 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, stacks_new); + llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new); } } + + grammar->stacks = std::move(stacks_new); } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -1051,7 +1049,12 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { } struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { - llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; + llama_grammar * result = new llama_grammar { + grammar.vocab, + grammar.rules, + grammar.stacks, + grammar.partial_utf8, + }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { @@ -1059,7 +1062,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { - result->stacks[is][ie] = &result->rules[ir0][ir1]; + result->stacks[is][ie] = &result->rules[ir0][ir1]; } } } @@ -1126,11 +1129,8 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks stacks_new; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); - grammar.stacks = std::move(stacks_new); + llama_grammar_accept(&grammar, *it); } grammar.partial_utf8 = decoded.second; diff --git a/cpp/llama-grammar.h b/cpp/llama-grammar.h index f529ce35..13e940fb 100644 --- a/cpp/llama-grammar.h +++ b/cpp/llama-grammar.h @@ -58,6 +58,7 @@ using llama_grammar_rules = std::vector; using llama_grammar_stacks = std::vector; using llama_grammar_candidates = std::vector; +// TODO: remove, needed for tests atm const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); @@ -65,11 +66,7 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - uint32_t chr, - llama_grammar_stacks & stacks_new); +void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, diff --git a/cpp/llama-sampling.cpp b/cpp/llama-sampling.cpp index 755c72b0..f525d317 100644 --- a/cpp/llama-sampling.cpp +++ b/cpp/llama-sampling.cpp @@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab // penalties struct llama_sampler_penalties { - const int32_t n_vocab; - const llama_token special_eos_id; - const llama_token linefeed_id; - const int32_t penalty_last_n; const float penalty_repeat; const float penalty_freq; const float penalty_present; - const bool penalize_nl; - const bool ignore_eos; - ring_buffer prev; + + // a frequency map to count token occurrences + std::unordered_map token_count; }; static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) { @@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to return; } - ctx->prev.push_back(token); -} - -static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_penalties *) smpl->ctx; + ctx->token_count[token]++; - if (ctx->ignore_eos) { - assert(ctx->special_eos_id >= 0); + // if the ring buffer is full, remove the oldest token + if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) { + const auto old = ctx->prev.front(); - // optimistically check if the candidates are not yet sorted/shuffled/truncated - if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) { - cur_p->data[ctx->special_eos_id].logit = -INFINITY; - } else { - // else, search for the special EOS token - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id == ctx->special_eos_id) { - cur_p->data[i].logit = -INFINITY; - break; - } - } + ctx->token_count[old]--; + if (ctx->token_count[old] == 0) { + ctx->token_count.erase(old); } } - if ((ctx->penalty_last_n == 0) || - (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { - return; - } - - bool nl_found = false; - size_t nl_idx = 0; - float nl_logit = -INFINITY; - if (!ctx->penalize_nl) { - assert(ctx->linefeed_id >= 0); + ctx->prev.push_back(token); - // optimistically check if the candidates are not yet sorted/shuffled/truncated - if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) { - nl_found = true; - nl_idx = ctx->linefeed_id; - nl_logit = cur_p->data[ctx->linefeed_id].logit; - } else { - // else, search for the linefeed token - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id == ctx->linefeed_id) { - nl_found = true; - nl_idx = i; - nl_logit = cur_p->data[i].logit; - break; - } - } - } +#if 0 + // sanity check + std::unordered_map tmp; + for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { + tmp[ctx->prev.rat(i)]++; } - // Create a frequency map to count occurrences of each token in last_tokens - // TODO: optimize this by maintaining the token count in the sampler context - using llama_token_cnt = std::unordered_map; - llama_token_cnt token_count; + assert(ctx->token_count == tmp); +#endif +} + +static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_penalties *) smpl->ctx; - for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { - token_count[ctx->prev.rat(i)]++; + if ((ctx->penalty_last_n == 0) || + (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { + return; } // Apply frequency and presence penalties to the cur_p for (size_t i = 0; i < cur_p->size; ++i) { - const auto token_iter = token_count.find(cur_p->data[i].id); - if (token_iter == token_count.end()) { + const auto token_iter = ctx->token_count.find(cur_p->data[i].id); + if (token_iter == ctx->token_count.end()) { continue; } const int count = token_iter->second; + assert(count > 0 && count <= ctx->penalty_last_n); + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. if (cur_p->data[i].logit <= 0) { @@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok } cur_p->sorted = false; - - if (!ctx->penalize_nl && nl_found) { - // restore the logit of the newline token if it was penalized - cur_p->data[nl_idx].logit = nl_logit; - } } static void llama_sampler_penalties_reset(struct llama_sampler * smpl) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; ctx->prev.clear(); + ctx->token_count.clear(); } static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; auto * result = llama_sampler_init_penalties( - ctx->n_vocab, - ctx->special_eos_id, - ctx->linefeed_id, ctx->penalty_last_n, ctx->penalty_repeat, ctx->penalty_freq, - ctx->penalty_present, - ctx->penalize_nl, - ctx->ignore_eos); + ctx->penalty_present); // copy the state { @@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = { }; struct llama_sampler * llama_sampler_init_penalties( - int32_t n_vocab, - llama_token special_eos_id, - llama_token linefeed_id, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { - if (linefeed_id == LLAMA_TOKEN_NULL) { - penalize_nl = true; - } - - if (special_eos_id == LLAMA_TOKEN_NULL) { - ignore_eos = false; - } - + float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); return new llama_sampler { /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { - /* .n_vocab = */ n_vocab, - /* .special_eos_id = */ special_eos_id, - /* .linefeed_id = */ linefeed_id, /* .penalty_last_n = */ penalty_last_n, /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, /* .penalty_present = */ penalty_present, - /* .penalize_nl = */ penalize_nl, - /* .ignore_eos = */ ignore_eos, /* .prev = */ ring_buffer(penalty_last_n), + /* .token_count = */ {}, }, }; } @@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std if (word.find(str) != std::string::npos) { token_sequences.emplace(token_id, std::vector()); } else { - size_t word_len = word.size(), str_len = str.size(); + size_t word_len = word.size(); + size_t str_len = str.size(); size_t pos = -1; while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { bool match = true; diff --git a/cpp/llama-vocab.cpp b/cpp/llama-vocab.cpp index b0135070..6e13e30c 100644 --- a/cpp/llama-vocab.cpp +++ b/cpp/llama-vocab.cpp @@ -418,6 +418,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_SMOLLM: case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_EXAONE: + case LLAMA_VOCAB_PRE_TYPE_MINERVA: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -737,7 +738,7 @@ struct llm_tokenizer_wpm_session { std::vector words(1, ""); for (const uint32_t cpt : cpts_nfd) { - const auto flags = unicode_cpt_flags(cpt); + const auto flags = unicode_cpt_flags_from_cpt(cpt); if (flags.is_whitespace) { if (words.back().size()) { // finish previous word if any @@ -1866,6 +1867,10 @@ int32_t llama_detokenize_impl( int32_t text_len_max, bool remove_special, bool unparse_special) { + if (vocab.type == LLAMA_VOCAB_TYPE_NONE) { + return 0; + } + LM_GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first."); int32_t avail = text_len_max; diff --git a/cpp/llama.cpp b/cpp/llama.cpp index d6a6292c..e317960f 100644 --- a/cpp/llama.cpp +++ b/cpp/llama.cpp @@ -157,6 +157,7 @@ static std::string format(const char * fmt, ...) { enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_DECI, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, @@ -174,6 +175,7 @@ enum llm_arch { LLM_ARCH_QWEN, LLM_ARCH_QWEN2, LLM_ARCH_QWEN2MOE, + LLM_ARCH_QWEN2VL, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PLAMO, @@ -190,10 +192,11 @@ enum llm_arch { LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, - LLM_ARCH_OLMO_1124, + LLM_ARCH_OLMO2, LLM_ARCH_OLMOE, LLM_ARCH_OPENELM, LLM_ARCH_ARCTIC, + LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, LLM_ARCH_CHATGLM, LLM_ARCH_BITNET, @@ -206,61 +209,66 @@ enum llm_arch { LLM_ARCH_GRANITE, LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, + LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_UNKNOWN, }; static const std::map LLM_ARCH_NAMES = { - { LLM_ARCH_LLAMA, "llama" }, - { LLM_ARCH_FALCON, "falcon" }, - { LLM_ARCH_GROK, "grok" }, - { LLM_ARCH_GPT2, "gpt2" }, - { LLM_ARCH_GPTJ, "gptj" }, - { LLM_ARCH_GPTNEOX, "gptneox" }, - { LLM_ARCH_MPT, "mpt" }, - { LLM_ARCH_BAICHUAN, "baichuan" }, - { LLM_ARCH_STARCODER, "starcoder" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_BERT, "bert" }, - { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, - { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, - { LLM_ARCH_BLOOM, "bloom" }, - { LLM_ARCH_STABLELM, "stablelm" }, - { LLM_ARCH_QWEN, "qwen" }, - { LLM_ARCH_QWEN2, "qwen2" }, - { LLM_ARCH_QWEN2MOE, "qwen2moe" }, - { LLM_ARCH_PHI2, "phi2" }, - { LLM_ARCH_PHI3, "phi3" }, - { LLM_ARCH_PLAMO, "plamo" }, - { LLM_ARCH_CODESHELL, "codeshell" }, - { LLM_ARCH_ORION, "orion" }, - { LLM_ARCH_INTERNLM2, "internlm2" }, - { LLM_ARCH_MINICPM, "minicpm" }, - { LLM_ARCH_MINICPM3, "minicpm3" }, - { LLM_ARCH_GEMMA, "gemma" }, - { LLM_ARCH_GEMMA2, "gemma2" }, - { LLM_ARCH_STARCODER2, "starcoder2" }, - { LLM_ARCH_MAMBA, "mamba" }, - { LLM_ARCH_XVERSE, "xverse" }, - { LLM_ARCH_COMMAND_R, "command-r" }, - { LLM_ARCH_DBRX, "dbrx" }, - { LLM_ARCH_OLMO, "olmo" }, - { LLM_ARCH_OLMO_1124, "olmo_1124" }, - { LLM_ARCH_OLMOE, "olmoe" }, - { LLM_ARCH_OPENELM, "openelm" }, - { LLM_ARCH_ARCTIC, "arctic" }, - { LLM_ARCH_DEEPSEEK2, "deepseek2" }, - { LLM_ARCH_CHATGLM, "chatglm" }, - { LLM_ARCH_BITNET, "bitnet" }, - { LLM_ARCH_T5, "t5" }, - { LLM_ARCH_T5ENCODER, "t5encoder" }, - { LLM_ARCH_JAIS, "jais" }, - { LLM_ARCH_NEMOTRON, "nemotron" }, - { LLM_ARCH_EXAONE, "exaone" }, - { LLM_ARCH_RWKV6, "rwkv6" }, - { LLM_ARCH_GRANITE, "granite" }, - { LLM_ARCH_GRANITE_MOE, "granitemoe" }, - { LLM_ARCH_CHAMELEON, "chameleon" }, - { LLM_ARCH_UNKNOWN, "(unknown)" }, + { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_DECI, "deci" }, + { LLM_ARCH_FALCON, "falcon" }, + { LLM_ARCH_GROK, "grok" }, + { LLM_ARCH_GPT2, "gpt2" }, + { LLM_ARCH_GPTJ, "gptj" }, + { LLM_ARCH_GPTNEOX, "gptneox" }, + { LLM_ARCH_MPT, "mpt" }, + { LLM_ARCH_BAICHUAN, "baichuan" }, + { LLM_ARCH_STARCODER, "starcoder" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BERT, "bert" }, + { LLM_ARCH_NOMIC_BERT, "nomic-bert" }, + { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_STABLELM, "stablelm" }, + { LLM_ARCH_QWEN, "qwen" }, + { LLM_ARCH_QWEN2, "qwen2" }, + { LLM_ARCH_QWEN2MOE, "qwen2moe" }, + { LLM_ARCH_QWEN2VL, "qwen2vl" }, + { LLM_ARCH_PHI2, "phi2" }, + { LLM_ARCH_PHI3, "phi3" }, + { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_CODESHELL, "codeshell" }, + { LLM_ARCH_ORION, "orion" }, + { LLM_ARCH_INTERNLM2, "internlm2" }, + { LLM_ARCH_MINICPM, "minicpm" }, + { LLM_ARCH_MINICPM3, "minicpm3" }, + { LLM_ARCH_GEMMA, "gemma" }, + { LLM_ARCH_GEMMA2, "gemma2" }, + { LLM_ARCH_STARCODER2, "starcoder2" }, + { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_XVERSE, "xverse" }, + { LLM_ARCH_COMMAND_R, "command-r" }, + { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OLMO2, "olmo2" }, + { LLM_ARCH_OLMOE, "olmoe" }, + { LLM_ARCH_OPENELM, "openelm" }, + { LLM_ARCH_ARCTIC, "arctic" }, + { LLM_ARCH_DEEPSEEK, "deepseek" }, + { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_CHATGLM, "chatglm" }, + { LLM_ARCH_BITNET, "bitnet" }, + { LLM_ARCH_T5, "t5" }, + { LLM_ARCH_T5ENCODER, "t5encoder" }, + { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_NEMOTRON, "nemotron" }, + { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_RWKV6, "rwkv6" }, + { LLM_ARCH_GRANITE, "granite" }, + { LLM_ARCH_GRANITE_MOE, "granitemoe" }, + { LLM_ARCH_CHAMELEON, "chameleon" }, + { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_UNKNOWN, "(unknown)" }, }; enum llm_kv { @@ -280,6 +288,7 @@ enum llm_kv { LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, + LLM_KV_FEATURES_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, LLM_KV_FEED_FORWARD_LENGTH, @@ -311,6 +320,8 @@ enum llm_kv { LLM_KV_ATTENTION_VALUE_LENGTH, LLM_KV_ATTENTION_LAYERNORM_EPS, LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, + LLM_KV_ATTENTION_GROUPNORM_EPS, + LLM_KV_ATTENTION_GROUPNORM_GROUPS, LLM_KV_ATTENTION_CAUSAL, LLM_KV_ATTENTION_Q_LORA_RANK, LLM_KV_ATTENTION_KV_LORA_RANK, @@ -319,6 +330,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, @@ -373,6 +385,12 @@ enum llm_kv { LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + LLM_KV_POSNET_EMBEDDING_LENGTH, + LLM_KV_POSNET_BLOCK_COUNT, + + LLM_KV_CONVNEXT_EMBEDDING_LENGTH, + LLM_KV_CONVNEXT_BLOCK_COUNT, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -396,6 +414,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, { LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" }, @@ -427,6 +446,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_VALUE_LENGTH, "%s.attention.value_length" }, { LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" }, { LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_EPS, "%s.attention.group_norm_epsilon" }, + { LLM_KV_ATTENTION_GROUPNORM_GROUPS, "%s.attention.group_norm_groups" }, { LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" }, { LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" }, { LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" }, @@ -435,6 +456,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, @@ -456,6 +478,12 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, + { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, + { LLM_KV_POSNET_BLOCK_COUNT, "%s.posnet.block_count" }, + + { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, "%s.convnext.embedding_length" }, + { LLM_KV_CONVNEXT_BLOCK_COUNT, "%s.convnext.block_count" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -614,6 +642,22 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CONV1D, + LLM_TENSOR_CONVNEXT_DW, + LLM_TENSOR_CONVNEXT_NORM, + LLM_TENSOR_CONVNEXT_PW1, + LLM_TENSOR_CONVNEXT_PW2, + LLM_TENSOR_CONVNEXT_GAMMA, + LLM_TENSOR_POS_NET_CONV1, + LLM_TENSOR_POS_NET_CONV2, + LLM_TENSOR_POS_NET_NORM, + LLM_TENSOR_POS_NET_NORM1, + LLM_TENSOR_POS_NET_NORM2, + LLM_TENSOR_POS_NET_ATTN_NORM, + LLM_TENSOR_POS_NET_ATTN_Q, + LLM_TENSOR_POS_NET_ATTN_K, + LLM_TENSOR_POS_NET_ATTN_V, + LLM_TENSOR_POS_NET_ATTN_OUT, }; static const std::map> LLM_TENSOR_NAMES = { @@ -643,6 +687,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_DECI, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_BAICHUAN, { @@ -909,6 +979,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_QWEN2VL, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_QWEN2MOE, { @@ -1047,6 +1134,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" }, + { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, @@ -1221,7 +1310,7 @@ static const std::map> LLM_TENSOR_N }, }, { - LLM_ARCH_OLMO_1124, + LLM_ARCH_OLMO2, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, @@ -1297,6 +1386,33 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_DEEPSEEK, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + }, + }, { LLM_ARCH_DEEPSEEK2, { @@ -1552,6 +1668,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, + { + LLM_ARCH_WAVTOKENIZER_DEC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_CONV1D, "conv1d" }, + { LLM_TENSOR_CONVNEXT_DW, "convnext.%d.dw" }, + { LLM_TENSOR_CONVNEXT_NORM, "convnext.%d.norm" }, + { LLM_TENSOR_CONVNEXT_PW1, "convnext.%d.pw1" }, + { LLM_TENSOR_CONVNEXT_PW2, "convnext.%d.pw2" }, + { LLM_TENSOR_CONVNEXT_GAMMA, "convnext.%d.gamma" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_POS_NET_CONV1, "posnet.%d.conv1" }, + { LLM_TENSOR_POS_NET_CONV2, "posnet.%d.conv2" }, + { LLM_TENSOR_POS_NET_NORM, "posnet.%d.norm" }, + { LLM_TENSOR_POS_NET_NORM1, "posnet.%d.norm1" }, + { LLM_TENSOR_POS_NET_NORM2, "posnet.%d.norm2" }, + { LLM_TENSOR_POS_NET_ATTN_NORM, "posnet.%d.attn_norm" }, + { LLM_TENSOR_POS_NET_ATTN_Q, "posnet.%d.attn_q" }, + { LLM_TENSOR_POS_NET_ATTN_K, "posnet.%d.attn_k" }, + { LLM_TENSOR_POS_NET_ATTN_V, "posnet.%d.attn_v" }, + { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1560,6 +1701,73 @@ static const std::map> LLM_TENSOR_N }, }; +enum llm_chat_template { + LLM_CHAT_TEMPLATE_CHATML, + LLM_CHAT_TEMPLATE_LLAMA_2, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS, + LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP, + LLM_CHAT_TEMPLATE_MISTRAL_V1, + LLM_CHAT_TEMPLATE_MISTRAL_V3, + LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN, + LLM_CHAT_TEMPLATE_MISTRAL_V7, + LLM_CHAT_TEMPLATE_PHI_3, + LLM_CHAT_TEMPLATE_FALCON_3, + LLM_CHAT_TEMPLATE_ZEPHYR, + LLM_CHAT_TEMPLATE_MONARCH, + LLM_CHAT_TEMPLATE_GEMMA, + LLM_CHAT_TEMPLATE_ORION, + LLM_CHAT_TEMPLATE_OPENCHAT, + LLM_CHAT_TEMPLATE_VICUNA, + LLM_CHAT_TEMPLATE_VICUNA_ORCA, + LLM_CHAT_TEMPLATE_DEEPSEEK, + LLM_CHAT_TEMPLATE_DEEPSEEK_2, + LLM_CHAT_TEMPLATE_COMMAND_R, + LLM_CHAT_TEMPLATE_LLAMA_3, + LLM_CHAT_TEMPLATE_CHATGML_3, + LLM_CHAT_TEMPLATE_CHATGML_4, + LLM_CHAT_TEMPLATE_MINICPM, + LLM_CHAT_TEMPLATE_EXAONE_3, + LLM_CHAT_TEMPLATE_RWKV_WORLD, + LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GIGACHAT, + LLM_CHAT_TEMPLATE_MEGREZ, + LLM_CHAT_TEMPLATE_UNKNOWN, +}; + +static const std::map LLM_CHAT_TEMPLATES = { + { "chatml", LLM_CHAT_TEMPLATE_CHATML }, + { "llama2", LLM_CHAT_TEMPLATE_LLAMA_2 }, + { "llama2-sys", LLM_CHAT_TEMPLATE_LLAMA_2_SYS }, + { "llama2-sys-bos", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS }, + { "llama2-sys-strip", LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP }, + { "mistral-v1", LLM_CHAT_TEMPLATE_MISTRAL_V1 }, + { "mistral-v3", LLM_CHAT_TEMPLATE_MISTRAL_V3 }, + { "mistral-v3-tekken", LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN }, + { "mistral-v7", LLM_CHAT_TEMPLATE_MISTRAL_V7 }, + { "phi3", LLM_CHAT_TEMPLATE_PHI_3 }, + { "falcon3", LLM_CHAT_TEMPLATE_FALCON_3 }, + { "zephyr", LLM_CHAT_TEMPLATE_ZEPHYR }, + { "monarch", LLM_CHAT_TEMPLATE_MONARCH }, + { "gemma", LLM_CHAT_TEMPLATE_GEMMA }, + { "orion", LLM_CHAT_TEMPLATE_ORION }, + { "openchat", LLM_CHAT_TEMPLATE_OPENCHAT }, + { "vicuna", LLM_CHAT_TEMPLATE_VICUNA }, + { "vicuna-orca", LLM_CHAT_TEMPLATE_VICUNA_ORCA }, + { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, + { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, + { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, + { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, + { "chatglm3", LLM_CHAT_TEMPLATE_CHATGML_3 }, + { "chatglm4", LLM_CHAT_TEMPLATE_CHATGML_4 }, + { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, + { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, + { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, + { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, +}; + static llm_arch llm_arch_from_string(const std::string & name) { for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT if (kv.second == name) { @@ -1633,9 +1841,10 @@ struct LLM_TN { // static const std::map LLAMA_ROPE_SCALING_TYPES = { - { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, - { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, - { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, + { LLAMA_ROPE_SCALING_TYPE_NONE, "none" }, + { LLAMA_ROPE_SCALING_TYPE_LINEAR, "linear" }, + { LLAMA_ROPE_SCALING_TYPE_YARN, "yarn" }, + { LLAMA_ROPE_SCALING_TYPE_LONGROPE, "longrope" }, }; static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { @@ -1741,7 +1950,7 @@ struct llama_file { DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL); if (!bufLen) { - ret = format("Win32 error code: %s", error_code); + ret = format("Win32 error code: %lx", error_code); } else { ret = lpMsgBuf; LocalFree(lpMsgBuf); @@ -2079,7 +2288,7 @@ struct llama_mmap { HMODULE hKernel32 = GetModuleHandleW(L"kernel32.dll"); // may fail on pre-Windows 8 systems - pPrefetchVirtualMemory = reinterpret_cast (GetProcAddress(hKernel32, "PrefetchVirtualMemory")); + pPrefetchVirtualMemory = (decltype(pPrefetchVirtualMemory))(void *) GetProcAddress(hKernel32, "PrefetchVirtualMemory"); if (pPrefetchVirtualMemory) { // advise the kernel to preload the mapped memory @@ -2352,6 +2561,7 @@ enum e_model { MODEL_16B, MODEL_20B, MODEL_30B, + MODEL_32B, MODEL_34B, MODEL_35B, MODEL_40B, @@ -2377,15 +2587,26 @@ static const size_t kiB = 1024; static const size_t MiB = 1024*kiB; static const size_t GiB = 1024*MiB; +struct llama_hparams_posnet { + uint32_t n_embd; + uint32_t n_layer; +}; + +struct llama_hparams_convnext { + uint32_t n_embd; + uint32_t n_layer; +}; + struct llama_hparams { bool vocab_only; bool rope_finetuned; bool use_par_res; bool swin_norm; - uint32_t n_vocab; + uint32_t n_vocab = 0; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; + uint32_t n_embd_features = 0; uint32_t n_layer; uint32_t n_rot; uint32_t n_swa = 0; // sliding window attention (SWA) @@ -2396,6 +2617,10 @@ struct llama_hparams { uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_rel_attn_bkts = 0; + // for WavTokenizer + struct llama_hparams_posnet posnet; + struct llama_hparams_convnext convnext; + std::array n_head_arr; std::array n_head_kv_arr; std::array n_ff_arr; @@ -2410,6 +2635,9 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; + float f_norm_group_eps; + + uint32_t n_norm_groups; float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; @@ -2420,11 +2648,12 @@ struct llama_hparams { uint32_t time_decay_extra_dim = 0; uint32_t wkv_head_size = 0; - float rope_attn_factor = 1.0f; - float rope_freq_base_train; - float rope_freq_scale_train; - uint32_t n_ctx_orig_yarn; - float rope_yarn_log_mul; + float rope_attn_factor = 1.0f; + float rope_freq_base_train; + float rope_freq_scale_train; + uint32_t n_ctx_orig_yarn; + float rope_yarn_log_mul; + int rope_sections[4]; // for State Space Models uint32_t ssm_d_conv = 0; @@ -2454,63 +2683,6 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; - bool operator!=(const llama_hparams & other) const { - if (this->vocab_only != other.vocab_only) return true; - if (this->n_vocab != other.n_vocab) return true; - if (this->n_ctx_train != other.n_ctx_train) return true; - if (this->n_embd != other.n_embd) return true; - if (this->n_layer != other.n_layer) return true; - if (this->n_rot != other.n_rot) return true; - if (this->n_swa != other.n_swa) return true; - if (this->n_embd_head_k != other.n_embd_head_k) return true; - if (this->n_embd_head_v != other.n_embd_head_v) return true; - if (this->n_expert != other.n_expert) return true; - if (this->n_expert_used != other.n_expert_used) return true; - - if (this->n_head_arr != other.n_head_arr) return true; - if (this->n_head_kv_arr != other.n_head_kv_arr) return true; - if (this->n_ff_arr != other.n_ff_arr) return true; - - if (this->n_rel_attn_bkts != other.n_rel_attn_bkts) return true; - if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true; - if (this->n_lora_q != other.n_lora_q) return true; - if (this->n_lora_kv != other.n_lora_kv) return true; - if (this->n_ff_exp != other.n_ff_exp) return true; - if (this->n_ff_shexp != other.n_ff_shexp) return true; - if (this->n_expert_shared != other.n_expert_shared) return true; - - if (this->rope_finetuned != other.rope_finetuned) return true; - if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true; - - if (this->ssm_d_conv != other.ssm_d_conv) return true; - if (this->ssm_d_inner != other.ssm_d_inner) return true; - if (this->ssm_d_state != other.ssm_d_state) return true; - if (this->ssm_dt_rank != other.ssm_dt_rank) return true; - if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; - - if (this->rescale_every_n_layers != other.rescale_every_n_layers) return true; - if (this->time_mix_extra_dim != other.time_mix_extra_dim) return true; - if (this->time_decay_extra_dim != other.time_decay_extra_dim) return true; - if (this->wkv_head_size != other.wkv_head_size) return true; - - if (this->dec_start_token_id != other.dec_start_token_id) return true; - - const float EPSILON = 1e-9f; - - if (!is_float_close(this->f_norm_eps, other.f_norm_eps, EPSILON)) return true; - if (!is_float_close(this->f_norm_rms_eps, other.f_norm_rms_eps, EPSILON)) return true; - if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true; - if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true; - if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true; - if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true; - if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true; - if (!is_float_close(this->f_residual_scale, other.f_residual_scale, EPSILON)) return true; - if (!is_float_close(this->f_embedding_scale, other.f_embedding_scale, EPSILON)) return true; - if (!is_float_close(this->f_attention_scale, other.f_attention_scale, EPSILON)) return true; - - return false; - } - uint32_t n_head(uint32_t il = 0) const { if (il < n_layer) { return n_head_arr[il]; @@ -2563,21 +2735,21 @@ struct llama_hparams { if (wkv_head_size != 0) { // for RWKV models return 2 * n_embd; - } else { - // TODO: maybe support other convolution strides than 1 - // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } + + // TODO: maybe support other convolution strides than 1 + // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; - } else { - // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; } + + // corresponds to Mamba's ssm_states size + return ssm_d_state * ssm_d_inner; } }; @@ -2615,142 +2787,187 @@ struct llama_cparams { void * cb_eval_user_data; }; -// TODO: separate into "llama_layer_enc" and "llama_layer_dec" -struct llama_layer { - llama_layer() { - // initialize all pointers to NULL - std::memset(this, 0, sizeof(*this)); - } +struct llama_layer_posnet { + // resnet + struct lm_ggml_tensor * norm1 = nullptr; + struct lm_ggml_tensor * norm1_b = nullptr; + + struct lm_ggml_tensor * conv1 = nullptr; + struct lm_ggml_tensor * conv1_b = nullptr; + + struct lm_ggml_tensor * norm2 = nullptr; + struct lm_ggml_tensor * norm2_b = nullptr; + + struct lm_ggml_tensor * conv2 = nullptr; + struct lm_ggml_tensor * conv2_b = nullptr; + + // attention + struct lm_ggml_tensor * attn_norm = nullptr; + struct lm_ggml_tensor * attn_norm_b = nullptr; + + struct lm_ggml_tensor * attn_q = nullptr; + struct lm_ggml_tensor * attn_q_b = nullptr; + + struct lm_ggml_tensor * attn_k = nullptr; + struct lm_ggml_tensor * attn_k_b = nullptr; + + struct lm_ggml_tensor * attn_v = nullptr; + struct lm_ggml_tensor * attn_v_b = nullptr; + + struct lm_ggml_tensor * attn_o = nullptr; + struct lm_ggml_tensor * attn_o_b = nullptr; + + // normalize + struct lm_ggml_tensor * norm = nullptr; + struct lm_ggml_tensor * norm_b = nullptr; +}; + +struct llama_layer_convnext { + struct lm_ggml_tensor * dw = nullptr; + struct lm_ggml_tensor * dw_b = nullptr; + struct lm_ggml_tensor * norm = nullptr; + struct lm_ggml_tensor * norm_b = nullptr; + + struct lm_ggml_tensor * pw1 = nullptr; + struct lm_ggml_tensor * pw1_b = nullptr; + + struct lm_ggml_tensor * pw2 = nullptr; + struct lm_ggml_tensor * pw2_b = nullptr; + + struct lm_ggml_tensor * gamma = nullptr; +}; + +struct llama_layer { // normalization - struct lm_ggml_tensor * attn_norm; - struct lm_ggml_tensor * attn_norm_b; - struct lm_ggml_tensor * attn_norm_2; - struct lm_ggml_tensor * attn_norm_2_b; - struct lm_ggml_tensor * attn_q_norm; - struct lm_ggml_tensor * attn_q_norm_b; - struct lm_ggml_tensor * attn_k_norm; - struct lm_ggml_tensor * attn_k_norm_b; - struct lm_ggml_tensor * attn_out_norm; - struct lm_ggml_tensor * attn_out_norm_b; - struct lm_ggml_tensor * attn_q_a_norm; - struct lm_ggml_tensor * attn_kv_a_norm; - struct lm_ggml_tensor * attn_sub_norm; - struct lm_ggml_tensor * attn_post_norm; - struct lm_ggml_tensor * ffn_sub_norm; - struct lm_ggml_tensor * attn_norm_cross; - struct lm_ggml_tensor * attn_norm_enc; + struct lm_ggml_tensor * attn_norm = nullptr; + struct lm_ggml_tensor * attn_norm_b = nullptr; + struct lm_ggml_tensor * attn_norm_2 = nullptr; + struct lm_ggml_tensor * attn_norm_2_b = nullptr; + struct lm_ggml_tensor * attn_q_norm = nullptr; + struct lm_ggml_tensor * attn_q_norm_b = nullptr; + struct lm_ggml_tensor * attn_k_norm = nullptr; + struct lm_ggml_tensor * attn_k_norm_b = nullptr; + struct lm_ggml_tensor * attn_out_norm = nullptr; + struct lm_ggml_tensor * attn_out_norm_b = nullptr; + struct lm_ggml_tensor * attn_q_a_norm = nullptr; + struct lm_ggml_tensor * attn_kv_a_norm = nullptr; + struct lm_ggml_tensor * attn_sub_norm = nullptr; + struct lm_ggml_tensor * attn_post_norm = nullptr; + struct lm_ggml_tensor * ffn_sub_norm = nullptr; + struct lm_ggml_tensor * attn_norm_cross = nullptr; + struct lm_ggml_tensor * attn_norm_enc = nullptr; // attention - struct lm_ggml_tensor * wq; - struct lm_ggml_tensor * wk; - struct lm_ggml_tensor * wv; - struct lm_ggml_tensor * wo; - struct lm_ggml_tensor * wqkv; - struct lm_ggml_tensor * wq_a; - struct lm_ggml_tensor * wq_b; - struct lm_ggml_tensor * wkv_a_mqa; - struct lm_ggml_tensor * wkv_b; - struct lm_ggml_tensor * wq_cross; - struct lm_ggml_tensor * wk_cross; - struct lm_ggml_tensor * wv_cross; - struct lm_ggml_tensor * wo_cross; - struct lm_ggml_tensor * wq_enc; - struct lm_ggml_tensor * wk_enc; - struct lm_ggml_tensor * wv_enc; - struct lm_ggml_tensor * wo_enc; + struct lm_ggml_tensor * wq = nullptr; + struct lm_ggml_tensor * wk = nullptr; + struct lm_ggml_tensor * wv = nullptr; + struct lm_ggml_tensor * wo = nullptr; + struct lm_ggml_tensor * wqkv = nullptr; + struct lm_ggml_tensor * wq_a = nullptr; + struct lm_ggml_tensor * wq_b = nullptr; + struct lm_ggml_tensor * wkv_a_mqa = nullptr; + struct lm_ggml_tensor * wkv_b = nullptr; + struct lm_ggml_tensor * wq_cross = nullptr; + struct lm_ggml_tensor * wk_cross = nullptr; + struct lm_ggml_tensor * wv_cross = nullptr; + struct lm_ggml_tensor * wo_cross = nullptr; + struct lm_ggml_tensor * wq_enc = nullptr; + struct lm_ggml_tensor * wk_enc = nullptr; + struct lm_ggml_tensor * wv_enc = nullptr; + struct lm_ggml_tensor * wo_enc = nullptr; // attention bias - struct lm_ggml_tensor * bq; - struct lm_ggml_tensor * bk; - struct lm_ggml_tensor * bv; - struct lm_ggml_tensor * bo; - struct lm_ggml_tensor * bqkv; + struct lm_ggml_tensor * bq = nullptr; + struct lm_ggml_tensor * bk = nullptr; + struct lm_ggml_tensor * bv = nullptr; + struct lm_ggml_tensor * bo = nullptr; + struct lm_ggml_tensor * bqkv = nullptr; // relative position bias - struct lm_ggml_tensor * attn_rel_b; - struct lm_ggml_tensor * attn_rel_b_enc; - struct lm_ggml_tensor * attn_rel_b_cross; + struct lm_ggml_tensor * attn_rel_b = nullptr; + struct lm_ggml_tensor * attn_rel_b_enc = nullptr; + struct lm_ggml_tensor * attn_rel_b_cross = nullptr; // normalization - struct lm_ggml_tensor * ffn_norm; - struct lm_ggml_tensor * ffn_norm_b; - struct lm_ggml_tensor * ffn_post_norm; - struct lm_ggml_tensor * layer_out_norm; - struct lm_ggml_tensor * layer_out_norm_b; - struct lm_ggml_tensor * ffn_norm_exps; - struct lm_ggml_tensor * ffn_norm_enc; + struct lm_ggml_tensor * ffn_norm = nullptr; + struct lm_ggml_tensor * ffn_norm_b = nullptr; + struct lm_ggml_tensor * ffn_post_norm = nullptr; + struct lm_ggml_tensor * layer_out_norm = nullptr; + struct lm_ggml_tensor * layer_out_norm_b = nullptr; + struct lm_ggml_tensor * ffn_norm_exps = nullptr; + struct lm_ggml_tensor * ffn_norm_enc = nullptr; // ff - struct lm_ggml_tensor * ffn_gate; // w1 - struct lm_ggml_tensor * ffn_down; // w2 - struct lm_ggml_tensor * ffn_up; // w3 - struct lm_ggml_tensor * ffn_gate_enc; - struct lm_ggml_tensor * ffn_down_enc; - struct lm_ggml_tensor * ffn_up_enc; + struct lm_ggml_tensor * ffn_gate = nullptr; // w1 + struct lm_ggml_tensor * ffn_down = nullptr; // w2 + struct lm_ggml_tensor * ffn_up = nullptr; // w3 + struct lm_ggml_tensor * ffn_gate_enc = nullptr; + struct lm_ggml_tensor * ffn_down_enc = nullptr; + struct lm_ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct lm_ggml_tensor * ffn_gate_inp; - struct lm_ggml_tensor * ffn_gate_exps; - struct lm_ggml_tensor * ffn_down_exps; - struct lm_ggml_tensor * ffn_up_exps ; + struct lm_ggml_tensor * ffn_gate_inp = nullptr; + struct lm_ggml_tensor * ffn_gate_exps = nullptr; + struct lm_ggml_tensor * ffn_down_exps = nullptr; + struct lm_ggml_tensor * ffn_up_exps = nullptr; // ff shared expert (shexp) - struct lm_ggml_tensor * ffn_gate_inp_shexp; - struct lm_ggml_tensor * ffn_gate_shexp; - struct lm_ggml_tensor * ffn_down_shexp; - struct lm_ggml_tensor * ffn_up_shexp; + struct lm_ggml_tensor * ffn_gate_inp_shexp = nullptr; + struct lm_ggml_tensor * ffn_gate_shexp = nullptr; + struct lm_ggml_tensor * ffn_down_shexp = nullptr; + struct lm_ggml_tensor * ffn_up_shexp = nullptr; // ff bias - struct lm_ggml_tensor * ffn_gate_b; - struct lm_ggml_tensor * ffn_down_b; // b2 - struct lm_ggml_tensor * ffn_up_b; // b3 - struct lm_ggml_tensor * ffn_act; + struct lm_ggml_tensor * ffn_gate_b = nullptr; + struct lm_ggml_tensor * ffn_down_b = nullptr; // b2 + struct lm_ggml_tensor * ffn_up_b = nullptr; // b3 + struct lm_ggml_tensor * ffn_act = nullptr; // mamba proj - struct lm_ggml_tensor * ssm_in; - struct lm_ggml_tensor * ssm_x; - struct lm_ggml_tensor * ssm_dt; - struct lm_ggml_tensor * ssm_out; + struct lm_ggml_tensor * ssm_in = nullptr; + struct lm_ggml_tensor * ssm_x = nullptr; + struct lm_ggml_tensor * ssm_dt = nullptr; + struct lm_ggml_tensor * ssm_out = nullptr; // mamba - struct lm_ggml_tensor * ssm_conv1d; - struct lm_ggml_tensor * ssm_a; - struct lm_ggml_tensor * ssm_d; + struct lm_ggml_tensor * ssm_conv1d = nullptr; + struct lm_ggml_tensor * ssm_a = nullptr; + struct lm_ggml_tensor * ssm_d = nullptr; // mamba bias - struct lm_ggml_tensor * ssm_conv1d_b; - struct lm_ggml_tensor * ssm_dt_b; + struct lm_ggml_tensor * ssm_conv1d_b = nullptr; + struct lm_ggml_tensor * ssm_dt_b = nullptr; // rwkv - struct lm_ggml_tensor * time_mix_w1; - struct lm_ggml_tensor * time_mix_w2; - struct lm_ggml_tensor * time_mix_lerp_x; - struct lm_ggml_tensor * time_mix_lerp_w; - struct lm_ggml_tensor * time_mix_lerp_k; - struct lm_ggml_tensor * time_mix_lerp_v; - struct lm_ggml_tensor * time_mix_lerp_r; - struct lm_ggml_tensor * time_mix_lerp_g; - - struct lm_ggml_tensor * time_mix_first; - struct lm_ggml_tensor * time_mix_decay; - struct lm_ggml_tensor * time_mix_decay_w1; - struct lm_ggml_tensor * time_mix_decay_w2; - struct lm_ggml_tensor * time_mix_key; - struct lm_ggml_tensor * time_mix_value; - struct lm_ggml_tensor * time_mix_receptance; - struct lm_ggml_tensor * time_mix_gate; - - struct lm_ggml_tensor * time_mix_ln; - struct lm_ggml_tensor * time_mix_ln_b; - struct lm_ggml_tensor * time_mix_output; - - struct lm_ggml_tensor * channel_mix_lerp_k; - struct lm_ggml_tensor * channel_mix_lerp_r; - - struct lm_ggml_tensor * channel_mix_key; - struct lm_ggml_tensor * channel_mix_receptance; - struct lm_ggml_tensor * channel_mix_value; + struct lm_ggml_tensor * time_mix_w1 = nullptr; + struct lm_ggml_tensor * time_mix_w2 = nullptr; + struct lm_ggml_tensor * time_mix_lerp_x = nullptr; + struct lm_ggml_tensor * time_mix_lerp_w = nullptr; + struct lm_ggml_tensor * time_mix_lerp_k = nullptr; + struct lm_ggml_tensor * time_mix_lerp_v = nullptr; + struct lm_ggml_tensor * time_mix_lerp_r = nullptr; + struct lm_ggml_tensor * time_mix_lerp_g = nullptr; + + struct lm_ggml_tensor * time_mix_first = nullptr; + struct lm_ggml_tensor * time_mix_decay = nullptr; + struct lm_ggml_tensor * time_mix_decay_w1 = nullptr; + struct lm_ggml_tensor * time_mix_decay_w2 = nullptr; + struct lm_ggml_tensor * time_mix_key = nullptr; + struct lm_ggml_tensor * time_mix_value = nullptr; + struct lm_ggml_tensor * time_mix_receptance = nullptr; + struct lm_ggml_tensor * time_mix_gate = nullptr; + + struct lm_ggml_tensor * time_mix_ln = nullptr; + struct lm_ggml_tensor * time_mix_ln_b = nullptr; + struct lm_ggml_tensor * time_mix_output = nullptr; + + struct lm_ggml_tensor * channel_mix_lerp_k = nullptr; + struct lm_ggml_tensor * channel_mix_lerp_r = nullptr; + + struct lm_ggml_tensor * channel_mix_key = nullptr; + struct lm_ggml_tensor * channel_mix_receptance = nullptr; + struct lm_ggml_tensor * channel_mix_value = nullptr; // long rope factors struct lm_ggml_tensor * rope_long = nullptr; @@ -2758,13 +2975,17 @@ struct llama_layer { struct lm_ggml_tensor * rope_freqs = nullptr; // bitnet scale - struct lm_ggml_tensor * wq_scale; - struct lm_ggml_tensor * wk_scale; - struct lm_ggml_tensor * wv_scale; - struct lm_ggml_tensor * wo_scale; - struct lm_ggml_tensor * ffn_gate_scale; - struct lm_ggml_tensor * ffn_up_scale; - struct lm_ggml_tensor * ffn_down_scale; + struct lm_ggml_tensor * wq_scale = nullptr; + struct lm_ggml_tensor * wk_scale = nullptr; + struct lm_ggml_tensor * wv_scale = nullptr; + struct lm_ggml_tensor * wo_scale = nullptr; + struct lm_ggml_tensor * ffn_gate_scale = nullptr; + struct lm_ggml_tensor * ffn_up_scale = nullptr; + struct lm_ggml_tensor * ffn_down_scale = nullptr; + + struct llama_layer_posnet posnet; + + struct llama_layer_convnext convnext; }; // very similar to llama_batch, @@ -2895,6 +3116,9 @@ struct llama_model { struct lm_ggml_tensor * cls_out = nullptr; struct lm_ggml_tensor * cls_out_b = nullptr; + struct lm_ggml_tensor * conv1d = nullptr; + struct lm_ggml_tensor * conv1d_b = nullptr; + std::vector layers; // gguf metadata @@ -2979,6 +3203,7 @@ struct llama_sbatch { // batch indices of the output std::vector out_ids; std::vector seq; + const llama_batch * batch = nullptr; // buffers for the ubatch @@ -3324,6 +3549,11 @@ struct llama_context { // whether we are computing encoder output or decoder output bool is_encoding = false; + // TODO: find a better way to accommodate mutli-dimension position encoding methods + // number of position id each token get, 1 for each token in most cases. + // when using m-rope, it will be 3 position ids per token to representing 3 dimension coordinate. + int n_pos_per_token = 1; + // output of the encoder part of the encoder-decoder models std::vector embd_enc; std::vector> seq_ids_enc; @@ -3394,6 +3624,17 @@ static int llama_get_device_count(const llama_model & model) { return (int) model.devices.size(); } +static struct lm_ggml_tensor * llama_get_model_tensor(const struct llama_model * model, const char * name) { + auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(), + [name](const std::pair & it) { + return it.first == name; + }); + if (it == model->tensors_by_name.end()) { + return nullptr; + } + return it->second; +} + template static bool buft_supported(lm_ggml_backend_buffer_type_t buft, lm_ggml_backend_dev_t dev, F & fn) { lm_ggml_init_params params = { @@ -3447,7 +3688,9 @@ static bool llama_kv_cache_init( const struct llama_hparams & hparams = model.hparams; - const int64_t n_layer = hparams.n_layer; + const int32_t n_layer = hparams.n_layer; + + LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", __func__, kv_size, offload, lm_ggml_type_name(type_k), lm_ggml_type_name(type_v), n_layer); cache.has_shift = false; @@ -3488,10 +3731,12 @@ static bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); - for (int i = 0; i < (int) n_layer; i++) { + for (int i = 0; i < n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + LLAMA_LOG_DEBUG("%s: layer %d: n_embd_k_gqa = %d, n_embd_v_gqa = %d\n", __func__, i, n_embd_k_gqa, n_embd_v_gqa); + lm_ggml_backend_buffer_type_t buft; if (offload) { auto * dev = model.dev_layer.at(i).dev; @@ -4524,9 +4769,6 @@ struct llama_model_loader { case LM_GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case LM_GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case LM_GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; - case LM_GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; - case LM_GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; - case LM_GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, lm_ggml_type_name(type_max)); @@ -4877,7 +5119,9 @@ struct llama_model_loader { mappings.reserve(files.size()); mmaps_used.reserve(files.size()); for (const auto & file : files) { - std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, lm_ggml_is_numa())); + auto * reg = lm_ggml_backend_dev_backend_reg(lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU)); + auto * is_numa_fn = (decltype(lm_ggml_is_numa) *) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_cpu_is_numa"); + std::unique_ptr mapping(new llama_mmap(file.get(), prefetch ? -1 : 0, is_numa_fn())); mmaps_used.emplace_back(mapping->size, 0); if (mlock_mmaps) { std::unique_ptr mlock_mmap(new llama_mlock()); @@ -5288,9 +5532,6 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw"; case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw"; - case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: return "Q4_0_4_4"; - case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: return "Q4_0_4_8"; - case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: return "Q4_0_8_8"; default: return "unknown, may not work"; } @@ -5339,6 +5580,7 @@ static const char * llama_model_type_name(e_model type) { case MODEL_16B: return "16B"; case MODEL_20B: return "20B"; case MODEL_30B: return "30B"; + case MODEL_32B: return "32B"; case MODEL_34B: return "34B"; case MODEL_35B: return "35B"; case MODEL_40B: return "40B"; @@ -5407,7 +5649,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); // get hparams kv - ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, false); // everything past this point is not vocab-related if (hparams.vocab_only) { @@ -5420,6 +5662,16 @@ static void llm_load_hparams( ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + if (model.arch == LLM_ARCH_WAVTOKENIZER_DEC) { + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + + ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); + ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); + + ml.get_key(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd); + ml.get_key(LLM_KV_CONVNEXT_BLOCK_COUNT, hparams.convnext.n_layer); + } + LM_GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); LM_GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); if (hparams.n_expert > 0) { @@ -5428,13 +5680,13 @@ static void llm_load_hparams( LM_GGML_ASSERT(hparams.n_expert_used == 0); } - // zero-out the per-layer hparams + // zero-out the array hparams std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; @@ -5483,7 +5735,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) { + if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_DECI || model.arch == LLM_ARCH_FALCON) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -5523,11 +5775,24 @@ static void llm_load_hparams( } } } break; + case LLM_ARCH_DECI: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_MINICPM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); switch (hparams.n_layer) { + case 52: model.type = e_model::MODEL_1B; break; case 40: model.type = e_model::MODEL_2B; break; default: model.type = e_model::MODEL_UNKNOWN; } @@ -5692,6 +5957,13 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_QWEN2VL: + { + std::array section_dims; + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, section_dims, 4, true); + std::copy(section_dims.begin(), section_dims.begin() + 4, std::begin(hparams.rope_sections)); + } + // fall through case LLM_ARCH_QWEN2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5699,7 +5971,10 @@ static void llm_load_hparams( case 24: model.type = hparams.n_embd == 1024 ? e_model::MODEL_0_5B : e_model::MODEL_1B; break; case 28: model.type = hparams.n_embd == 1536 ? e_model::MODEL_1_5B : e_model::MODEL_7B; break; case 32: model.type = e_model::MODEL_7B; break; + case 36: model.type = e_model::MODEL_3B; break; case 40: model.type = hparams.n_head() == 20 ? e_model::MODEL_4B : e_model::MODEL_13B; break; + case 48: model.type = e_model::MODEL_14B; break; + case 64: model.type = e_model::MODEL_32B; break; case 80: model.type = e_model::MODEL_70B; break; default: model.type = e_model::MODEL_UNKNOWN; } @@ -5909,7 +6184,7 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; - case LLM_ARCH_OLMO_1124: + case LLM_ARCH_OLMO2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5999,6 +6274,19 @@ static void llm_load_hparams( model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_DEEPSEEK: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + + switch (hparams.n_layer) { + case 28: model.type = e_model::MODEL_20B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_DEEPSEEK2: { bool is_lite = (hparams.n_layer == 27); @@ -6152,6 +6440,13 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + } break; default: (void)0; } @@ -6181,7 +6476,7 @@ static void llm_load_vocab( ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_model); ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); - if (tokenizer_model == "no_vocab") { + if (tokenizer_model == "no_vocab" || tokenizer_model == "none") { vocab.type = LLAMA_VOCAB_TYPE_NONE; // default special tokens @@ -6319,7 +6614,8 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "llama3" || tokenizer_pre == "llama-v3" || - tokenizer_pre == "llama-bpe") { + tokenizer_pre == "llama-bpe"|| + tokenizer_pre == "falcon3") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3; vocab.tokenizer_ignore_merges = true; vocab.tokenizer_add_bos = true; @@ -6345,10 +6641,12 @@ static void llm_load_vocab( tokenizer_pre == "phi-2" || tokenizer_pre == "jina-es" || tokenizer_pre == "jina-de" || + tokenizer_pre == "gigachat" || tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-es" || tokenizer_pre == "jina-v2-de" || - tokenizer_pre == "jina-v2-code") { + tokenizer_pre == "jina-v2-code" || + tokenizer_pre == "roberta-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "refact") { @@ -6415,6 +6713,12 @@ static void llm_load_vocab( vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; vocab.tokenizer_add_bos = true; vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "minerva-7b") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_MINERVA; + } else if ( + tokenizer_pre == "megrez") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -6993,6 +7297,13 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len); + if (model.arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } + if (model.arch == LLM_ARCH_DEEPSEEK2) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); @@ -7008,7 +7319,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); } - if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) { + if (model.arch == LLM_ARCH_MINICPM || model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) { LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); @@ -7149,6 +7460,22 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT_ID}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_NONE}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, LM_GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_CONV1, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_CONV2, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, + {LLM_TENSOR_POS_NET_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}}, + {LLM_TENSOR_POS_NET_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_DW, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_IM2COL}}, + {LLM_TENSOR_CONVNEXT_NORM, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, + {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, LM_GGML_OP_MUL}}, }; // checks if the weight tensor can be used with the specified buffer type and device @@ -7192,12 +7519,12 @@ static bool weight_buft_supported(const llama_hparams & hparams, lm_ggml_tensor } break; case LM_GGML_OP_ADD: { - lm_ggml_tensor * a = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, w->ne[0], 512); + lm_ggml_tensor * a = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); op_tensor = lm_ggml_add(ctx, a, w); } break; case LM_GGML_OP_MUL: { - lm_ggml_tensor * a = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, w->ne[0], 512); + lm_ggml_tensor * a = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); op_tensor = lm_ggml_mul(ctx, a, w); } break; case LM_GGML_OP_DIV: @@ -7253,6 +7580,12 @@ static bool weight_buft_supported(const llama_hparams & hparams, lm_ggml_tensor lm_ggml_tensor * state = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, S, n_seqs, S, H); op_tensor = lm_ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); } break; + case LM_GGML_OP_IM2COL: + { + const int n_embd = hparams.n_embd; + lm_ggml_tensor * b = lm_ggml_new_tensor_4d(ctx, LM_GGML_TYPE_F32, n_embd, w->ne[1], 1, 1); + op_tensor = lm_ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, LM_GGML_TYPE_F16); + } break; default: LM_GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, lm_ggml_op_name(op), w->name); } @@ -7383,7 +7716,8 @@ static bool llm_load_tensors( model.main_gpu = main_gpu; model.n_gpu_layers = n_gpu_layers; - const int n_layer = hparams.n_layer; + const int n_layer = hparams.n_layer; + bool use_mmap_buffer = true; // build a list of buffer types for the CPU and GPU devices @@ -7633,7 +7967,13 @@ static bool llm_load_tensors( layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } if (n_expert == 0) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); @@ -7652,6 +7992,68 @@ static bool llm_load_tensors( } } } break; + case LLM_ARCH_DECI: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + } + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; case LLM_ARCH_MINICPM3: { const int64_t n_embd_head_qk_rope = hparams.n_rot; @@ -8100,6 +8502,7 @@ static bool llm_load_tensors( } } break; case LLM_ARCH_QWEN2: + case LLM_ARCH_QWEN2VL: { model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8602,7 +9005,7 @@ static bool llm_load_tensors( layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; - case LLM_ARCH_OLMO_1124: + case LLM_ARCH_OLMO2: { model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -8760,15 +9163,8 @@ static bool llm_load_tensors( layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); } } break; - case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK: { - const bool is_lite = (hparams.n_layer == 27); - - const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; const int64_t n_ff_exp = hparams.n_ff_exp; const int64_t n_expert_shared = hparams.n_expert_shared; @@ -8783,23 +9179,11 @@ static bool llm_load_tensors( auto & layer = model.layers[i]; layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!is_lite) { - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); - } - - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - if (!is_lite) { - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); - } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - } - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); if (i < (int) hparams.n_layer_dense_lead) { @@ -8828,12 +9212,80 @@ static bool llm_load_tensors( } } } break; - case LLM_ARCH_BITNET: + case LLM_ARCH_DEEPSEEK2: { - model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + const bool is_lite = (hparams.n_layer == 27); + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (!is_lite) { + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + } + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (!is_lite) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + } + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } + } break; + case LLM_ARCH_BITNET: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); for (int i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; @@ -9130,9 +9582,9 @@ static bool llm_load_tensors( } break; case LLM_ARCH_CHAMELEON: { - model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - // output + // output model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed @@ -9161,6 +9613,109 @@ static bool llm_load_tensors( layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); + + model.conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); + model.conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); + + // posnet + { + const int64_t n_embd = hparams.posnet.n_embd; + + for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { + auto & layer = model.layers[i].posnet; + + // posnet: + // + // - resnet + // - resnet + // - attn + // - resnet + // - resnet + // - norm + // + switch (i) { + case 0: + case 1: + case 3: + case 4: + { + layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); + layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); + + layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); + + layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); + layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); + + layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); + } break; + case 2: + { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + + layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); + + layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); + + layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); + + layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); + } break; + case 5: + { + layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + } break; + default: LM_GGML_ABORT("unknown posnet layer"); + }; + } + } + + LM_GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); + + model.tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0); + model.tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0); + + // convnext + { + const int64_t n_embd = hparams.convnext.n_embd; + + for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { + auto & layer = model.layers[i].convnext; + + layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); + layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); + + layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); + + layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); + layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); + + layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); + layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); + + layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); + } + + // output + model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + } + + model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); + model.output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + } break; default: throw std::runtime_error("unknown architecture"); } @@ -9201,7 +9756,7 @@ static bool llm_load_tensors( lm_ggml_backend_dev_t dev = lm_ggml_backend_buft_get_device(buft); if (!dev) { // FIXME: workaround for CPU backend buft having a NULL device - dev = lm_ggml_backend_reg_dev_get(lm_ggml_backend_cpu_reg(), 0); + dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU); } lm_ggml_backend_dev_props props; lm_ggml_backend_dev_get_props(dev, &props); @@ -9380,6 +9935,7 @@ enum llm_ffn_gate_type { enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, + LLM_NORM_GROUP, }; static struct lm_ggml_tensor * llm_build_inp_embd( @@ -9400,7 +9956,7 @@ static struct lm_ggml_tensor * llm_build_inp_embd( inpL = lm_ggml_get_rows(ctx, tok_embd, lctx.inp_tokens); } else { - lctx.inp_embd = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, n_embd, batch.n_tokens); + lctx.inp_embd = lm_ggml_new_tensor_2d(ctx, LM_GGML_TYPE_F32, n_embd, batch.n_tokens); inpL = lctx.inp_embd; lm_ggml_set_input(lctx.inp_embd); } @@ -9521,8 +10077,14 @@ static struct lm_ggml_tensor * llm_build_norm( const llm_build_cb & cb, int il) { switch (type) { - case LLM_NORM: cur = lm_ggml_norm (ctx, cur, hparams.f_norm_eps); break; - case LLM_NORM_RMS: cur = lm_ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break; + case LLM_NORM: cur = lm_ggml_norm (ctx, cur, hparams.f_norm_eps); break; + case LLM_NORM_RMS: cur = lm_ggml_rms_norm (ctx, cur, hparams.f_norm_rms_eps); break; + case LLM_NORM_GROUP: + { + cur = lm_ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]); + cur = lm_ggml_group_norm(ctx, cur, hparams.n_norm_groups, hparams.f_norm_group_eps); + cur = lm_ggml_reshape_2d(ctx, cur, cur->ne[0], cur->ne[2]); + } break; } if (mw || mb) { @@ -10861,6 +11423,167 @@ struct llm_build_context { return gf; } + struct lm_ggml_cgraph * build_deci() { + struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + LM_GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct lm_ggml_tensor * cur; + struct lm_ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct lm_ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + struct lm_ggml_tensor * inpSA = inpL; + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_head = hparams.n_head(il); + + if (n_head == 0) { + // attention-free layer of Llama-3_1-Nemotron-51B + cur = inpL; + } else { + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + } + + if (n_head > 0 && n_head_kv == 0) { + // "linear attention" of Llama-3_1-Nemotron-51B + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "wo", il); + } else if (n_head > 0) { + // self-attention + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct lm_ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = lm_ggml_rope_ext( + ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = lm_ggml_rope_ext( + ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + cur = lm_ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + // modified to support attention-free layer of Llama-3_1-Nemotron-51B + struct lm_ggml_tensor * ffn_inp = cur; + if (n_head > 0) { + ffn_inp = lm_ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + } + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + cur = lm_ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + cur = lm_ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + // For Granite architecture + if (hparams.f_logit_scale) { + cur = lm_ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + } + + cb(cur, "result_output", -1); + + lm_ggml_build_forward_expand(gf, cur); + + return gf; + } + struct lm_ggml_cgraph * build_baichuan() { struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -12489,12 +13212,8 @@ struct llm_build_context { return gf; } - struct lm_ggml_cgraph * build_qwen2moe() { + struct lm_ggml_cgraph * build_qwen2vl() { struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); - - // mutable variable, needed during the last layer of the computation to skip unused tokens - int32_t n_tokens = this->n_tokens; - const int64_t n_embd_head = hparams.n_embd_head_v; LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); LM_GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -12505,10 +13224,15 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); // inp_pos - contains the positions - struct lm_ggml_tensor * inp_pos = build_inp_pos(); + lctx.inp_pos = lm_ggml_new_tensor_1d(ctx0, LM_GGML_TYPE_I32, n_tokens * 4); + cb(lctx.inp_pos, "inp_pos", -1); + lm_ggml_set_input(lctx.inp_pos); + struct lm_ggml_tensor * inp_pos = lctx.inp_pos; // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask(); + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); for (int il = 0; il < n_layer; ++il) { struct lm_ggml_tensor * inpSA = inpL; @@ -12519,7 +13243,124 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - // self_attention + // self-attention + { + // compute Q and K and RoPE them + struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = lm_ggml_rope_multi( + ctx0, + lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = lm_ggml_rope_multi( + ctx0, + lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = lm_ggml_add(ctx0, cur, ffn_inp); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + lm_ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct lm_ggml_cgraph * build_qwen2moe() { + struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + LM_GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct lm_ggml_tensor * cur; + struct lm_ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct lm_ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct lm_ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention { // compute Q and K and RoPE them struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); @@ -12772,7 +13613,13 @@ struct llm_build_context { struct lm_ggml_tensor * inp_pos = build_inp_pos(); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct lm_ggml_tensor * KQ_mask_swa = build_inp_KQ_mask_swa(); + struct lm_ggml_tensor * KQ_mask = nullptr; + if (hparams.n_swa == 0) { + // Phi-4 doesn't use sliding window attention + KQ_mask = build_inp_KQ_mask(); + } else { + KQ_mask = build_inp_KQ_mask_swa(); + } for (int il = 0; il < n_layer; ++il) { auto residual = inpL; @@ -12830,7 +13677,7 @@ struct llm_build_context { cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask_swa, n_tokens, kv_head, n_kv, 1.0f, cb, il); + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il); } if (il == n_layer - 1) { @@ -13440,153 +14287,6 @@ struct llm_build_context { return gf; } - // ref: https://arxiv.org/abs/2203.03466 - // https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738 - // based on the original build_llama() function - struct lm_ggml_cgraph * build_minicpm() { - struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); - - const int64_t n_embd_head = hparams.n_embd_head_v; - LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - LM_GGML_ASSERT(n_embd_head == hparams.n_rot); - - const int64_t n_embd = hparams.n_embd; - //TODO: if the model varies, these parameters need to be read from the model - const int64_t n_embd_base = 256; - const float scale_embd = 12.0f; - const float scale_depth = 1.4f; - - struct lm_ggml_tensor * cur; - struct lm_ggml_tensor * inpL; - - inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); - - // scale the input embeddings - inpL = lm_ggml_scale(ctx0, inpL, scale_embd); - cb(inpL, "inp_scaled", -1); - - // inp_pos - contains the positions - struct lm_ggml_tensor * inp_pos = build_inp_pos(); - - // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask(); - - for (int il = 0; il < n_layer; ++il) { - struct lm_ggml_tensor * inpSA = inpL; - - // norm - cur = llm_build_norm(ctx0, inpL, hparams, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "attn_norm", il); - - // self-attention - { - // compute Q and K and RoPE them - struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = lm_ggml_rope_ext( - ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Qcur, "Qcur", il); - - Kcur = lm_ggml_rope_ext( - ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - cb(Kcur, "Kcur", il); - - cur = llm_build_kv(ctx0, lctx, kv_self, gf, - model.layers[il].wo, model.layers[il].bo, - Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); - } - - if (il == n_layer - 1) { - // skip computing output for unused tokens - struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids(); - cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - // scale_res - scale the hidden states for residual connection - const float scale_res = scale_depth/sqrtf(float(n_layer)); - cur = lm_ggml_scale(ctx0, cur, scale_res); - cb(cur, "hidden_scaled", -1); - - struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = llm_build_norm(ctx0, ffn_inp, hparams, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm", il); - - cur = llm_build_ffn(ctx0, lctx, cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, cb, il); - cb(cur, "ffn_out", il); - } - - // scale the hidden states for residual connection - cur = lm_ggml_scale(ctx0, cur, scale_res); - cb(cur, "hidden_scaled_ffn", -1); - - cur = lm_ggml_add(ctx0, cur, ffn_inp); - cur = lctx.cvec.apply_to(ctx0, cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = llm_build_norm(ctx0, cur, hparams, - model.output_norm, NULL, - LLM_NORM_RMS, cb, -1); - cb(cur, "result_norm", -1); - - // lm_head scaling - const float scale_lmhead = float(n_embd_base)/float(n_embd); - cur = lm_ggml_scale(ctx0, cur, scale_lmhead); - cb(cur, "lmhead_scaling", -1); - - // lm_head - cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); - cb(cur, "result_output", -1); - - lm_ggml_build_forward_expand(gf, cur); - - return gf; - } - struct lm_ggml_cgraph * build_minicpm3() { struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); @@ -14492,7 +15192,7 @@ struct llm_build_context { return gf; } - struct lm_ggml_cgraph * build_olmo_1124() { + struct lm_ggml_cgraph * build_olmo2() { struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens @@ -15104,22 +15804,176 @@ struct llm_build_context { cur = llm_build_norm(ctx0, inpSA, hparams, model.layers[il].ffn_norm_exps, NULL, LLM_NORM_RMS, cb, il); - cb(cur, "ffn_norm_exps", il); + cb(cur, "ffn_norm_exps", il); + + cur = llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + cb, il); + cb(cur, "ffn_moe_out", il); + + cur = lm_ggml_add(ctx0, cur, ffn_out); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + lm_ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct lm_ggml_cgraph * build_deepseek() { + struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + LM_GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + LM_GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct lm_ggml_tensor * cur; + struct lm_ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct lm_ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct lm_ggml_tensor * KQ_mask = build_inp_KQ_mask(); + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + for (int il = 0; il < n_layer; ++il) { + struct lm_ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + struct lm_ggml_tensor * rope_factors = build_rope_factors(il); + + // compute Q and K and RoPE them + struct lm_ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = lm_ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + struct lm_ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = lm_ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + struct lm_ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = lm_ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = lm_ggml_rope_ext( + ctx0, lm_ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = lm_ggml_rope_ext( + ctx0, lm_ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct lm_ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = lm_ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = lm_ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + + struct lm_ggml_tensor * ffn_inp = lm_ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); - cur = llm_build_moe_ffn(ctx0, lctx, cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - false, 0.0, - cb, il); - cb(cur, "ffn_moe_out", il); + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + lm_ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, hparams.expert_weights_scale, + cb, il); + cb(moe_out, "ffn_moe_out", il); - cur = lm_ggml_add(ctx0, cur, ffn_out); - cb(cur, "ffn_out", il); + // FFN shared expert + { + lm_ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = lm_ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + cur = lm_ggml_add(ctx0, cur, ffn_inp); cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); @@ -15136,6 +15990,7 @@ struct llm_build_context { // lm_head cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); lm_ggml_build_forward_expand(gf, cur); @@ -15522,7 +16377,7 @@ struct llm_build_context { return gf; } - struct lm_ggml_cgraph * build_t5_encoder() { + struct lm_ggml_cgraph * build_t5_enc() { struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens @@ -15654,7 +16509,7 @@ struct llm_build_context { return gf; } - struct lm_ggml_cgraph * build_t5_decoder() { + struct lm_ggml_cgraph * build_t5_dec() { struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // mutable variable, needed during the last layer of the computation to skip unused tokens @@ -16603,6 +17458,158 @@ struct llm_build_context { return gf; } + + struct lm_ggml_cgraph * build_wavtokenizer_dec() { + struct lm_ggml_cgraph * gf = lm_ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + struct lm_ggml_tensor * cur; + struct lm_ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + cur = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, inpL)); + + cur = lm_ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1); + cur = lm_ggml_add(ctx0, cur, model.conv1d_b); + + // posnet + for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) { + const auto & layer = model.layers[il].posnet; + + inpL = cur; + + switch (il) { + case 0: + case 1: + case 3: + case 4: + { + cur = llm_build_norm(ctx0, cur, hparams, + layer.norm1, + layer.norm1_b, + LLM_NORM_GROUP, cb, 0); + + cur = lm_ggml_mul(ctx0, lm_ggml_sigmoid(ctx0, cur), cur); + + cur = lm_ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1); + cur = lm_ggml_add(ctx0, cur, layer.conv1_b); + + cur = llm_build_norm(ctx0, cur, hparams, + layer.norm2, + layer.norm2_b, + LLM_NORM_GROUP, cb, 0); + + cur = lm_ggml_mul(ctx0, lm_ggml_sigmoid(ctx0, cur), cur); + + cur = lm_ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1); + cur = lm_ggml_add(ctx0, cur, layer.conv2_b); + + cur = lm_ggml_add(ctx0, cur, inpL); + } break; + case 2: + { + cur = llm_build_norm(ctx0, cur, hparams, + layer.attn_norm, + layer.attn_norm_b, + LLM_NORM_GROUP, cb, 0); + + struct lm_ggml_tensor * q; + struct lm_ggml_tensor * k; + struct lm_ggml_tensor * v; + + q = lm_ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1); + k = lm_ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1); + v = lm_ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1); + + q = lm_ggml_add(ctx0, q, layer.attn_q_b); + k = lm_ggml_add(ctx0, k, layer.attn_k_b); + v = lm_ggml_add(ctx0, v, layer.attn_v_b); + + q = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, q)); + k = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, k)); + + struct lm_ggml_tensor * kq = lm_ggml_mul_mat(ctx0, k, q); + + kq = lm_ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f); + + cur = lm_ggml_mul_mat(ctx0, kq, v); + + cur = lm_ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1); + cur = lm_ggml_add(ctx0, cur, layer.attn_o_b); + + cur = lm_ggml_add(ctx0, cur, inpL); + } break; + case 5: + { + cur = llm_build_norm(ctx0, cur, hparams, + layer.norm, + layer.norm_b, + LLM_NORM_GROUP, cb, 0); + } break; + default: LM_GGML_ABORT("unknown posnet layer"); + }; + } + + cur = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, cur)); + + cur = llm_build_norm(ctx0, cur, hparams, + model.tok_norm, + model.tok_norm_b, + LLM_NORM, cb, -1); + + cur = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, cur)); + + inpL = cur; + + // convnext + for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) { + const auto & layer = model.layers[il].convnext; + + cur = inpL; + + cur = lm_ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1); + cur = lm_ggml_add(ctx0, cur, layer.dw_b); + + cur = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, cur)); + + cur = llm_build_norm(ctx0, cur, hparams, + layer.norm, + layer.norm_b, + LLM_NORM, cb, -1); + + cur = llm_build_ffn(ctx0, lctx, cur, + layer.pw1, layer.pw1_b, NULL, + NULL, NULL, NULL, + layer.pw2, layer.pw2_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + + cur = lm_ggml_mul(ctx0, cur, layer.gamma); + + cur = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, cur)); + + inpL = lm_ggml_add(ctx0, cur, inpL); + } + + cur = inpL; + + cur = lm_ggml_cont(ctx0, lm_ggml_transpose(ctx0, cur)); + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, + model.output_norm_b, + LLM_NORM, cb, -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + cur = lm_ggml_add(ctx0, cur, model.output_b); + cb(cur, "result_embd", -1); + + lm_ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct lm_ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -16685,11 +17692,16 @@ static struct lm_ggml_cgraph * llama_build_graph( switch (model.arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_MINICPM: case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: { result = llm.build_llama(); } break; + case LLM_ARCH_DECI: + { + result = llm.build_deci(); + } break; case LLM_ARCH_BAICHUAN: { result = llm.build_baichuan(); @@ -16736,6 +17748,11 @@ static struct lm_ggml_cgraph * llama_build_graph( { result = llm.build_qwen2(); } break; + case LLM_ARCH_QWEN2VL: + { + lctx.n_pos_per_token = 4; + result = llm.build_qwen2vl(); + } break; case LLM_ARCH_QWEN2MOE: { result = llm.build_qwen2moe(); @@ -16768,10 +17785,6 @@ static struct lm_ggml_cgraph * llama_build_graph( { result = llm.build_internlm2(); } break; - case LLM_ARCH_MINICPM: - { - result = llm.build_minicpm(); - } break; case LLM_ARCH_MINICPM3: { result = llm.build_minicpm3(); @@ -16808,9 +17821,9 @@ static struct lm_ggml_cgraph * llama_build_graph( { result = llm.build_olmo(); } break; - case LLM_ARCH_OLMO_1124: + case LLM_ARCH_OLMO2: { - result = llm.build_olmo_1124(); + result = llm.build_olmo2(); } break; case LLM_ARCH_OLMOE: { @@ -16828,6 +17841,10 @@ static struct lm_ggml_cgraph * llama_build_graph( { result = llm.build_arctic(); } break; + case LLM_ARCH_DEEPSEEK: + { + result = llm.build_deepseek(); + } break; case LLM_ARCH_DEEPSEEK2: { result = llm.build_deepseek2(); @@ -16843,14 +17860,14 @@ static struct lm_ggml_cgraph * llama_build_graph( case LLM_ARCH_T5: { if (lctx.is_encoding) { - result = llm.build_t5_encoder(); + result = llm.build_t5_enc(); } else { - result = llm.build_t5_decoder(); + result = llm.build_t5_dec(); } } break; case LLM_ARCH_T5ENCODER: { - result = llm.build_t5_encoder(); + result = llm.build_t5_enc(); } break; case LLM_ARCH_JAIS: { @@ -16872,6 +17889,10 @@ static struct lm_ggml_cgraph * llama_build_graph( { result = llm.build_chameleon(); } break; + case LLM_ARCH_WAVTOKENIZER_DEC: + { + result = llm.build_wavtokenizer_dec(); + } break; default: LM_GGML_ABORT("fatal error"); } @@ -16958,35 +17979,40 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) if (ubatch.pos && lctx.inp_pos) { const int64_t n_tokens = ubatch.n_tokens; - - lm_ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*lm_ggml_element_size(lctx.inp_pos)); + auto n_pos = lctx.n_pos_per_token; + lm_ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*n_pos*lm_ggml_element_size(lctx.inp_pos)); } if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { - LM_GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); - const int64_t n_tokens = ubatch.n_tokens; + //LM_GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); + + if (!lctx.inp_out_ids) { + LLAMA_LOG_WARN("%s: 'lctx.inp_out_ids' is not created\n", __func__); + } else { + const int64_t n_tokens = ubatch.n_tokens; - LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); - int32_t * data = (int32_t *) lctx.inp_out_ids->data; + LM_GGML_ASSERT(lm_ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer)); + int32_t * data = (int32_t *) lctx.inp_out_ids->data; - if (lctx.n_outputs == n_tokens) { - for (int i = 0; i < n_tokens; ++i) { - data[i] = i; - } - } else if (ubatch.output) { - int32_t n_outputs = 0; - for (int i = 0; i < n_tokens; ++i) { - if (ubatch.output[i]) { - data[n_outputs++] = i; + if (lctx.n_outputs == n_tokens) { + for (int i = 0; i < n_tokens; ++i) { + data[i] = i; + } + } else if (ubatch.output) { + int32_t n_outputs = 0; + for (int i = 0; i < n_tokens; ++i) { + if (ubatch.output[i]) { + data[n_outputs++] = i; + } } + // the graph needs to have been passed the correct number of outputs + LM_GGML_ASSERT(lctx.n_outputs == n_outputs); + } else if (lctx.n_outputs == 1) { + // only keep last output + data[0] = n_tokens - 1; + } else { + LM_GGML_ASSERT(lctx.n_outputs == 0); } - // the graph needs to have been passed the correct number of outputs - LM_GGML_ASSERT(lctx.n_outputs == n_outputs); - } else if (lctx.n_outputs == 1) { - // only keep last output - data[0] = n_tokens - 1; - } else { - LM_GGML_ASSERT(lctx.n_outputs == 0); } } @@ -17454,8 +18480,9 @@ static enum lm_ggml_status llama_graph_compute( int n_threads, lm_ggml_threadpool * threadpool) { if (lctx.backend_cpu != nullptr) { - lm_ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool); - lm_ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data); + auto * reg = lm_ggml_backend_dev_backend_reg(lm_ggml_backend_get_device(lctx.backend_cpu)); + auto * set_threadpool_fn = (decltype(lm_ggml_backend_cpu_set_threadpool) *) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_cpu_set_threadpool"); + set_threadpool_fn(lctx.backend_cpu, threadpool); } // set the number of threads for all the backends @@ -17656,6 +18683,7 @@ static int llama_decode_internal( embd = nullptr; // do not extract embeddings when not needed LM_GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); } + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (lm_ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); lm_ggml_backend_sched_alloc_graph(lctx.sched.get(), gf); @@ -18222,13 +19250,13 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { static void llama_kv_cache_update_internal(struct llama_context & lctx) { bool need_reserve = false; - // apply K-shift if needed - if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { + if (lctx.kv_self.has_shift) { if (!llama_kv_cache_can_shift(&lctx)) { - LM_GGML_ABORT("Deepseek2 does not support K-shift"); + LM_GGML_ABORT("The current context does not support K-shift"); } - { + // apply K-shift if needed + if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { lm_ggml_backend_sched_reset(lctx.sched.get()); lm_ggml_cgraph * gf = llama_build_graph_k_shift(lctx); @@ -18443,10 +19471,6 @@ static lm_ggml_type llama_tensor_get_type(quantize_state_internal & qs, lm_ggml_ else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = LM_GGML_TYPE_IQ3_S; } - else if (new_type == LM_GGML_TYPE_Q4_0_4_4 || new_type == LM_GGML_TYPE_Q4_0_4_8 || - new_type == LM_GGML_TYPE_Q4_0_8_8) { - new_type = LM_GGML_TYPE_Q4_0; - } else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) { new_type = LM_GGML_TYPE_Q4_K; } @@ -18769,9 +19793,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = LM_GGML_TYPE_IQ4_XS; break; case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = LM_GGML_TYPE_IQ3_S; break; case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = LM_GGML_TYPE_IQ3_S; break; - case LLAMA_FTYPE_MOSTLY_Q4_0_4_4: default_type = LM_GGML_TYPE_Q4_0_4_4; break; - case LLAMA_FTYPE_MOSTLY_Q4_0_4_8: default_type = LM_GGML_TYPE_Q4_0_4_8; break; - case LLAMA_FTYPE_MOSTLY_Q4_0_8_8: default_type = LM_GGML_TYPE_Q4_0_8_8; break; default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } @@ -19110,14 +20131,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s f32_data = (float *) f32_conv_buf.data(); } - int chunk_size_multiplier = 1; - if (new_type == LM_GGML_TYPE_Q4_0_4_4 || new_type == LM_GGML_TYPE_Q4_0_4_8 || new_type == LM_GGML_TYPE_Q4_0_8_8) { - if ((new_type == LM_GGML_TYPE_Q4_0_8_8) && (tensor->ne[1] % 8 != 0)) new_type = LM_GGML_TYPE_Q4_0; - else if (tensor->ne[1] % 4 != 0) new_type = LM_GGML_TYPE_Q4_0; - if (new_type == LM_GGML_TYPE_Q4_0_8_8) chunk_size_multiplier = 8; - else if (new_type == LM_GGML_TYPE_Q4_0_4_4 || new_type == LM_GGML_TYPE_Q4_0_4_8) chunk_size_multiplier = 4; - } - LLAMA_LOG_INFO("converting to %s .. ", lm_ggml_type_name(new_type)); fflush(stdout); @@ -19130,8 +20143,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const int64_t nrows = tensor->ne[1]; static const int64_t min_chunk_size = 32 * 512; - const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)) * - chunk_size_multiplier; + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; @@ -19372,6 +20384,7 @@ void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { // struct llama_model_params llama_model_default_params() { struct llama_model_params result = { + /*.devices =*/ nullptr, /*.n_gpu_layers =*/ 0, /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, @@ -19489,7 +20502,11 @@ void llama_backend_init(void) { void llama_numa_init(enum lm_ggml_numa_strategy numa) { if (numa != LM_GGML_NUMA_STRATEGY_DISABLED) { - lm_ggml_numa_init(numa); + auto * dev = lm_ggml_backend_dev_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU); + LM_GGML_ASSERT(dev && "CPU backend is not loaded"); + auto * reg = lm_ggml_backend_dev_backend_reg(dev); + auto * numa_init_fn = (decltype(lm_ggml_numa_init) *) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_cpu_numa_init"); + numa_init_fn(numa); } } @@ -19580,19 +20597,24 @@ struct llama_model * llama_load_model_from_file( } // create list of devices to use with this model - // currently, we use all available devices - // TODO: rework API to give user more control over device selection - for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) { - lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); - switch (lm_ggml_backend_dev_type(dev)) { - case LM_GGML_BACKEND_DEVICE_TYPE_CPU: - case LM_GGML_BACKEND_DEVICE_TYPE_ACCEL: - // skip CPU backends since they are handled separately - break; + if (params.devices) { + for (lm_ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { + model->devices.push_back(*dev); + } + } else { + // use all available devices + for (size_t i = 0; i < lm_ggml_backend_dev_count(); ++i) { + lm_ggml_backend_dev_t dev = lm_ggml_backend_dev_get(i); + switch (lm_ggml_backend_dev_type(dev)) { + case LM_GGML_BACKEND_DEVICE_TYPE_CPU: + case LM_GGML_BACKEND_DEVICE_TYPE_ACCEL: + // skip CPU backends since they are handled separately + break; - case LM_GGML_BACKEND_DEVICE_TYPE_GPU: - model->devices.push_back(dev); - break; + case LM_GGML_BACKEND_DEVICE_TYPE_GPU: + model->devices.push_back(dev); + break; + } } } @@ -19763,9 +20785,6 @@ struct llama_context * llama_new_context_with_model( __func__, n_ctx_per_seq, hparams.n_ctx_train); } - ctx->abort_callback = params.abort_callback; - ctx->abort_callback_data = params.abort_callback_data; - ctx->logits_all = params.logits_all; // build worst-case graph for encoder if a model contains encoder @@ -19814,7 +20833,7 @@ struct llama_context * llama_new_context_with_model( } // add CPU backend - ctx->backend_cpu = lm_ggml_backend_cpu_init(); + ctx->backend_cpu = lm_ggml_backend_init_by_type(LM_GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); if (ctx->backend_cpu == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__); llama_free(ctx); @@ -19834,6 +20853,8 @@ struct llama_context * llama_new_context_with_model( } } + llama_set_abort_callback(ctx, params.abort_callback, params.abort_callback_data); + if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); @@ -19879,7 +20900,8 @@ struct llama_context * llama_new_context_with_model( std::vector backend_ptrs; for (auto & backend : ctx->backends) { auto * buft = lm_ggml_backend_get_default_buffer_type(backend.get()); - if (lm_ggml_backend_is_cpu(backend.get()) && !model->devices.empty()) { + auto backend_type = lm_ggml_backend_dev_type(lm_ggml_backend_get_device(backend.get())); + if (backend_type == LM_GGML_BACKEND_DEVICE_TYPE_CPU && !model->devices.empty()) { // use the host buffer of the first device CPU for faster transfer of the intermediate state auto * dev = model->devices[0]; auto * host_buft = lm_ggml_backend_dev_host_buffer_type(dev); @@ -19907,7 +20929,8 @@ struct llama_context * llama_new_context_with_model( // pipeline parallelism requires support for async compute and events in all devices if (pipeline_parallel) { for (auto & backend : ctx->backends) { - if (lm_ggml_backend_is_cpu(backend.get())) { + auto dev_type = lm_ggml_backend_dev_type(lm_ggml_backend_get_device(backend.get())); + if (dev_type == LM_GGML_BACKEND_DEVICE_TYPE_CPU) { // ignore CPU backend continue; } @@ -20049,10 +21072,12 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_T5ENCODER: case LLM_ARCH_JAIS: case LLM_ARCH_RWKV6: + case LLM_ARCH_WAVTOKENIZER_DEC: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: case LLM_ARCH_PLAMO: @@ -20063,6 +21088,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_COMMAND_R: case LLM_ARCH_OLMO: case LLM_ARCH_ARCTIC: + case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: @@ -20081,7 +21107,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2MOE: - case LLM_ARCH_OLMO_1124: + case LLM_ARCH_OLMO2: case LLM_ARCH_OLMOE: case LLM_ARCH_PHI2: case LLM_ARCH_PHI3: @@ -20096,6 +21122,9 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_MINICPM3: return LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_QWEN2VL: + return LLAMA_ROPE_TYPE_MROPE; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: LM_GGML_ABORT("unknown architecture"); @@ -20162,17 +21191,6 @@ uint64_t llama_model_n_params(const struct llama_model * model) { return model->n_elements; } -struct lm_ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name) { - auto it = std::find_if(model->tensors_by_name.begin(), model->tensors_by_name.end(), - [name](const std::pair & it) { - return it.first == name; - }); - if (it == model->tensors_by_name.end()) { - return nullptr; - } - return it->second; -} - bool llama_model_has_encoder(const struct llama_model * model) { switch (model->arch) { case LLM_ARCH_T5: return true; @@ -20474,7 +21492,7 @@ void llama_kv_cache_update(struct llama_context * ctx) { } bool llama_kv_cache_can_shift(struct llama_context * ctx) { - return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA + return !ctx->kv_self.recurrent && ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA } // deprecated @@ -21461,6 +22479,14 @@ int32_t llama_n_threads_batch(struct llama_context * ctx) { void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { ctx->abort_callback = abort_callback; ctx->abort_callback_data = abort_callback_data; + + for (auto & backend : ctx->backends) { + auto * reg = lm_ggml_backend_dev_backend_reg(lm_ggml_backend_get_device(backend.get())); + auto * set_abort_callback_fn = (lm_ggml_backend_set_abort_callback_t) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), ctx->abort_callback, ctx->abort_callback_data); + } + } } void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { @@ -21656,7 +22682,7 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); } } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + throw std::runtime_error(format("out of range [0, %zu)", ctx->output_ids.size())); } else { j = ctx->output_ids[i]; } @@ -21827,18 +22853,115 @@ int32_t llama_detokenize( // chat templates // +static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { + if (LLM_CHAT_TEMPLATES.find(tmpl) != LLM_CHAT_TEMPLATES.end()) { + return LLM_CHAT_TEMPLATES.at(tmpl); + } + auto tmpl_contains = [&tmpl](const char * haystack) -> bool { + return tmpl.find(haystack) != std::string::npos; + }; + if (tmpl_contains("<|im_start|>")) { + return LLM_CHAT_TEMPLATE_CHATML; + } else if (tmpl.find("mistral") == 0 || tmpl_contains("[INST]")) { + if (tmpl_contains("[SYSTEM_PROMPT]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V7; + } else if ( + // catches official 'v1' template + tmpl_contains("' [INST] ' + system_message") + // catches official 'v3' and 'v3-tekken' templates + || tmpl_contains("[AVAILABLE_TOOLS]") + ) { + // Official mistral 'v1', 'v3' and 'v3-tekken' templates + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + if (tmpl_contains(" [INST]")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V1; + } else if (tmpl_contains("\"[INST]\"")) { + return LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN; + } + return LLM_CHAT_TEMPLATE_MISTRAL_V3; + } else { + // llama2 template and its variants + // [variant] support system message + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl_contains("<>"); + bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + bool strip_message = tmpl_contains("content.strip()"); + if (strip_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; + } else if (add_bos_inside_history) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; + } else if (support_system_message) { + return LLM_CHAT_TEMPLATE_LLAMA_2_SYS; + } else { + return LLM_CHAT_TEMPLATE_LLAMA_2; + } + } + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>")) { + return LLM_CHAT_TEMPLATE_PHI_3; + } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + return LLM_CHAT_TEMPLATE_FALCON_3; + } else if (tmpl_contains("<|user|>") && tmpl_contains("<|endoftext|>")) { + return LLM_CHAT_TEMPLATE_ZEPHYR; + } else if (tmpl_contains("bos_token + message['role']")) { + return LLM_CHAT_TEMPLATE_MONARCH; + } else if (tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_GEMMA; + } else if (tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + // OrionStarAI/Orion-14B-Chat + return LLM_CHAT_TEMPLATE_ORION; + } else if (tmpl_contains("GPT4 Correct ")) { + // openchat/openchat-3.5-0106 + return LLM_CHAT_TEMPLATE_OPENCHAT; + } else if (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: ")) { + // eachadea/vicuna-13b-1.1 (and Orca variant) + if (tmpl_contains("SYSTEM: ")) { + return LLM_CHAT_TEMPLATE_VICUNA_ORCA; + } + return LLM_CHAT_TEMPLATE_VICUNA; + } else if (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>")) { + // deepseek-ai/deepseek-coder-33b-instruct + return LLM_CHAT_TEMPLATE_DEEPSEEK; + } else if (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>")) { + // CohereForAI/c4ai-command-r-plus + return LLM_CHAT_TEMPLATE_COMMAND_R; + } else if (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>")) { + return LLM_CHAT_TEMPLATE_LLAMA_3; + } else if (tmpl_contains("[gMASK]sop")) { + // chatglm3-6b + return LLM_CHAT_TEMPLATE_CHATGML_3; + } else if (tmpl_contains("[gMASK]")) { + return LLM_CHAT_TEMPLATE_CHATGML_4; + } else if (tmpl_contains(LU8("<用户>"))) { + // MiniCPM-3B-OpenHermes-2.5-v2-GGUF + return LLM_CHAT_TEMPLATE_MINICPM; + } else if (tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + return LLM_CHAT_TEMPLATE_DEEPSEEK_2; + } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { + // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb + // EXAONE-3.0-7.8B-Instruct + return LLM_CHAT_TEMPLATE_EXAONE_3; + } else if (tmpl_contains("rwkv-world")) { + return LLM_CHAT_TEMPLATE_RWKV_WORLD; + } else if (tmpl_contains("<|start_of_role|>")) { + return LLM_CHAT_TEMPLATE_GRANITE; + } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { + return LLM_CHAT_TEMPLATE_GIGACHAT; + } else if (tmpl_contains("<|role_start|>")) { + return LLM_CHAT_TEMPLATE_MEGREZ; + } + return LLM_CHAT_TEMPLATE_UNKNOWN; +} + // Simple version of "llama_apply_chat_template" that only works with strings // This function uses heuristic checks to determine commonly used template. It is not a jinja parser. static int32_t llama_chat_apply_template_internal( - const std::string & tmpl, + const llm_chat_template tmpl, const std::vector & chat, std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; - auto tmpl_contains = [&tmpl](std::string haystack) -> bool { - return tmpl.find(haystack) != std::string::npos; - }; - if (tmpl == "chatml" || tmpl_contains("<|im_start|>")) { + if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template for (auto message : chat) { ss << "<|im_start|>" << message->role << "\n" << message->content << "<|im_end|>\n"; @@ -21846,16 +22969,59 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|im_start|>assistant\n"; } - } else if (tmpl == "llama2" || tmpl == "mistral" || tmpl_contains("[INST]")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V7) { + // Official mistral 'v7' template + // See: https://huggingface.co/mistralai/Mistral-Large-Instruct-2411#basic-instruct-template-v7 + for (auto message : chat) { + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << "[SYSTEM_PROMPT] " << content << "[/SYSTEM_PROMPT]"; + } else if (role == "user") { + ss << "[INST] " << content << "[/INST]"; + } + else { + ss << " " << content << ""; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3 + || tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN) { + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/chat_templates.md + // See: https://github.com/mistralai/cookbook/blob/main/concept-deep-dive/tokenization/templates.md + std::string leading_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V1 ? " " : ""; + std::string trailing_space = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3_TEKKEN ? "" : " "; + bool trim_assistant_message = tmpl == LLM_CHAT_TEMPLATE_MISTRAL_V3; + bool is_inside_turn = false; + for (auto message : chat) { + if (!is_inside_turn) { + ss << leading_space << "[INST]" << trailing_space; + is_inside_turn = true; + } + std::string role(message->role); + std::string content(message->content); + if (role == "system") { + ss << content << "\n\n"; + } else if (role == "user") { + ss << content << leading_space << "[/INST]"; + } else { + ss << trailing_space << (trim_assistant_message ? trim(content) : content) << ""; + is_inside_turn = false; + } + } + } else if ( + tmpl == LLM_CHAT_TEMPLATE_LLAMA_2 + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS + || tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP) { // llama2 template and its variants // [variant] support system message - bool support_system_message = tmpl_contains("<>") || tmpl == "mistral"; - // [variant] space before + after response - bool space_around_response = tmpl_contains("' ' + eos_token"); + // See: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 + bool support_system_message = tmpl != LLM_CHAT_TEMPLATE_LLAMA_2; // [variant] add BOS inside history - bool add_bos_inside_history = tmpl_contains("bos_token + '[INST]"); + bool add_bos_inside_history = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_BOS; // [variant] trim spaces from the input message - bool strip_message = tmpl_contains("content.strip()"); + bool strip_message = tmpl == LLM_CHAT_TEMPLATE_LLAMA_2_SYS_STRIP; // construct the prompt bool is_inside_turn = true; // skip BOS at the beginning ss << "[INST] "; @@ -21876,12 +23042,11 @@ static int32_t llama_chat_apply_template_internal( } else if (role == "user") { ss << content << " [/INST]"; } else { - ss << (space_around_response ? " " : "") << content << (space_around_response ? " " : "") << ""; + ss << content << ""; is_inside_turn = false; } } - // llama2 templates seem to not care about "add_generation_prompt" - } else if (tmpl == "phi3" || (tmpl_contains("<|assistant|>") && tmpl_contains("<|end|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_PHI_3) { // Phi 3 for (auto message : chat) { std::string role(message->role); @@ -21890,7 +23055,16 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl == "zephyr" || tmpl_contains("<|user|>")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_FALCON_3) { + // Falcon 3 + for (auto message : chat) { + std::string role(message->role); + ss << "<|" << role << "|>\n" << message->content << "\n"; + } + if (add_ass) { + ss << "<|assistant|>\n"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_ZEPHYR) { // zephyr template for (auto message : chat) { ss << "<|" << message->role << "|>" << "\n" << message->content << "<|endoftext|>\n"; @@ -21898,7 +23072,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>\n"; } - } else if (tmpl == "monarch" || tmpl_contains("bos_token + message['role']")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MONARCH) { // mlabonne/AlphaMonarch-7B template (the is included inside history) for (auto message : chat) { std::string bos = (message == chat.front()) ? "" : ""; // skip BOS for first message @@ -21907,7 +23081,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "assistant\n"; } - } else if (tmpl == "gemma" || tmpl == "gemma2" || tmpl_contains("")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_GEMMA) { // google/gemma-7b-it std::string system_prompt = ""; for (auto message : chat) { @@ -21929,7 +23103,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "model\n"; } - } else if (tmpl == "orion" || tmpl_contains("'\\n\\nAssistant: ' + eos_token")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_ORION) { // OrionStarAI/Orion-14B-Chat std::string system_prompt = ""; for (auto message : chat) { @@ -21949,7 +23123,7 @@ static int32_t llama_chat_apply_template_internal( ss << message->content << ""; } } - } else if (tmpl == "openchat" || tmpl_contains("GPT4 Correct ")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_OPENCHAT) { // openchat/openchat-3.5-0106, for (auto message : chat) { std::string role(message->role); @@ -21963,13 +23137,13 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "GPT4 Correct Assistant:"; } - } else if (tmpl == "vicuna" || tmpl == "vicuna-orca" || (tmpl_contains("USER: ") && tmpl_contains("ASSISTANT: "))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_VICUNA || tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { // eachadea/vicuna-13b-1.1 (and Orca variant) for (auto message : chat) { std::string role(message->role); if (role == "system") { // Orca-Vicuna variant uses a system prefix - if (tmpl == "vicuna-orca" || tmpl_contains("SYSTEM: ")) { + if (tmpl == LLM_CHAT_TEMPLATE_VICUNA_ORCA) { ss << "SYSTEM: " << message->content << "\n"; } else { ss << message->content << "\n\n"; @@ -21983,7 +23157,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "ASSISTANT:"; } - } else if (tmpl == "deepseek" || (tmpl_contains("### Instruction:") && tmpl_contains("<|EOT|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK) { // deepseek-ai/deepseek-coder-33b-instruct for (auto message : chat) { std::string role(message->role); @@ -21998,7 +23172,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "### Response:\n"; } - } else if (tmpl == "command-r" || (tmpl_contains("<|START_OF_TURN_TOKEN|>") && tmpl_contains("<|USER_TOKEN|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_COMMAND_R) { // CohereForAI/c4ai-command-r-plus for (auto message : chat) { std::string role(message->role); @@ -22013,7 +23187,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; } - } else if (tmpl == "llama3" || (tmpl_contains("<|start_header_id|>") && tmpl_contains("<|end_header_id|>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA_3) { // Llama 3 for (auto message : chat) { std::string role(message->role); @@ -22022,7 +23196,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_header_id|>assistant<|end_header_id|>\n\n"; } - } else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_3) { // chatglm3-6b ss << "[gMASK]" << "sop"; for (auto message : chat) { @@ -22032,7 +23206,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "chatglm4" || tmpl_contains("[gMASK]")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_CHATGML_4) { ss << "[gMASK]" << ""; for (auto message : chat) { std::string role(message->role); @@ -22041,7 +23215,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|assistant|>"; } - } else if (tmpl == "minicpm" || tmpl_contains(LU8("<用户>"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_MINICPM) { // MiniCPM-3B-OpenHermes-2.5-v2-GGUF for (auto message : chat) { std::string role(message->role); @@ -22053,7 +23227,7 @@ static int32_t llama_chat_apply_template_internal( ss << trim(message->content); } } - } else if (tmpl == "deepseek2" || tmpl_contains("'Assistant: ' + message['content'] + eos_token")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_2) { // DeepSeek-V2 for (auto message : chat) { std::string role(message->role); @@ -22068,7 +23242,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "Assistant:"; } - } else if (tmpl == "exaone3" || (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]"))) { + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct for (auto message : chat) { @@ -22084,7 +23258,7 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "[|assistant|]"; } - } else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (auto message : chat) { std::string role(message->role); @@ -22094,7 +23268,7 @@ static int32_t llama_chat_apply_template_internal( ss << message->content << "\n\n"; } } - } else if (tmpl == "granite" || tmpl_contains("<|start_of_role|>")) { + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { // IBM Granite template for (const auto & message : chat) { std::string role(message->role); @@ -22107,6 +23281,42 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>\n"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { + // GigaChat template + bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; + + // Handle system message if present + if (has_system) { + ss << "" << chat[0]->content << "<|message_sep|>"; + } else { + ss << ""; + } + + // Process remaining messages + for (size_t i = has_system ? 1 : 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "user") { + ss << "user<|role_sep|>" << chat[i]->content << "<|message_sep|>" + << "available functions<|role_sep|>[]<|message_sep|>"; + } else if (role == "assistant") { + ss << "assistant<|role_sep|>" << chat[i]->content << "<|message_sep|>"; + } + } + + // Add generation prompt if needed + if (add_ass) { + ss << "assistant<|role_sep|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_MEGREZ) { + // Megrez template + for (auto message : chat) { + std::string role(message->role); + ss << "<|role_start|>" << role << "<|role_end|>" << message->content << "<|turn_end|>"; + } + + if (add_ass) { + ss << "<|role_start|>assistant<|role_end|>"; + } } else { // template not supported return -1; @@ -22126,15 +23336,15 @@ int32_t llama_chat_apply_template( std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); if (tmpl == nullptr) { LM_GGML_ASSERT(model != nullptr); - // load template from model - std::vector model_template(2048, 0); // longest known template is about 1200 bytes - std::string template_key = "tokenizer.chat_template"; - int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size()); - if (res < 0) { + + // load template from model, if available + const auto & it = model->lm_gguf_kv.find("tokenizer.chat_template"); + if (it != model->lm_gguf_kv.end() && it->second.size() > 0) { + curr_tmpl = it->second; + } + else { // worst case: there is no information about template, we will use chatml by default - curr_tmpl = "chatml"; // see llama_chat_apply_template_internal - } else { - curr_tmpl = std::string(model_template.data(), model_template.size()); + curr_tmpl = "chatml"; // see llama_chat_apply_template_internal } } @@ -22146,7 +23356,11 @@ int32_t llama_chat_apply_template( } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); + llm_chat_template detected_tmpl = llama_chat_detect_template(curr_tmpl); + if (detected_tmpl == LLM_CHAT_TEMPLATE_UNKNOWN) { + return -1; + } + int32_t res = llama_chat_apply_template_internal(detected_tmpl, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } @@ -22156,6 +23370,15 @@ int32_t llama_chat_apply_template( return res; } +int32_t llama_chat_builtin_templates(const char ** output, size_t len) { + auto it = LLM_CHAT_TEMPLATES.begin(); + for (size_t i = 0; i < std::min(len, LLM_CHAT_TEMPLATES.size()); i++) { + output[i] = it->first.c_str(); + std::advance(it, 1); + } + return (int32_t) LLM_CHAT_TEMPLATES.size(); +} + // // sampling // @@ -22202,32 +23425,23 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int } const char * llama_print_system_info(void) { - lm_ggml_cpu_init(); // some ARM features are detected at runtime - static std::string s; - s = ""; - s += "AVX = " + std::to_string(lm_ggml_cpu_has_avx()) + " | "; - s += "AVX_VNNI = " + std::to_string(lm_ggml_cpu_has_avx_vnni()) + " | "; - s += "AVX2 = " + std::to_string(lm_ggml_cpu_has_avx2()) + " | "; - s += "AVX512 = " + std::to_string(lm_ggml_cpu_has_avx512()) + " | "; - s += "AVX512_VBMI = " + std::to_string(lm_ggml_cpu_has_avx512_vbmi()) + " | "; - s += "AVX512_VNNI = " + std::to_string(lm_ggml_cpu_has_avx512_vnni()) + " | "; - s += "AVX512_BF16 = " + std::to_string(lm_ggml_cpu_has_avx512_bf16()) + " | "; - s += "AMX_INT8 = " + std::to_string(lm_ggml_cpu_has_amx_int8()) + " | "; - s += "FMA = " + std::to_string(lm_ggml_cpu_has_fma()) + " | "; - s += "NEON = " + std::to_string(lm_ggml_cpu_has_neon()) + " | "; - s += "SVE = " + std::to_string(lm_ggml_cpu_has_sve()) + " | "; - s += "ARM_FMA = " + std::to_string(lm_ggml_cpu_has_arm_fma()) + " | "; - s += "F16C = " + std::to_string(lm_ggml_cpu_has_f16c()) + " | "; - s += "FP16_VA = " + std::to_string(lm_ggml_cpu_has_fp16_va()) + " | "; - s += "RISCV_VECT = " + std::to_string(lm_ggml_cpu_has_riscv_v()) + " | "; - s += "WASM_SIMD = " + std::to_string(lm_ggml_cpu_has_wasm_simd()) + " | "; - s += "SSE3 = " + std::to_string(lm_ggml_cpu_has_sse3()) + " | "; - s += "SSSE3 = " + std::to_string(lm_ggml_cpu_has_ssse3()) + " | "; - s += "VSX = " + std::to_string(lm_ggml_cpu_has_vsx()) + " | "; - s += "MATMUL_INT8 = " + std::to_string(lm_ggml_cpu_has_matmul_int8()) + " | "; - s += "LLAMAFILE = " + std::to_string(lm_ggml_cpu_has_llamafile()) + " | "; + for (size_t i = 0; i < lm_ggml_backend_reg_count(); i++) { + auto * reg = lm_ggml_backend_reg_get(i); + auto * get_features_fn = (lm_ggml_backend_get_features_t) lm_ggml_backend_reg_get_proc_address(reg, "lm_ggml_backend_get_features"); + if (get_features_fn) { + lm_ggml_backend_feature * features = get_features_fn(reg); + s += lm_ggml_backend_reg_name(reg); + s += " : "; + for (; features->name; features++) { + s += features->name; + s += " = "; + s += features->value; + s += " | "; + } + } + } return s.c_str(); } diff --git a/cpp/llama.h b/cpp/llama.h index 8d14f494..05429c38 100644 --- a/cpp/llama.h +++ b/cpp/llama.h @@ -104,12 +104,15 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, }; enum llama_rope_type { - LLAMA_ROPE_TYPE_NONE = -1, - LLAMA_ROPE_TYPE_NORM = 0, - LLAMA_ROPE_TYPE_NEOX = LM_GGML_ROPE_TYPE_NEOX, + LLAMA_ROPE_TYPE_NONE = -1, + LLAMA_ROPE_TYPE_NORM = 0, + LLAMA_ROPE_TYPE_NEOX = LM_GGML_ROPE_TYPE_NEOX, + LLAMA_ROPE_TYPE_MROPE = LM_GGML_ROPE_TYPE_MROPE, + LLAMA_ROPE_TYPE_VISION = LM_GGML_ROPE_TYPE_VISION, }; enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file @@ -171,9 +174,9 @@ extern "C" { LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors + //LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // removed from gguf files, use Q4_0 and runtime repack + //LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // removed from gguf files, use Q4_0 and runtime repack + //LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors @@ -185,7 +188,8 @@ extern "C" { LLAMA_ROPE_SCALING_TYPE_NONE = 0, LLAMA_ROPE_SCALING_TYPE_LINEAR = 1, LLAMA_ROPE_SCALING_TYPE_YARN = 2, - LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_YARN, + LLAMA_ROPE_SCALING_TYPE_LONGROPE = 3, + LLAMA_ROPE_SCALING_TYPE_MAX_VALUE = LLAMA_ROPE_SCALING_TYPE_LONGROPE, }; enum llama_pooling_type { @@ -272,6 +276,9 @@ extern "C" { }; struct llama_model_params { + // NULL-terminated list of devices to use for offloading (if NULL, all available devices are used) + lm_ggml_backend_dev_t * devices; + int32_t n_gpu_layers; // number of layers to store in VRAM enum llama_split_mode split_mode; // how to split the model across multiple GPUs @@ -451,6 +458,7 @@ extern "C" { // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure // - The output string is always null-terminated and cleared on failure + // - When retrieving a string, an extra byte must be allocated to account for the null terminator // - GGUF array values are not supported by these functions // Get metadata value as a string by key name @@ -474,9 +482,6 @@ extern "C" { // Returns the total number of parameters in the model LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); - // Get a llama model tensor - LLAMA_API struct lm_ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); - // Returns true if the model contains an encoder that requires llama_encode() call LLAMA_API bool llama_model_has_encoder(const struct llama_model * model); @@ -987,6 +992,9 @@ extern "C" { char * buf, int32_t length); + // Get list of built-in chat templates + LLAMA_API int32_t llama_chat_builtin_templates(const char ** output, size_t len); + // // Sampling API // @@ -1128,16 +1136,12 @@ extern "C" { const char * grammar_str, const char * grammar_root); + /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first. LLAMA_API struct llama_sampler * llama_sampler_init_penalties( - int32_t n_vocab, // llama_n_vocab() - llama_token special_eos_id, // llama_token_eos() - llama_token linefeed_id, // llama_token_nl() - int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat, // 1.0 = disabled - float penalty_freq, // 0.0 = disabled - float penalty_present, // 0.0 = disabled - bool penalize_nl, // consider newlines as a repeatable token - bool ignore_eos); // ignore the end-of-sequence token + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present); // 0.0 = disabled /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 LLAMA_API struct llama_sampler * llama_sampler_init_dry( diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index c175744d..63f9d5b0 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -11,6 +11,27 @@ namespace rnllama { +const std::vector kv_cache_types = { + LM_GGML_TYPE_F32, + LM_GGML_TYPE_F16, + LM_GGML_TYPE_BF16, + LM_GGML_TYPE_Q8_0, + LM_GGML_TYPE_Q4_0, + LM_GGML_TYPE_Q4_1, + LM_GGML_TYPE_IQ4_NL, + LM_GGML_TYPE_Q5_0, + LM_GGML_TYPE_Q5_1, +}; + +static lm_ggml_type kv_cache_type_from_str(const std::string & s) { + for (const auto & type : kv_cache_types) { + if (lm_ggml_type_name(type) == s) { + return type; + } + } + throw std::runtime_error("Unsupported cache type: " + s); +} + static std::string lm_gguf_data_to_str(enum lm_gguf_type type, const void * data, int i) { switch (type) { case LM_GGUF_TYPE_UINT8: return std::to_string(((const uint8_t *)data)[i]); @@ -253,7 +274,7 @@ struct llama_rn_context { is_interrupted = false; params.antiprompt.clear(); - params.sparams.grammar.clear(); + params.sampling.grammar.clear(); num_prompt_tokens = 0; num_tokens_predicted = 0; generated_text = ""; @@ -267,14 +288,14 @@ struct llama_rn_context incomplete = false; n_remain = 0; n_past = 0; - params.sparams.n_prev = n_ctx; + params.sampling.n_prev = n_ctx; } bool initSampling() { if (ctx_sampling != nullptr) { common_sampler_free(ctx_sampling); } - ctx_sampling = common_sampler_init(model, params.sparams); + ctx_sampling = common_sampler_init(model, params.sampling); return ctx_sampling != nullptr; } @@ -467,10 +488,10 @@ struct llama_rn_context llama_token_data_array cur_p = *common_sampler_get_candidates(ctx_sampling); - const int32_t n_probs = params.sparams.n_probs; + const int32_t n_probs = params.sampling.n_probs; // deprecated - /*if (params.sparams.temp <= 0 && n_probs > 0) + /*if (params.sampling.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates llama_sampler_init_softmax(); @@ -546,7 +567,7 @@ struct llama_rn_context const std::string token_text = token_with_probs.tok == -1 ? "" : common_token_to_piece(ctx, token_with_probs.tok); generated_text += token_text; - if (params.sparams.n_probs > 0) + if (params.sampling.n_probs > 0) { generated_token_probs.push_back(token_with_probs); } diff --git a/cpp/sampling.cpp b/cpp/sampling.cpp index 66a2311c..b02a299b 100644 --- a/cpp/sampling.cpp +++ b/cpp/sampling.cpp @@ -99,7 +99,7 @@ struct ring_buffer { }; struct common_sampler { - common_sampler_params params; + common_params_sampling params; struct llama_sampler * grmr; struct llama_sampler * chain; @@ -125,7 +125,7 @@ struct common_sampler { } }; -std::string common_sampler_params::print() const { +std::string common_params_sampling::print() const { char result[1024]; snprintf(result, sizeof(result), @@ -141,7 +141,7 @@ std::string common_sampler_params::print() const { return std::string(result); } -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); lparams.no_perf = params.no_perf; @@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.logit_bias.size(), params.logit_bias.data())); - llama_sampler_chain_add(result->chain, - llama_sampler_init_penalties( - llama_n_vocab (model), - llama_token_eos(model), - llama_token_nl (model), - params.penalty_last_n, - params.penalty_repeat, - params.penalty_freq, - params.penalty_present, - params.penalize_nl, - params.ignore_eos)); - if (params.mirostat == 0) { for (const auto & cnstr : params.samplers) { switch (cnstr) { - case COMMON_SAMPLER_TYPE_DRY: + case COMMON_SAMPLER_TYPE_DRY: { - std::vector c_breakers; + std::vector c_breakers; c_breakers.reserve(params.dry_sequence_breakers.size()); - for (const auto& str : params.dry_sequence_breakers) { + for (const auto & str : params.dry_sequence_breakers) { c_breakers.push_back(str.c_str()); } llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } - break; + break; case COMMON_SAMPLER_TYPE_TOP_K: llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); break; @@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co case COMMON_SAMPLER_TYPE_INFILL: llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); break; + case COMMON_SAMPLER_TYPE_PENALTIES: + llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); + break; default: LM_GGML_ASSERT(false && "unknown sampler type"); } @@ -320,6 +311,45 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return cur_p.data[cur_p.selected].id; } +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { + LM_GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); + + std::vector result; + result.reserve(idxs.size()); + + size_t i = 0; + for (; i < draft.size(); i++) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + common_sampler_accept(gsmpl, id, true); + + result.push_back(id); + + if (draft[i] != id) { + break; + } + } + + if (i == draft.size()) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + common_sampler_accept(gsmpl, id, true); + + result.push_back(id); + } + + return result; +} + +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { + std::vector idxs(draft.size() + 1); + for (size_t i = 0; i < idxs.size(); ++i) { + idxs[i] = i; + } + + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); +} + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { return llama_sampler_get_seed(gsmpl->chain); } @@ -376,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't'; case COMMON_SAMPLER_TYPE_XTC: return 'x'; case COMMON_SAMPLER_TYPE_INFILL: return 'i'; + case COMMON_SAMPLER_TYPE_PENALTIES: return 'e'; default : return '?'; } } @@ -390,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) { case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature"; case COMMON_SAMPLER_TYPE_XTC: return "xtc"; case COMMON_SAMPLER_TYPE_INFILL: return "infill"; + case COMMON_SAMPLER_TYPE_PENALTIES: return "penalties"; default : return ""; } } @@ -404,6 +436,7 @@ std::vector common_sampler_types_from_names(const std::vect { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE }, { "xtc", COMMON_SAMPLER_TYPE_XTC }, { "infill", COMMON_SAMPLER_TYPE_INFILL }, + { "penalties", COMMON_SAMPLER_TYPE_PENALTIES }, }; // since samplers names are written multiple ways @@ -450,6 +483,7 @@ std::vector common_sampler_types_from_chars(const std::stri { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }, { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL }, + { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES), COMMON_SAMPLER_TYPE_PENALTIES }, }; std::vector samplers; diff --git a/cpp/sampling.h b/cpp/sampling.h index d37f25ad..348911b1 100644 --- a/cpp/sampling.h +++ b/cpp/sampling.h @@ -36,7 +36,7 @@ struct common_sampler; // llama_sampler API overloads -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_sampler_params & params); +struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); void common_sampler_free(struct common_sampler * gsmpl); @@ -60,6 +60,27 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam // llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +// generalized version of common_sampler_sample +// +// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match +// if the sampler disagrees at some point, we stop and return the accepted tokens up to now +// +// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); +// +// is equivalent to +// +// common_sampler_sample(gsmpl, ctx, idx); +// common_sampler_accept(gsmpl, token, true); +// +// requires: idxs.size() == draft.size() + 1 +// +// returns at least 1 token, up to idxs.size() +// +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); + +// assume idxs == [ 0, 1, 2, ..., draft.size() ] +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); // helpers diff --git a/cpp/sgemm.cpp b/cpp/sgemm.cpp index 5f8cc2d5..14f6eac6 100644 --- a/cpp/sgemm.cpp +++ b/cpp/sgemm.cpp @@ -204,6 +204,7 @@ template <> inline float32x4_t load(const float *p) { return vld1q_f32(p); } #if !defined(_MSC_VER) +// FIXME: this should check for __ARM_FEATURE_FP16_VECTOR_ARITHMETIC template <> inline float16x8_t load(const lm_ggml_fp16_t *p) { return vld1q_f16((const float16_t *)p); } diff --git a/cpp/unicode.cpp b/cpp/unicode.cpp index 50b35bbb..8ed6b1a5 100644 --- a/cpp/unicode.cpp +++ b/cpp/unicode.cpp @@ -71,15 +71,15 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) { throw std::invalid_argument("failed to convert utf8 to codepoint"); } -//static std::vector unicode_cpt_to_utf16(uint32_t cp) { +//static std::vector unicode_cpt_to_utf16(uint32_t cpt) { // std::vector result; -// if (/* 0x0000 <= cp && */ cp <= 0xffff) { -// result.emplace_back(cp); +// if (/* 0x0000 <= cpt && */ cpt <= 0xffff) { +// result.emplace_back(cpt); // return result; // } -// if (0x10000 <= cp && cp <= 0x10ffff) { -// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); -// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); +// if (0x10000 <= cpt && cpt <= 0x10ffff) { +// result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10)); +// result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff)); // return result; // } // throw std::invalid_argument("failed to convert codepoint to utf16"); @@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) { // return result; //} -static std::vector unicode_cpt_flags_array() { - std::vector cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED); +static std::vector unicode_cpt_flags_array() { + std::vector cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED); assert (unicode_ranges_flags.begin()[0].first == 0); assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS); @@ -201,7 +201,18 @@ static std::unordered_map unicode_utf8_to_byte_map() { } static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { +#if defined(__clang__) + // disable C++17 deprecation warning for std::codecvt_utf8 +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdeprecated-declarations" +#endif + std::wstring_convert> conv; + +#if defined(__clang__) +# pragma clang diagnostic pop +#endif + return conv.from_bytes(s); } @@ -242,8 +253,8 @@ static std::vector unicode_regex_split_custom_gpt2(const std::string & t return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; }; - auto _get_flags = [&] (const size_t pos) -> codepoint_flags { - return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; }; size_t _prev_end = offset_ini; @@ -360,8 +371,8 @@ static std::vector unicode_regex_split_custom_llama3(const std::string & return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; }; - auto _get_flags = [&] (const size_t pos) -> codepoint_flags { - return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; }; size_t _prev_end = offset_ini; @@ -561,29 +572,29 @@ static std::vector unicode_regex_split_custom(const std::string & text, // interface // -std::string unicode_cpt_to_utf8(uint32_t cp) { +std::string unicode_cpt_to_utf8(uint32_t cpt) { std::string result; - if (/* 0x00 <= cp && */ cp <= 0x7f) { - result.push_back(cp); + if (/* 0x00 <= cpt && */ cpt <= 0x7f) { + result.push_back(cpt); return result; } - if (0x80 <= cp && cp <= 0x7ff) { - result.push_back(0xc0 | ((cp >> 6) & 0x1f)); - result.push_back(0x80 | (cp & 0x3f)); + if (0x80 <= cpt && cpt <= 0x7ff) { + result.push_back(0xc0 | ((cpt >> 6) & 0x1f)); + result.push_back(0x80 | (cpt & 0x3f)); return result; } - if (0x800 <= cp && cp <= 0xffff) { - result.push_back(0xe0 | ((cp >> 12) & 0x0f)); - result.push_back(0x80 | ((cp >> 6) & 0x3f)); - result.push_back(0x80 | (cp & 0x3f)); + if (0x800 <= cpt && cpt <= 0xffff) { + result.push_back(0xe0 | ((cpt >> 12) & 0x0f)); + result.push_back(0x80 | ((cpt >> 6) & 0x3f)); + result.push_back(0x80 | (cpt & 0x3f)); return result; } - if (0x10000 <= cp && cp <= 0x10ffff) { - result.push_back(0xf0 | ((cp >> 18) & 0x07)); - result.push_back(0x80 | ((cp >> 12) & 0x3f)); - result.push_back(0x80 | ((cp >> 6) & 0x3f)); - result.push_back(0x80 | (cp & 0x3f)); + if (0x10000 <= cpt && cpt <= 0x10ffff) { + result.push_back(0xf0 | ((cpt >> 18) & 0x07)); + result.push_back(0x80 | ((cpt >> 12) & 0x3f)); + result.push_back(0x80 | ((cpt >> 6) & 0x3f)); + result.push_back(0x80 | (cpt & 0x3f)); return result; } @@ -613,19 +624,19 @@ std::vector unicode_cpts_from_utf8(const std::string & utf8) { return result; } -codepoint_flags unicode_cpt_flags(const uint32_t cp) { - static const codepoint_flags undef(codepoint_flags::UNDEFINED); +unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) { + static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED); static const auto cpt_flags = unicode_cpt_flags_array(); - return cp < cpt_flags.size() ? cpt_flags[cp] : undef; + return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef; } -codepoint_flags unicode_cpt_flags(const std::string & utf8) { - static const codepoint_flags undef(codepoint_flags::UNDEFINED); +unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) { + static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED); if (utf8.empty()) { return undef; // undefined } size_t offset = 0; - return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset)); + return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset)); } std::string unicode_byte_to_utf8(uint8_t byte) { @@ -638,41 +649,41 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) { return map.at(utf8); } -uint32_t unicode_tolower(uint32_t cp) { +uint32_t unicode_tolower(uint32_t cpt) { // binary search - auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp, + auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt, [](const std::pair & pair, uint32_t value) { return pair.first < value; }); - if (it != unicode_map_lowercase.end() && it->first == cp) { + if (it != unicode_map_lowercase.end() && it->first == cpt) { return it->second; } - return cp; // Return the original code point if no lowercase mapping is found + return cpt; // Return the original code point if no lowercase mapping is found } std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { // unicode categories static const std::map k_ucat_enum = { - { "\\p{N}", codepoint_flags::NUMBER }, - { "\\p{L}", codepoint_flags::LETTER }, - { "\\p{P}", codepoint_flags::PUNCTUATION }, + { "\\p{N}", unicode_cpt_flags::NUMBER }, + { "\\p{L}", unicode_cpt_flags::LETTER }, + { "\\p{P}", unicode_cpt_flags::PUNCTUATION }, }; static const std::map k_ucat_cpt = { - { codepoint_flags::NUMBER, 0xD1 }, - { codepoint_flags::LETTER, 0xD2 }, - { codepoint_flags::PUNCTUATION, 0xD3 }, + { unicode_cpt_flags::NUMBER, 0xD1 }, + { unicode_cpt_flags::LETTER, 0xD2 }, + { unicode_cpt_flags::PUNCTUATION, 0xD3 }, }; static const std::map k_ucat_map = { - { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 - { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z - { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} + { unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9 + { unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z + { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} }; // compute collapsed codepoints only if needed by at least one regex bool need_collapse = false; - for (auto & regex_expr : regex_exprs) { + for (const auto & regex_expr : regex_exprs) { // search for unicode categories for (const auto & ucat : k_ucat_enum) { if (std::string::npos != regex_expr.find(ucat.first)) { @@ -698,7 +709,7 @@ std::vector unicode_regex_split(const std::string & text, const std continue; } - const auto flags = unicode_cpt_flags(cpts[i]); + const auto flags = unicode_cpt_flags_from_cpt(cpts[i]); if (flags.is_whitespace) { //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does. @@ -714,7 +725,7 @@ std::vector unicode_regex_split(const std::string & text, const std std::vector bpe_offsets = { cpts.size() }; - for (auto & regex_expr : regex_exprs) { + for (const auto & regex_expr : regex_exprs) { // first, see if we have an efficient custom regex implementation auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); @@ -728,7 +739,7 @@ std::vector unicode_regex_split(const std::string & text, const std // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category // with the corresponding collapsed representation bool use_collapsed = false; - for (auto & ucat : k_ucat_enum) { + for (const auto & ucat : k_ucat_enum) { if (std::string::npos != regex_expr.find(ucat.first)) { use_collapsed = true; break; @@ -794,7 +805,7 @@ std::vector unicode_regex_split(const std::string & text, const std // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback std::wstring wtext(cpts.begin(), cpts.end()); for (size_t i = 0; i < wtext.size(); ++i) { - if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) { + if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) { wtext[i] = 0x0B; } } diff --git a/cpp/unicode.h b/cpp/unicode.h index 008532a2..c27098df 100644 --- a/cpp/unicode.h +++ b/cpp/unicode.h @@ -4,9 +4,7 @@ #include #include -// TODO: prefix all symbols with "llama_" - -struct codepoint_flags { +struct unicode_cpt_flags { enum { UNDEFINED = 0x0001, NUMBER = 0x0002, // regex: \p{N} @@ -35,7 +33,7 @@ struct codepoint_flags { uint16_t is_nfd : 1; // decode from uint16 - inline codepoint_flags(const uint16_t flags=0) { + inline unicode_cpt_flags(const uint16_t flags = 0) { *reinterpret_cast(this) = flags; } @@ -50,18 +48,19 @@ struct codepoint_flags { size_t unicode_len_utf8(char src); -std::string unicode_cpt_to_utf8(uint32_t cp); -uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset); +std::string unicode_cpt_to_utf8 (uint32_t cpt); +uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset); + std::vector unicode_cpts_from_utf8(const std::string & utf8); std::vector unicode_cpts_normalize_nfd(const std::vector & cpts); -codepoint_flags unicode_cpt_flags(const uint32_t cp); -codepoint_flags unicode_cpt_flags(const std::string & utf8); +unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt); +unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8); std::string unicode_byte_to_utf8(uint8_t byte); -uint8_t unicode_utf8_to_byte(const std::string & utf8); +uint8_t unicode_utf8_to_byte(const std::string & utf8); -uint32_t unicode_tolower(uint32_t cp); +uint32_t unicode_tolower(uint32_t cpt); std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); diff --git a/docs/API/README.md b/docs/API/README.md index c307fb18..814da0ae 100644 --- a/docs/API/README.md +++ b/docs/API/README.md @@ -63,7 +63,7 @@ llama.rn #### Defined in -[index.ts:103](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L103) +[index.ts:103](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L103) ___ @@ -73,7 +73,7 @@ ___ #### Defined in -[index.ts:94](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L94) +[index.ts:94](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L94) ___ @@ -83,7 +83,7 @@ ___ #### Defined in -[index.ts:67](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L67) +[index.ts:67](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L67) ___ @@ -93,7 +93,7 @@ ___ #### Defined in -[index.ts:92](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L92) +[index.ts:92](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L92) ___ @@ -121,7 +121,6 @@ ___ | `n_predict?` | `number` | Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0,no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity. | | `n_probs?` | `number` | If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0` | | `n_threads?` | `number` | - | -| `penalize_nl?` | `boolean` | Penalize newline tokens when applying the repeat penalty. Default: `false` | | `penalty_freq?` | `number` | Repeat alpha frequency penalty. Default: `0.0`, which is disabled. | | `penalty_last_n?` | `number` | Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size. | | `penalty_present?` | `number` | Repeat alpha presence penalty. Default: `0.0`, which is disabled. | @@ -138,7 +137,7 @@ ___ #### Defined in -[NativeRNLlama.ts:60](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L60) +[NativeRNLlama.ts:60](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L60) ___ @@ -164,7 +163,7 @@ ___ #### Defined in -[NativeRNLlama.ts:205](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L205) +[NativeRNLlama.ts:201](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L201) ___ @@ -187,7 +186,7 @@ ___ #### Defined in -[NativeRNLlama.ts:194](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L194) +[NativeRNLlama.ts:190](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L190) ___ @@ -204,7 +203,7 @@ ___ #### Defined in -[NativeRNLlama.ts:189](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L189) +[NativeRNLlama.ts:185](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L185) ___ @@ -221,7 +220,7 @@ ___ #### Defined in -[NativeRNLlama.ts:184](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L184) +[NativeRNLlama.ts:180](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L180) ___ @@ -257,7 +256,7 @@ ___ #### Defined in -[NativeRNLlama.ts:8](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L8) +[NativeRNLlama.ts:8](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L8) ___ @@ -273,7 +272,7 @@ ___ #### Defined in -[NativeRNLlama.ts:4](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L4) +[NativeRNLlama.ts:4](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L4) ___ @@ -289,7 +288,7 @@ ___ #### Defined in -[NativeRNLlama.ts:225](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L225) +[NativeRNLlama.ts:221](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L221) ___ @@ -308,7 +307,7 @@ ___ #### Defined in -[NativeRNLlama.ts:229](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L229) +[NativeRNLlama.ts:225](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L225) ___ @@ -325,7 +324,7 @@ ___ #### Defined in -[NativeRNLlama.ts:236](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L236) +[NativeRNLlama.ts:232](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L232) ___ @@ -341,7 +340,7 @@ ___ #### Defined in -[NativeRNLlama.ts:221](https://github.com/mybigday/llama.rn/blob/276a90a/src/NativeRNLlama.ts#L221) +[NativeRNLlama.ts:217](https://github.com/mybigday/llama.rn/blob/86adb39/src/NativeRNLlama.ts#L217) ___ @@ -357,7 +356,7 @@ ___ #### Defined in -[chat.ts:3](https://github.com/mybigday/llama.rn/blob/276a90a/src/chat.ts#L3) +[chat.ts:3](https://github.com/mybigday/llama.rn/blob/86adb39/src/chat.ts#L3) ___ @@ -374,7 +373,7 @@ ___ #### Defined in -[chat.ts:7](https://github.com/mybigday/llama.rn/blob/276a90a/src/chat.ts#L7) +[chat.ts:7](https://github.com/mybigday/llama.rn/blob/86adb39/src/chat.ts#L7) ___ @@ -391,7 +390,7 @@ ___ #### Defined in -[index.ts:57](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L57) +[index.ts:57](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L57) ## Functions @@ -415,7 +414,7 @@ ___ #### Defined in -[grammar.ts:826](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L826) +[grammar.ts:826](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L826) ___ @@ -436,7 +435,7 @@ ___ #### Defined in -[index.ts:295](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L295) +[index.ts:295](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L295) ___ @@ -456,7 +455,7 @@ ___ #### Defined in -[index.ts:280](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L280) +[index.ts:280](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L280) ___ @@ -470,7 +469,7 @@ ___ #### Defined in -[index.ts:354](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L354) +[index.ts:354](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L354) ___ @@ -490,4 +489,4 @@ ___ #### Defined in -[index.ts:266](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L266) +[index.ts:266](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L266) diff --git a/docs/API/classes/LlamaContext.md b/docs/API/classes/LlamaContext.md index c6e0ea69..847a915f 100644 --- a/docs/API/classes/LlamaContext.md +++ b/docs/API/classes/LlamaContext.md @@ -45,7 +45,7 @@ #### Defined in -[index.ts:124](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L124) +[index.ts:124](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L124) ## Properties @@ -55,7 +55,7 @@ #### Defined in -[index.ts:116](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L116) +[index.ts:116](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L116) ___ @@ -65,7 +65,7 @@ ___ #### Defined in -[index.ts:114](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L114) +[index.ts:114](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L114) ___ @@ -81,7 +81,7 @@ ___ #### Defined in -[index.ts:120](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L120) +[index.ts:120](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L120) ___ @@ -91,7 +91,7 @@ ___ #### Defined in -[index.ts:118](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L118) +[index.ts:118](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L118) ## Methods @@ -111,7 +111,7 @@ ___ #### Defined in -[index.ts:239](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L239) +[index.ts:239](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L239) ___ @@ -134,7 +134,7 @@ ___ #### Defined in -[index.ts:219](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L219) +[index.ts:219](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L219) ___ @@ -155,7 +155,7 @@ ___ #### Defined in -[index.ts:160](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L160) +[index.ts:160](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L160) ___ @@ -175,7 +175,7 @@ ___ #### Defined in -[index.ts:208](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L208) +[index.ts:208](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L208) ___ @@ -196,7 +196,7 @@ ___ #### Defined in -[index.ts:212](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L212) +[index.ts:212](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L212) ___ @@ -217,7 +217,7 @@ ___ #### Defined in -[index.ts:150](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L150) +[index.ts:150](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L150) ___ @@ -231,7 +231,7 @@ ___ #### Defined in -[index.ts:255](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L255) +[index.ts:255](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L255) ___ @@ -253,7 +253,7 @@ Load cached prompt & completion state from a file. #### Defined in -[index.ts:134](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L134) +[index.ts:134](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L134) ___ @@ -267,7 +267,7 @@ ___ #### Defined in -[index.ts:261](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L261) +[index.ts:261](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L261) ___ @@ -281,7 +281,7 @@ ___ #### Defined in -[index.ts:251](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L251) +[index.ts:251](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L251) ___ @@ -305,7 +305,7 @@ Save current cached prompt & completion state to a file. #### Defined in -[index.ts:143](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L143) +[index.ts:143](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L143) ___ @@ -319,7 +319,7 @@ ___ #### Defined in -[index.ts:200](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L200) +[index.ts:200](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L200) ___ @@ -339,4 +339,4 @@ ___ #### Defined in -[index.ts:204](https://github.com/mybigday/llama.rn/blob/276a90a/src/index.ts#L204) +[index.ts:204](https://github.com/mybigday/llama.rn/blob/86adb39/src/index.ts#L204) diff --git a/docs/API/classes/SchemaGrammarConverter.md b/docs/API/classes/SchemaGrammarConverter.md index 0daa36b6..458db18e 100644 --- a/docs/API/classes/SchemaGrammarConverter.md +++ b/docs/API/classes/SchemaGrammarConverter.md @@ -46,7 +46,7 @@ #### Defined in -[grammar.ts:213](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L213) +[grammar.ts:213](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L213) ## Properties @@ -56,7 +56,7 @@ #### Defined in -[grammar.ts:203](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L203) +[grammar.ts:203](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L203) ___ @@ -66,7 +66,7 @@ ___ #### Defined in -[grammar.ts:205](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L205) +[grammar.ts:205](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L205) ___ @@ -76,7 +76,7 @@ ___ #### Defined in -[grammar.ts:201](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L201) +[grammar.ts:201](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L201) ___ @@ -90,7 +90,7 @@ ___ #### Defined in -[grammar.ts:209](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L209) +[grammar.ts:209](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L209) ___ @@ -100,7 +100,7 @@ ___ #### Defined in -[grammar.ts:211](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L211) +[grammar.ts:211](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L211) ___ @@ -114,7 +114,7 @@ ___ #### Defined in -[grammar.ts:207](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L207) +[grammar.ts:207](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L207) ## Methods @@ -135,7 +135,7 @@ ___ #### Defined in -[grammar.ts:695](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L695) +[grammar.ts:695](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L695) ___ @@ -156,7 +156,7 @@ ___ #### Defined in -[grammar.ts:226](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L226) +[grammar.ts:226](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L226) ___ @@ -179,7 +179,7 @@ ___ #### Defined in -[grammar.ts:712](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L712) +[grammar.ts:712](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L712) ___ @@ -200,7 +200,7 @@ ___ #### Defined in -[grammar.ts:314](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L314) +[grammar.ts:314](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L314) ___ @@ -220,7 +220,7 @@ ___ #### Defined in -[grammar.ts:520](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L520) +[grammar.ts:520](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L520) ___ @@ -241,7 +241,7 @@ ___ #### Defined in -[grammar.ts:325](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L325) +[grammar.ts:325](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L325) ___ @@ -255,7 +255,7 @@ ___ #### Defined in -[grammar.ts:815](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L815) +[grammar.ts:815](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L815) ___ @@ -276,7 +276,7 @@ ___ #### Defined in -[grammar.ts:249](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L249) +[grammar.ts:249](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L249) ___ @@ -297,4 +297,4 @@ ___ #### Defined in -[grammar.ts:531](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L531) +[grammar.ts:531](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L531) diff --git a/docs/API/classes/SchemaGrammarConverterBuiltinRule.md b/docs/API/classes/SchemaGrammarConverterBuiltinRule.md index c758d7ea..7c24d601 100644 --- a/docs/API/classes/SchemaGrammarConverterBuiltinRule.md +++ b/docs/API/classes/SchemaGrammarConverterBuiltinRule.md @@ -28,7 +28,7 @@ #### Defined in -[grammar.ts:82](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L82) +[grammar.ts:82](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L82) ## Properties @@ -38,7 +38,7 @@ #### Defined in -[grammar.ts:78](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L78) +[grammar.ts:78](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L78) ___ @@ -48,4 +48,4 @@ ___ #### Defined in -[grammar.ts:80](https://github.com/mybigday/llama.rn/blob/276a90a/src/grammar.ts#L80) +[grammar.ts:80](https://github.com/mybigday/llama.rn/blob/86adb39/src/grammar.ts#L80) diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local index 39e576b1..e3806c10 100644 --- a/example/ios/.xcode.env.local +++ b/example/ios/.xcode.env.local @@ -1 +1 @@ -export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1734671296534-0.19020674937051285/node +export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1734674930237-0.4736887321888663/node diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index f1ab5ebe..0b56fda7 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -1270,7 +1270,7 @@ SPEC CHECKSUMS: glog: 04b94705f318337d7ead9e6d17c019bd9b1f6b1b hermes-engine: 10fbd3f62405c41ea07e71973ea61e1878d07322 libevent: 4049cae6c81cdb3654a443be001fb9bdceff7913 - llama-rn: 714a935f3080c60c393daae9d1e682c20cb5d647 + llama-rn: 94fb38e08bc2bc683a3e625f60d956a3739fb8eb RCT-Folly: 424b8c9a7a0b9ab2886ffe9c3b041ef628fd4fb1 RCTRequired: a2faf4bad4e438ca37b2040cb8f7799baa065c18 RCTTypeSafety: cb09f3e4747b6d18331a15eb05271de7441ca0b3 diff --git a/example/ios/RNLlamaExample.xcodeproj/project.pbxproj b/example/ios/RNLlamaExample.xcodeproj/project.pbxproj index a499a241..5618edc5 100644 --- a/example/ios/RNLlamaExample.xcodeproj/project.pbxproj +++ b/example/ios/RNLlamaExample.xcodeproj/project.pbxproj @@ -440,7 +440,7 @@ "$(inherited)", ); INFOPLIST_FILE = RNLlamaExampleTests/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 12.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", @@ -464,7 +464,7 @@ BUNDLE_LOADER = "$(TEST_HOST)"; COPY_PHASE_STRIP = NO; INFOPLIST_FILE = RNLlamaExampleTests/Info.plist; - IPHONEOS_DEPLOYMENT_TARGET = 12.4; + IPHONEOS_DEPLOYMENT_TARGET = 13.0; LD_RUNPATH_SEARCH_PATHS = ( "$(inherited)", "@executable_path/Frameworks", diff --git a/example/src/App.tsx b/example/src/App.tsx index 3849753c..55d856cb 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -453,7 +453,6 @@ export default function App() { mirostat: 0, mirostat_tau: 5, mirostat_eta: 0.1, - penalize_nl: false, ignore_eos: false, stop: [ '', diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 8d364889..9d08af43 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -115,8 +115,8 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig if (params[@"flash_attn"] && [params[@"flash_attn"] boolValue]) defaultParams.flash_attn = true; - if (params[@"cache_type_k"]) defaultParams.cache_type_k = [params[@"cache_type_k"] UTF8String]; - if (params[@"cache_type_v"]) defaultParams.cache_type_v = [params[@"cache_type_v"] UTF8String]; + if (params[@"cache_type_k"]) defaultParams.cache_type_k = rnllama::kv_cache_type_from_str([params[@"cache_type_k"] UTF8String]); + if (params[@"cache_type_v"]) defaultParams.cache_type_v = rnllama::kv_cache_type_from_str([params[@"cache_type_v"] UTF8String]); int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : 0; const int maxThreads = (int) [[NSProcessInfo processInfo] processorCount]; @@ -280,7 +280,7 @@ - (NSDictionary *)completion:(NSDictionary *)params NSString *prompt = [params objectForKey:@"prompt"]; llama->params.prompt = [prompt UTF8String]; - llama->params.sparams.seed = params[@"seed"] ? [params[@"seed"] intValue] : -1; + llama->params.sampling.seed = params[@"seed"] ? [params[@"seed"] intValue] : -1; if (params[@"n_threads"]) { int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : llama->params.cpuparams.n_threads; @@ -290,9 +290,9 @@ - (NSDictionary *)completion:(NSDictionary *)params llama->params.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads; } if (params[@"n_predict"]) llama->params.n_predict = [params[@"n_predict"] intValue]; - if (params[@"ignore_eos"]) llama->params.sparams.ignore_eos = [params[@"ignore_eos"] boolValue]; + if (params[@"ignore_eos"]) llama->params.sampling.ignore_eos = [params[@"ignore_eos"] boolValue]; - auto & sparams = llama->params.sparams; + auto & sparams = llama->params.sampling; if (params[@"temperature"]) sparams.temp = [params[@"temperature"] doubleValue]; @@ -306,7 +306,6 @@ - (NSDictionary *)completion:(NSDictionary *)params if (params[@"mirostat"]) sparams.mirostat = [params[@"mirostat"] intValue]; if (params[@"mirostat_tau"]) sparams.mirostat_tau = [params[@"mirostat_tau"] doubleValue]; if (params[@"mirostat_eta"]) sparams.mirostat_eta = [params[@"mirostat_eta"] doubleValue]; - if (params[@"penalize_nl"]) sparams.penalize_nl = [params[@"penalize_nl"] boolValue]; if (params[@"top_k"]) sparams.top_k = [params[@"top_k"] intValue]; if (params[@"top_p"]) sparams.top_p = [params[@"top_p"] doubleValue]; @@ -410,7 +409,7 @@ - (NSDictionary *)completion:(NSDictionary *)params NSMutableDictionary *tokenResult = [[NSMutableDictionary alloc] init]; tokenResult[@"token"] = [NSString stringWithUTF8String:to_send.c_str()]; - if (llama->params.sparams.n_probs > 0) { + if (llama->params.sampling.n_probs > 0) { const std::vector to_send_toks = common_tokenize(llama->ctx, to_send, false); size_t probs_pos = std::min(sent_token_probs_index, llama->generated_token_probs.size()); size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama->generated_token_probs.size()); diff --git a/llama-rn.podspec b/llama-rn.podspec index 22bfcd8b..b62c95a4 100644 --- a/llama-rn.podspec +++ b/llama-rn.podspec @@ -2,7 +2,7 @@ require "json" package = JSON.parse(File.read(File.join(__dir__, "package.json"))) base_ld_flags = "-framework Accelerate -framework Foundation -framework Metal -framework MetalKit" -base_compiler_flags = "-fno-objc-arc -DLM_GGML_USE_ACCELERATE -Wno-shorten-64-to-32" +base_compiler_flags = "-fno-objc-arc -DLM_GGML_USE_CPU -DLM_GGML_USE_ACCELERATE -Wno-shorten-64-to-32" if ENV["RNLLAMA_DISABLE_METAL"] != "1" then base_compiler_flags += " -DLM_GGML_USE_METAL" # -DLM_GGML_METAL_NDEBUG @@ -20,7 +20,7 @@ Pod::Spec.new do |s| s.license = package["license"] s.authors = package["author"] - s.platforms = { :ios => "11.0", :tvos => "11.0" } + s.platforms = { :ios => "13.0", :tvos => "13.0" } s.source = { :git => "https://github.com/mybigday/llama.rn.git", :tag => "#{s.version}" } s.source_files = "ios/**/*.{h,m,mm}", "cpp/**/*.{h,cpp,hpp,c,m,mm}" diff --git a/llama.cpp b/llama.cpp index 9abe9eea..b92a14a8 160000 --- a/llama.cpp +++ b/llama.cpp @@ -1 +1 @@ -Subproject commit 9abe9eeae98b11fa93b82632b264126a010225ff +Subproject commit b92a14a841fb4dfaf27b29d982ec8ba5289a3bff diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index d4369650..67e5cb01 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -20,9 +20,13 @@ cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c ./cpp/ggml-cpu.c cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp ./cpp/ggml-cpu.cpp cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h ./cpp/ggml-cpu-impl.h cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h ./cpp/ggml-cpu-aarch64.h -cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c ./cpp/ggml-cpu-aarch64.c +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp ./cpp/ggml-cpu-aarch64.cpp cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h ./cpp/ggml-cpu-quants.h cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c ./cpp/ggml-cpu-quants.c +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h ./cpp/ggml-cpu-traits.h +cp ./llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp ./cpp/ggml-cpu-traits.cpp + +cp -r ./llama.cpp/ggml/src/ggml-cpu/amx ./cpp/ cp ./llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h ./cpp/sgemm.h cp ./llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp ./cpp/sgemm.cpp @@ -37,8 +41,6 @@ cp ./llama.cpp/ggml/src/ggml-common.h ./cpp/ggml-common.h cp ./llama.cpp/ggml/src/ggml-opt.cpp ./cpp/ggml-opt.cpp cp ./llama.cpp/ggml/src/ggml-quants.h ./cpp/ggml-quants.h cp ./llama.cpp/ggml/src/ggml-quants.c ./cpp/ggml-quants.c -cp ./llama.cpp/ggml/src/ggml-aarch64.c ./cpp/ggml-aarch64.c -cp ./llama.cpp/ggml/src/ggml-aarch64.h ./cpp/ggml-aarch64.h cp ./llama.cpp/ggml/src/ggml-threading.cpp ./cpp/ggml-threading.cpp cp ./llama.cpp/ggml/src/ggml-threading.h ./cpp/ggml-threading.h @@ -105,13 +107,18 @@ files_add_lm_prefix=( "./cpp/ggml-cpu.c" "./cpp/ggml-cpu.cpp" "./cpp/ggml-cpu-aarch64.h" - "./cpp/ggml-cpu-aarch64.c" + "./cpp/ggml-cpu-aarch64.cpp" "./cpp/ggml-cpu-quants.h" "./cpp/ggml-cpu-quants.c" + "./cpp/ggml-cpu-traits.h" + "./cpp/ggml-cpu-traits.cpp" "./cpp/ggml-threading.h" "./cpp/ggml-threading.cpp" - "./cpp/ggml-aarch64.h" - "./cpp/ggml-aarch64.c" + "./cpp/amx/amx.h" + "./cpp/amx/amx.cpp" + "./cpp/amx/mmq.h" + "./cpp/amx/mmq.cpp" + "./cpp/amx/common.h" ) # Loop through each file and run the sed commands @@ -167,9 +174,6 @@ patch -p0 -d ./cpp < ./scripts/ggml-metal.m.patch patch -p0 -d ./cpp < ./scripts/ggml-backend-reg.cpp.patch patch -p0 -d ./cpp < ./scripts/ggml.c.patch patch -p0 -d ./cpp < ./scripts/ggml-quants.c.patch -patch -p0 -d ./cpp < ./scripts/ggml-cpu-aarch64.c.patch -patch -p0 -d ./cpp < ./scripts/sgemm.cpp.patch - if [ "$OS" = "Darwin" ]; then # Build metallib (~1.4MB) diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch index 44268d4a..a37f38bd 100644 --- a/scripts/common.cpp.patch +++ b/scripts/common.cpp.patch @@ -1,5 +1,5 @@ ---- common.cpp.orig 2024-12-20 13:06:22 -+++ common.cpp 2024-12-20 13:05:42 +--- common.cpp.orig 2024-12-20 13:45:13 ++++ common.cpp 2024-12-20 13:47:23 @@ -4,10 +4,6 @@ #include "common.h" @@ -33,7 +33,7 @@ // // CPU utils -@@ -985,6 +985,8 @@ +@@ -1017,6 +1017,8 @@ if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } @@ -42,7 +42,7 @@ mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; -@@ -999,6 +1001,11 @@ +@@ -1031,6 +1033,11 @@ mparams.kv_overrides = params.kv_overrides.data(); } @@ -54,13 +54,13 @@ return mparams; } -@@ -1123,221 +1130,7 @@ +@@ -1116,221 +1123,6 @@ + LOG_ERR("%s: curl_easy_perform() failed after %d attempts\n", __func__, max_attempts); return false; - } +-} - -static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { -- - // Initialize libcurl - std::unique_ptr curl(curl_easy_init(), &curl_easy_cleanup); - if (!curl) { @@ -98,7 +98,7 @@ - nlohmann::json metadata; - std::string etag; - std::string last_modified; - +- - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); @@ -133,11 +133,13 @@ - std::string etag; - std::string last_modified; - }; +- - common_load_model_from_url_headers headers; +- - { - typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *); - auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t { -- common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata; +- common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata; - - static std::regex header_regex("([^:]+): (.*)\r\n"); - static std::regex etag_regex("ETag", std::regex_constants::icase); @@ -271,8 +273,6 @@ - } - - return true; --} -- + } + struct llama_model * common_load_model_from_url( - const char * model_url, - const char * path_model, diff --git a/scripts/common.h.patch b/scripts/common.h.patch index 0f6b0a02..dd27da8b 100644 --- a/scripts/common.h.patch +++ b/scripts/common.h.patch @@ -1,6 +1,6 @@ ---- common.h.orig 2024-12-20 13:06:22 -+++ common.h 2024-12-20 13:05:53 -@@ -41,6 +41,17 @@ +--- common.h.orig 2024-12-20 13:44:36 ++++ common.h 2024-12-20 13:44:25 +@@ -43,6 +43,17 @@ struct common_control_vector_load_info; @@ -18,7 +18,7 @@ // // CPU utils // -@@ -154,6 +165,7 @@ +@@ -183,6 +194,7 @@ }; struct common_params { @@ -26,13 +26,13 @@ int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) -@@ -270,6 +282,9 @@ +@@ -300,6 +312,9 @@ bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data + llama_progress_callback progress_callback = nullptr; + void * progress_callback_user_data = nullptr; + - std::string cache_type_k = "f16"; // KV cache data type for the K - std::string cache_type_v = "f16"; // KV cache data type for the V + lm_ggml_type cache_type_k = LM_GGML_TYPE_F16; // KV cache data type for the K + lm_ggml_type cache_type_v = LM_GGML_TYPE_F16; // KV cache data type for the V diff --git a/scripts/ggml-backend-reg.cpp.patch b/scripts/ggml-backend-reg.cpp.patch index 8b80d697..486eb8e2 100644 --- a/scripts/ggml-backend-reg.cpp.patch +++ b/scripts/ggml-backend-reg.cpp.patch @@ -1,30 +1,29 @@ ---- ggml-backend-reg.cpp.orig 2024-11-17 11:53:44 -+++ ggml-backend-reg.cpp 2024-11-17 11:53:17 -@@ -12,9 +12,14 @@ +--- ggml-backend-reg.cpp.orig 2024-12-20 13:38:33 ++++ ggml-backend-reg.cpp 2024-12-20 13:38:34 +@@ -35,9 +35,14 @@ #endif - + #ifdef LM_GGML_USE_METAL +#include + +#if !TARGET_OS_SIMULATOR #include "ggml-metal.h" #endif - + +#endif + #ifdef LM_GGML_USE_SYCL #include "ggml-sycl.h" #endif -@@ -52,8 +57,12 @@ +@@ -142,7 +147,11 @@ register_backend(lm_ggml_backend_cuda_reg()); #endif #ifdef LM_GGML_USE_METAL + +#if !TARGET_OS_SIMULATOR register_backend(lm_ggml_backend_metal_reg()); - #endif -+ +#endif ++ + #endif #ifdef LM_GGML_USE_SYCL register_backend(lm_ggml_backend_sycl_reg()); - #endif diff --git a/scripts/ggml-cpu-aarch64.c.patch b/scripts/ggml-cpu-aarch64.c.patch deleted file mode 100644 index 5cadaa07..00000000 --- a/scripts/ggml-cpu-aarch64.c.patch +++ /dev/null @@ -1,11 +0,0 @@ ---- ggml-cpu-aarch64.c.orig 2024-11-17 12:15:45 -+++ ggml-cpu-aarch64.c 2024-11-17 12:15:56 -@@ -8,7 +8,7 @@ - #include "ggml-quants.h" - #include "ggml-impl.h" - #include "ggml-cpu.h" --#include "ggml-cpu/ggml-cpu-impl.h" -+#include "ggml-cpu-impl.h" - - #include - #include diff --git a/scripts/ggml-metal.m.patch b/scripts/ggml-metal.m.patch index 3da0d3d2..b971bf32 100644 --- a/scripts/ggml-metal.m.patch +++ b/scripts/ggml-metal.m.patch @@ -1,11 +1,11 @@ ---- ggml-metal.m.orig 2024-11-21 11:03:19 -+++ ggml-metal.m 2024-11-21 11:03:20 -@@ -463,7 +463,7 @@ +--- ggml-metal.m.orig 2024-12-20 13:36:22 ++++ ggml-metal.m 2024-12-20 13:37:17 +@@ -509,7 +509,7 @@ const bool try_metallib = true; #endif - NSString * path_lib = [bundle pathForResource:@"default" ofType:@"metallib"]; + NSString * path_lib = [bundle pathForResource:@"ggml-llama" ofType:@"metallib"]; - if (try_metallib && path_lib != nil) { - // pre-compiled library found - NSURL * libURL = [NSURL fileURLWithPath:path_lib]; + if (path_lib == nil) { + // Try to find the resource in the directory where the current binary located. + NSString * current_binary = [[NSProcessInfo processInfo] arguments][0]; diff --git a/scripts/ggml.c.patch b/scripts/ggml.c.patch index f104777b..aef334a3 100644 --- a/scripts/ggml.c.patch +++ b/scripts/ggml.c.patch @@ -1,6 +1,6 @@ ---- ggml.c.orig 2024-11-17 12:20:04 -+++ ggml.c 2024-11-17 12:20:05 -@@ -114,9 +114,9 @@ +--- ggml.c.orig 2024-12-20 13:38:33 ++++ ggml.c 2024-12-20 13:38:34 +@@ -117,9 +117,9 @@ #elif defined(__linux__) && defined(__GLIBC__) #include static void lm_ggml_print_backtrace_symbols(void) { diff --git a/scripts/llama.cpp.patch b/scripts/llama.cpp.patch index 56e5908c..2aafb402 100644 --- a/scripts/llama.cpp.patch +++ b/scripts/llama.cpp.patch @@ -1,5 +1,5 @@ ---- llama.cpp.orig 2024-11-21 11:03:19 -+++ llama.cpp 2024-11-21 11:03:20 +--- llama.cpp.orig 2024-12-23 10:40:08 ++++ llama.cpp 2024-12-23 10:40:09 @@ -80,6 +80,17 @@ #define LLAMA_MAX_LAYERS 512 #define LLAMA_MAX_EXPERTS 160 // DeepSeekV2 @@ -18,7 +18,7 @@ // // helpers // -@@ -1951,16 +1962,16 @@ +@@ -2160,16 +2171,16 @@ if (prefetch > 0) { // advise the kernel to preload the mapped memory diff --git a/scripts/sgemm.cpp.patch b/scripts/sgemm.cpp.patch deleted file mode 100644 index 590be27a..00000000 --- a/scripts/sgemm.cpp.patch +++ /dev/null @@ -1,12 +0,0 @@ ---- sgemm.cpp.orig 2024-11-17 12:18:43 -+++ sgemm.cpp 2024-11-17 12:19:12 -@@ -50,8 +50,7 @@ - - #include "sgemm.h" - #include "ggml-impl.h" --// hack until moved into the CPU backend --#include "../ggml-cpu-impl.h" -+#include "ggml-cpu-impl.h" - #include "ggml-quants.h" - - #ifdef _MSC_VER diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index e69f37f9..92d8458b 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -125,10 +125,6 @@ export type NativeCompletionParams = { * Repeat alpha presence penalty. Default: `0.0`, which is disabled. */ penalty_present?: number - /** - * Penalize newline tokens when applying the repeat penalty. Default: `false` - */ - penalize_nl?: boolean /** * Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. */