diff --git a/.github/workflows/build-wheels-metal.yaml b/.github/workflows/build-wheels-metal.yaml index e56b16c7..f0011a98 100644 --- a/.github/workflows/build-wheels-metal.yaml +++ b/.github/workflows/build-wheels-metal.yaml @@ -11,7 +11,7 @@ jobs: runs-on: macos-${{ matrix.os }} strategy: matrix: - os: [12, 13, 14] + os: [13, 14, 15] steps: - uses: actions/checkout@v4 diff --git a/.gitmodules b/.gitmodules index e3daca82..2a64b27c 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,7 +5,7 @@ [submodule "dependency/llama.cpp"] path = dependency/llama.cpp url = https://github.com/NexaAI/llama.cpp.git - branch = master + branch = release [submodule "nexa/eval/benchmark_tasks"] path = nexa/eval/benchmark_tasks url = https://github.com/NexaAI/benchmark-tasks.git diff --git a/README.md b/README.md index a61f4633..1bd83309 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ ## Latest News 🔥 -- Support Nexa AI's own vision language model (0.9B parameters): `nexa run omnivision` and audio language model (2.9B parameters): `nexa run omniaudio` +- Support Nexa AI's own vision language model (0.9B parameters): `nexa run omniVLM` and audio language model (2.9B parameters): `nexa run omniaudio` - Support audio language model: `nexa run qwen2audio`, **we are the first open-source toolkit to support audio language model with GGML tensor library.** - Support iOS Swift binding for local inference on **iOS mobile** devices. - Support embedding model: `nexa embed ` @@ -33,13 +33,13 @@ Welcome to submit your requests through [issues](https://github.com/NexaAI/nexa- ## Install Option 1: Executable Installer

- + macOS Installer

- + Windows Installer

@@ -228,7 +228,7 @@ Supported model examples (full list at [Model Hub](https://nexa.ai/models)): | [qwen2audio](https://nexa.ai/Qwen/Qwen2-Audio-7.8B-Instruct/gguf-q4_K_M/readme) | AudioLM | GGUF | `nexa run qwen2audio` | | [octopus-v2](https://www.nexaai.com/NexaAI/Octopus-v2/gguf-q4_0/readme) | Function Call | GGUF | `nexa run octopus-v2` | | [octo-net](https://www.nexaai.com/NexaAI/Octo-net/gguf-q4_0/readme) | Text | GGUF | `nexa run octo-net` | -| [omnivision](https://nexa.ai/NexaAI/omnivision/gguf-fp16/readme) | Multimodal | GGUF | `nexa run omnivision` | +| [omniVLM](https://nexa.ai/NexaAI/omniVLM/gguf-fp16/readme) | Multimodal | GGUF | `nexa run omniVLM` | | [nanollava](https://www.nexaai.com/qnguyen3/nanoLLaVA/gguf-fp16/readme) | Multimodal | GGUF | `nexa run nanollava` | | [llava-phi3](https://www.nexaai.com/xtuner/llava-phi-3-mini/gguf-q4_0/readme) | Multimodal | GGUF | `nexa run llava-phi3` | | [llava-llama3](https://www.nexaai.com/xtuner/llava-llama-3-8b-v1.1/gguf-q4_0/readme) | Multimodal | GGUF | `nexa run llava-llama3` | diff --git a/dependency/llama.cpp b/dependency/llama.cpp index 64a6001a..b2958b33 160000 --- a/dependency/llama.cpp +++ b/dependency/llama.cpp @@ -1 +1 @@ -Subproject commit 64a6001a1a408129eb510f49840947876220c5fa +Subproject commit b2958b33ddd4c8f13c98fb1c1249ac067769df91 diff --git a/docs/README.md b/docs/README.md index 252116f7..d4081d2e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -28,12 +28,16 @@ pip install nexaai[onnx] # if you need ONNX support ``` ### build from source + To build C++ only + ``` cmake -B build -S . cmake --build build --config Release -j32 ``` + To build C++ and install python package from source, run the following commands: + ```bash git clone --recursive https://github.com/NexaAI/nexa-sdk.git cd nexa-sdk @@ -75,7 +79,7 @@ python -m nexa.gguf.nexa_inference_text gemma python -m nexa.gguf.nexa_inference_text octopusv2 --stop_words "" wget https://assets-c4akfrf5b4d3f4b7.z01.azurefd.net/assets/2024/04/BMDataViz_661fb89f3845e.png -O test.png python -m nexa.gguf.nexa_inference_vlm nanollava -python -m nexa.gguf.nexa_inference_vlm_omni omnivision +python -m nexa.gguf.nexa_inference_vlm_omni omniVLM python -m nexa.gguf.nexa_inference_image sd1-4 python -m nexa.gguf.nexa_inference_image sd1-4 --img2img wget -O control_normal-fp16.safetensors https://huggingface.co/webui/ControlNet-modules-safetensors/resolve/main/control_normal-fp16.safetensors @@ -235,7 +239,9 @@ dumpbin /dependents your_executable_or_dll.dll # in Developer PowerShell for Vi ``` ### Debug dynamic lib + According to [isse](https://github.com/abetlen/llama-cpp-python/issues/1346), below can check the exported symbols on linux. + ``` readelf -Ws --dyn-syms libllama.so -``` \ No newline at end of file +``` diff --git a/nexa/__init__.py b/nexa/__init__.py index 26e7b666..af51d3c5 100644 --- a/nexa/__init__.py +++ b/nexa/__init__.py @@ -1 +1 @@ -__version__ = "0.0.9.5" +__version__ = "0.0.9.6" diff --git a/nexa/constants.py b/nexa/constants.py index 24acd195..51d6e051 100644 --- a/nexa/constants.py +++ b/nexa/constants.py @@ -188,8 +188,8 @@ class ModelType(Enum): "omnivision-preview": "omnivision-preview:projector-fp16", "omnivision-preview:fp16": "omnivision-preview:projector-fp16", "omnivision-preview:q4_0": "omnivision-preview:projector-q4_0", - "omnivision": "omnivision:projector-fp16", - "omnivision:fp16": "omnivision:projector-fp16", + "omniVLM": "omniVLM:projector-fp16", + "omniVLM:fp16": "omniVLM:projector-fp16", "omnivision-ocr": "omnivision-ocr:projector-fp16", "omnivision-ocr:fp16": "omnivision-ocr:projector-fp16", } @@ -198,8 +198,8 @@ class ModelType(Enum): "omnivision-preview": "omnivision-preview:model-fp16", "omnivision-preview:fp16": "omnivision-preview:model-fp16", "omnivision-preview:q4_0": "omnivision-preview:model-q4_0", - "omnivision": "omnivision:model-fp16", - "omnivision:fp16": "omnivision:model-fp16", + "omniVLM": "omniVLM:model-fp16", + "omniVLM:fp16": "omniVLM:model-fp16", "omnivision-ocr": "omnivision-ocr:model-fp16", "omnivision-ocr:fp16": "omnivision-ocr:model-fp16", } @@ -461,7 +461,7 @@ class ModelType(Enum): "FLUX.1-schnell": ModelType.COMPUTER_VISION, "Phi-3-vision-128k-instruct": ModelType.MULTIMODAL, "omnivision-preview": ModelType.MULTIMODAL, - "omnivision": ModelType.MULTIMODAL, + "omniVLM": ModelType.MULTIMODAL, "omnivision-ocr": ModelType.MULTIMODAL, "nanoLLaVA": ModelType.MULTIMODAL, "llava-v1.6-mistral-7b": ModelType.MULTIMODAL, diff --git a/nexa/gguf/llama/__init__.py b/nexa/gguf/llama/__init__.py index e69de29b..b3dcd6ed 100644 --- a/nexa/gguf/llama/__init__.py +++ b/nexa/gguf/llama/__init__.py @@ -0,0 +1,2 @@ +from nexa.gguf.llama.llama_cpp import * +from nexa.gguf.llama.llama import * diff --git a/nexa/gguf/llama/_ctypes_extensions.py b/nexa/gguf/llama/_ctypes_extensions.py new file mode 100644 index 00000000..c27f5c04 --- /dev/null +++ b/nexa/gguf/llama/_ctypes_extensions.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import sys +import os +import ctypes +import functools +import pathlib + +from typing import ( + Any, + Callable, + List, + Union, + Optional, + TYPE_CHECKING, + TypeVar, + Generic, +) +from typing_extensions import TypeAlias + + +# ctypes sane type hint helpers +# +# - Generic Pointer and Array types +# - PointerOrRef type with a type hinted byref function +# +# NOTE: Only use these for static type checking not for runtime checks +# no good will come of that + +if TYPE_CHECKING: + CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore + + CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore + + CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore + + CtypesVoidPointer: TypeAlias = ctypes.c_void_p + + class CtypesRef(Generic[CtypesCData]): + pass + + CtypesPointerOrRef: TypeAlias = Union[ + CtypesPointer[CtypesCData], CtypesRef[CtypesCData] + ] + + CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore + +F = TypeVar("F", bound=Callable[..., Any]) + + +def ctypes_function_for_shared_library(lib: ctypes.CDLL): + """Decorator for defining ctypes functions with type hints""" + + def ctypes_function( + name: str, argtypes: List[Any], restype: Any, enabled: bool = True + ): + def decorator(f: F) -> F: + if enabled: + func = getattr(lib, name) + func.argtypes = argtypes + func.restype = restype + functools.wraps(f)(func) + return func + else: + return f + + return decorator + + return ctypes_function + + +def _byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]: + """Type-annotated version of ctypes.byref""" + ... + + +byref = _byref if TYPE_CHECKING else ctypes.byref diff --git a/nexa/gguf/llama/_ggml.py b/nexa/gguf/llama/_ggml.py new file mode 100644 index 00000000..5b175d4c --- /dev/null +++ b/nexa/gguf/llama/_ggml.py @@ -0,0 +1,11 @@ +"""Internal module use at your own risk + +This module provides a minimal interface for working with ggml tensors from llama-cpp-python +""" +import os +import pathlib + +from nexa.gguf.lib_utils import load_library + +libggml = load_library("ggml") + diff --git a/nexa/gguf/llama/_internals_transformers.py b/nexa/gguf/llama/_internals_transformers.py index 7646563f..5d625b8d 100644 --- a/nexa/gguf/llama/_internals_transformers.py +++ b/nexa/gguf/llama/_internals_transformers.py @@ -6,6 +6,7 @@ from typing import ( Dict, List, + Tuple, Optional, Sequence, ) @@ -19,13 +20,12 @@ from nexa.gguf.llama.llama_grammar import LlamaGrammar from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr -import nexa.gguf.llama.llama_cpp as llama_cpp - +from nexa.gguf.llama.llama_cpp import * # Python wrappers over llama.h structs -class _LlamaModel: +class LlamaModel: """Intermediate Python wrapper for a llama.cpp llama_model. NOTE: For stability it's recommended you use the Llama class instead.""" @@ -33,7 +33,7 @@ def __init__( self, *, path_model: str, - params: llama_cpp.llama_model_params, + params: llama_model_params, verbose: bool = True, ): self.path_model = path_model @@ -41,23 +41,25 @@ def __init__( self.verbose = verbose self._exit_stack = ExitStack() - self.model = None + model = None if not os.path.exists(path_model): raise ValueError(f"Model path does not exist: {path_model}") with suppress_stdout_stderr(disable=verbose): - self.model = llama_cpp.llama_load_model_from_file( + model = llama_load_model_from_file( self.path_model.encode("utf-8"), self.params ) - if self.model is None: + if model is None: raise ValueError(f"Failed to load model from file: {path_model}") + self.model = model + def free_model(): if self.model is None: return - llama_cpp.llama_free_model(self.model) + llama_free_model(self.model) self.model = None self._exit_stack.callback(free_model) @@ -69,137 +71,92 @@ def __del__(self): self.close() def vocab_type(self) -> int: - assert self.model is not None - return llama_cpp.llama_vocab_type(self.model) + return llama_vocab_type(self.model) def n_vocab(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_vocab(self.model) + return llama_n_vocab(self.model) def n_ctx_train(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_ctx_train(self.model) + return llama_n_ctx_train(self.model) def n_embd(self) -> int: - assert self.model is not None - return llama_cpp.llama_n_embd(self.model) + return llama_n_embd(self.model) def rope_freq_scale_train(self) -> float: - assert self.model is not None - return llama_cpp.llama_rope_freq_scale_train(self.model) + return llama_rope_freq_scale_train(self.model) def desc(self) -> str: - assert self.model is not None buf = ctypes.create_string_buffer(1024) - llama_cpp.llama_model_desc(self.model, buf, 1024) + llama_model_desc(self.model, buf, 1024) return buf.value.decode("utf-8") def size(self) -> int: - assert self.model is not None - return llama_cpp.llama_model_size(self.model) + return llama_model_size(self.model) def n_params(self) -> int: - assert self.model is not None - return llama_cpp.llama_model_n_params(self.model) + return llama_model_n_params(self.model) def get_tensor(self, name: str) -> ctypes.c_void_p: - assert self.model is not None - return llama_cpp.llama_get_model_tensor(self.model, name.encode("utf-8")) - - def apply_lora_from_file( - self, - lora_path: str, - scale: float, - path_base_model: Optional[str], - n_threads: int, - ): - assert self.model is not None - return llama_cpp.llama_model_apply_lora_from_file( - self.model, - lora_path.encode("utf-8"), - scale, - ( - path_base_model.encode("utf-8") - if path_base_model is not None - else ctypes.c_char_p(0) - ), - n_threads, - ) + return llama_get_model_tensor(self.model, name.encode("utf-8")) # Vocab def token_get_text(self, token: int) -> str: - # TODO: Fix - assert self.model is not None - return llama_cpp.llama_token_get_text(self.model, token).decode("utf-8") + return llama_token_get_text(self.model, token).decode("utf-8") def token_get_score(self, token: int) -> float: - assert self.model is not None - return llama_cpp.llama_token_get_score(self.model, token) + return llama_token_get_score(self.model, token) def token_get_attr(self, token: int) -> int: - assert self.model is not None - return llama_cpp.llama_token_get_attr(self.model, token) + return llama_token_get_attr(self.model, token) # Special tokens def token_bos(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_bos(self.model) + return llama_token_bos(self.model) def token_eos(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_eos(self.model) + return llama_token_eos(self.model) def token_cls(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_cls(self.model) + return llama_token_cls(self.model) def token_sep(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_sep(self.model) + return llama_token_sep(self.model) def token_nl(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_nl(self.model) + return llama_token_nl(self.model) def token_prefix(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_prefix(self.model) + return llama_token_prefix(self.model) def token_middle(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_middle(self.model) + return llama_token_middle(self.model) def token_suffix(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_suffix(self.model) + return llama_token_suffix(self.model) def token_eot(self) -> int: - assert self.model is not None - return llama_cpp.llama_token_eot(self.model) + return llama_token_eot(self.model) def add_bos_token(self) -> bool: - assert self.model is not None - return llama_cpp.llama_add_bos_token(self.model) + return llama_add_bos_token(self.model) def add_eos_token(self) -> bool: - assert self.model is not None - return llama_cpp.llama_add_eos_token(self.model) + return llama_add_eos_token(self.model) # Tokenization def tokenize(self, text: bytes, add_bos: bool, special: bool): - assert self.model is not None n_ctx = self.n_ctx_train() - tokens = (llama_cpp.llama_token * n_ctx)() - n_tokens = llama_cpp.llama_tokenize( + tokens = (llama_token * n_ctx)() + n_tokens = llama_tokenize( self.model, text, len(text), tokens, n_ctx, add_bos, special ) if n_tokens < 0: n_tokens = abs(n_tokens) - tokens = (llama_cpp.llama_token * n_tokens)() - n_tokens = llama_cpp.llama_tokenize( + tokens = (llama_token * n_tokens)() + n_tokens = llama_tokenize( self.model, text, len(text), tokens, n_tokens, add_bos, special ) if n_tokens < 0: @@ -209,19 +166,17 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool): return list(tokens[:n_tokens]) def token_to_piece(self, token: int, special: bool = False) -> bytes: - assert self.model is not None buf = ctypes.create_string_buffer(32) - llama_cpp.llama_token_to_piece(self.model, token, buf, 32, 0, special) + llama_token_to_piece(self.model, token, buf, 32, 0, special) return bytes(buf) def detokenize(self, tokens: List[int], special: bool = False) -> bytes: - assert self.model is not None output = b"" size = 32 buffer = (ctypes.c_char * size)() for token in tokens: - n = llama_cpp.llama_token_to_piece( - self.model, llama_cpp.llama_token(token), buffer, size, 0, special + n = llama_token_to_piece( + self.model, llama_token(token), buffer, size, 0, special ) assert n <= size output += bytes(buffer[:n]) @@ -235,31 +190,30 @@ def detokenize(self, tokens: List[int], special: bool = False) -> bytes: # Extra def metadata(self) -> Dict[str, str]: - assert self.model is not None metadata: Dict[str, str] = {} buffer_size = 1024 buffer = ctypes.create_string_buffer(buffer_size) # zero the buffer buffer.value = b"\0" * buffer_size # iterate over model keys - for i in range(llama_cpp.llama_model_meta_count(self.model)): - nbytes = llama_cpp.llama_model_meta_key_by_index( + for i in range(llama_model_meta_count(self.model)): + nbytes = llama_model_meta_key_by_index( self.model, i, buffer, buffer_size ) if nbytes > buffer_size: buffer_size = nbytes + 1 buffer = ctypes.create_string_buffer(buffer_size) - nbytes = llama_cpp.llama_model_meta_key_by_index( + nbytes = llama_model_meta_key_by_index( self.model, i, buffer, buffer_size ) key = buffer.value.decode("utf-8") - nbytes = llama_cpp.llama_model_meta_val_str_by_index( + nbytes = llama_model_meta_val_str_by_index( self.model, i, buffer, buffer_size ) if nbytes > buffer_size: buffer_size = nbytes + 1 buffer = ctypes.create_string_buffer(buffer_size) - nbytes = llama_cpp.llama_model_meta_val_str_by_index( + nbytes = llama_model_meta_val_str_by_index( self.model, i, buffer, buffer_size ) value = buffer.value.decode("utf-8") @@ -269,18 +223,18 @@ def metadata(self) -> Dict[str, str]: @staticmethod def default_params(): """Get the default llama_model_params.""" - return llama_cpp.llama_model_default_params() + return llama_model_default_params() -class _LlamaContext: +class LlamaContext: """Intermediate Python wrapper for a llama.cpp llama_context. NOTE: For stability it's recommended you use the Llama class instead.""" def __init__( self, *, - model: _LlamaModel, - params: llama_cpp.llama_context_params, + model: LlamaModel, + params: llama_context_params, verbose: bool = True, ): self.model = model @@ -288,19 +242,17 @@ def __init__( self.verbose = verbose self._exit_stack = ExitStack() - self.ctx = None - - assert self.model.model is not None - - self.ctx = llama_cpp.llama_new_context_with_model(self.model.model, self.params) + ctx = llama_new_context_with_model(self.model.model, self.params) - if self.ctx is None: + if ctx is None: raise ValueError("Failed to create llama_context") + self.ctx = ctx + def free_ctx(): if self.ctx is None: return - llama_cpp.llama_free(self.ctx) + llama_free(self.ctx) self.ctx = None self._exit_stack.callback(free_ctx) @@ -312,41 +264,39 @@ def __del__(self): self.close() def n_ctx(self) -> int: - assert self.ctx is not None - return llama_cpp.llama_n_ctx(self.ctx) + return llama_n_ctx(self.ctx) def pooling_type(self) -> int: - assert self.ctx is not None - return llama_cpp.llama_pooling_type(self.ctx) + return llama_pooling_type(self.ctx) def kv_cache_clear(self): - assert self.ctx is not None - llama_cpp.llama_kv_cache_clear(self.ctx) + llama_kv_cache_clear(self.ctx) def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) + llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1) def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) + llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1) def kv_cache_seq_keep(self, seq_id: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id) + llama_kv_cache_seq_keep(self.ctx, seq_id) def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int): - assert self.ctx is not None - llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift) + llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift) def get_state_size(self) -> int: - assert self.ctx is not None - return llama_cpp.llama_get_state_size(self.ctx) + return llama_get_state_size(self.ctx) + + # TODO: copy_state_data + + # TODO: set_state_data + + # TODO: llama_load_session_file + + # TODO: llama_save_session_file - def decode(self, batch: "_LlamaBatch"): - assert self.ctx is not None - assert batch.batch is not None - return_code = llama_cpp.llama_decode( + def decode(self, batch: LlamaBatch): + return_code = llama_decode( self.ctx, batch.batch, ) @@ -354,40 +304,35 @@ def decode(self, batch: "_LlamaBatch"): raise RuntimeError(f"llama_decode returned {return_code}") def set_n_threads(self, n_threads: int, n_threads_batch: int): - assert self.ctx is not None - llama_cpp.llama_set_n_threads(self.ctx, n_threads, n_threads_batch) + llama_set_n_threads(self.ctx, n_threads, n_threads_batch) def get_logits(self): - assert self.ctx is not None - return llama_cpp.llama_get_logits(self.ctx) + return llama_get_logits(self.ctx) def get_logits_ith(self, i: int): - assert self.ctx is not None - return llama_cpp.llama_get_logits_ith(self.ctx, i) + return llama_get_logits_ith(self.ctx, i) def get_embeddings(self): - assert self.ctx is not None - return llama_cpp.llama_get_embeddings(self.ctx) + return llama_get_embeddings(self.ctx) # Sampling functions def set_rng_seed(self, seed: int): - assert self.ctx is not None - llama_cpp.llama_set_rng_seed(self.ctx, seed) + # TODO: Fix + llama_set_rng_seed(self.ctx, seed) def sample_repetition_penalties( self, candidates: "_LlamaTokenDataArray", - last_tokens_data: "llama_cpp.Array[llama_cpp.llama_token]", + last_tokens_data: "Array[llama_token]", penalty_last_n: int, penalty_repeat: float, penalty_freq: float, penalty_present: float, ): - assert self.ctx is not None - llama_cpp.llama_sample_repetition_penalties( + llama_sample_repetition_penalties( self.ctx, - llama_cpp.byref(candidates.candidates), + byref(candidates.candidates), last_tokens_data, penalty_last_n, penalty_repeat, @@ -396,58 +341,42 @@ def sample_repetition_penalties( ) def sample_softmax(self, candidates: "_LlamaTokenDataArray"): - assert self.ctx is not None - llama_cpp.llama_sample_softmax( + llama_sample_softmax( self.ctx, - llama_cpp.byref(candidates.candidates), + byref(candidates.candidates), ) def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_top_k( - self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep + llama_sample_top_k( + self.ctx, byref(candidates.candidates), k, min_keep ) def sample_top_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_top_p( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + llama_sample_top_p( + self.ctx, byref(candidates.candidates), p, min_keep ) def sample_min_p(self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int): - assert self.ctx is not None - llama_cpp.llama_sample_min_p( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep - ) - - def sample_tail_free( - self, candidates: "_LlamaTokenDataArray", z: float, min_keep: int - ): - assert self.ctx is not None - llama_cpp.llama_sample_tail_free( - self.ctx, llama_cpp.byref(candidates.candidates), z, min_keep + llama_sample_min_p( + self.ctx, byref(candidates.candidates), p, min_keep ) def sample_typical( self, candidates: "_LlamaTokenDataArray", p: float, min_keep: int ): - assert self.ctx is not None - llama_cpp.llama_sample_typical( - self.ctx, llama_cpp.byref(candidates.candidates), p, min_keep + llama_sample_typical( + self.ctx, byref(candidates.candidates), p, min_keep ) def sample_temp(self, candidates: "_LlamaTokenDataArray", temp: float): - assert self.ctx is not None - llama_cpp.llama_sample_temp( - self.ctx, llama_cpp.byref(candidates.candidates), temp + llama_sample_temp( + self.ctx, byref(candidates.candidates), temp ) def sample_grammar(self, candidates: "_LlamaTokenDataArray", grammar: LlamaGrammar): - assert self.ctx is not None - assert grammar.grammar is not None - llama_cpp.llama_sample_grammar( + llama_sample_grammar( self.ctx, - llama_cpp.byref(candidates.candidates), + byref(candidates.candidates), grammar.grammar, ) @@ -457,12 +386,11 @@ def sample_token_mirostat( tau: float, eta: float, m: int, - mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], + mu: CtypesPointerOrRef[ctypes.c_float], ) -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_mirostat( + return llama_sample_token_mirostat( self.ctx, - llama_cpp.byref(candidates.candidates), + byref(candidates.candidates), tau, eta, m, @@ -474,53 +402,46 @@ def sample_token_mirostat_v2( candidates: "_LlamaTokenDataArray", tau: float, eta: float, - mu: llama_cpp.CtypesPointerOrRef[ctypes.c_float], + mu: CtypesPointerOrRef[ctypes.c_float], ) -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_mirostat_v2( + return llama_sample_token_mirostat_v2( self.ctx, - llama_cpp.byref(candidates.candidates), + byref(candidates.candidates), tau, eta, mu, ) def sample_token_greedy(self, candidates: "_LlamaTokenDataArray") -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token_greedy( + return llama_sample_token_greedy( self.ctx, - llama_cpp.byref(candidates.candidates), + byref(candidates.candidates), ) def sample_token(self, candidates: "_LlamaTokenDataArray") -> int: - assert self.ctx is not None - return llama_cpp.llama_sample_token( + return llama_sample_token( self.ctx, - llama_cpp.byref(candidates.candidates), + byref(candidates.candidates), ) # Grammar def grammar_accept_token(self, grammar: LlamaGrammar, token: int): - assert self.ctx is not None - assert grammar.grammar is not None - llama_cpp.llama_grammar_accept_token(grammar.grammar, self.ctx, token) + llama_grammar_accept_token(grammar.grammar, self.ctx, token) def reset_timings(self): - assert self.ctx is not None - llama_cpp.llama_reset_timings(self.ctx) + llama_perf_context_reset(self.ctx) def print_timings(self): - assert self.ctx is not None - llama_cpp.llama_print_timings(self.ctx) + llama_perf_context_print(self.ctx) # Utility functions @staticmethod def default_params(): """Get the default llama_context_params.""" - return llama_cpp.llama_context_default_params() + return llama_context_default_params() -class _LlamaBatch: +class LlamaBatch: def __init__( self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True ): @@ -530,15 +451,17 @@ def __init__( self.verbose = verbose self._exit_stack = ExitStack() - self.batch = None - self.batch = llama_cpp.llama_batch_init( - self._n_tokens, self.embd, self.n_seq_max - ) + batch = llama_batch_init(self._n_tokens, self.embd, self.n_seq_max) + + if batch is None: + raise ValueError("Failed to create llama_batch") + + self.batch = batch def free_batch(): if self.batch is None: return - llama_cpp.llama_batch_free(self.batch) + llama_batch_free(self.batch) self.batch = None self._exit_stack.callback(free_batch) @@ -550,15 +473,12 @@ def __del__(self): self.close() def n_tokens(self) -> int: - assert self.batch is not None return self.batch.n_tokens def reset(self): - assert self.batch is not None self.batch.n_tokens = 0 def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): - assert self.batch is not None n_tokens = len(batch) self.batch.n_tokens = n_tokens for i in range(n_tokens): @@ -570,7 +490,6 @@ def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): self.batch.logits[n_tokens - 1] = True def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): - assert self.batch is not None n_tokens = len(batch) n_tokens0 = self.batch.n_tokens self.batch.n_tokens += n_tokens @@ -584,7 +503,7 @@ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): self.batch.logits[n_tokens - 1] = True -class _LlamaTokenDataArray: +class LlamaTokenDataArray: def __init__(self, *, n_vocab: int): self.n_vocab = n_vocab self.candidates_data = np.recarray( @@ -593,8 +512,8 @@ def __init__(self, *, n_vocab: int): [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True ), ) - self.candidates = llama_cpp.llama_token_data_array( - data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), + self.candidates = llama_token_data_array( + data=self.candidates_data.ctypes.data_as(llama_token_data_p), size=self.n_vocab, sorted=False, ) @@ -609,90 +528,10 @@ def copy_logits(self, logits: npt.NDArray[np.single]): self.candidates.size = self.n_vocab -# Python wrappers over common/common -def _tokenize(model: _LlamaModel, text: str, add_bos: bool, special: bool) -> list[int]: - assert model.model is not None - n_tokens = len(text) + 1 if add_bos else len(text) - result = (llama_cpp.llama_token * n_tokens)() - n_tokens = llama_cpp.llama_tokenize( - model.model, - text.encode("utf-8"), - len(text), - result, - n_tokens, - add_bos, - special, - ) - if n_tokens < 0: - result = (llama_cpp.llama_token * -n_tokens)() - check = llama_cpp.llama_tokenize( - model.model, - text.encode("utf-8"), - len(text), - result, - len(result), - add_bos, - special, - ) - if check != -n_tokens: - raise RuntimeError(f'Failed to tokenize: text="{text}" n_tokens={n_tokens}') - else: - result = result[:n_tokens] - return list(result) - - -def _token_to_piece(model: _LlamaModel, token: int, special: bool = False) -> str: - assert model.model is not None - result = (ctypes.c_char * 8)(0) - n_tokens = llama_cpp.llama_token_to_piece( - model.model, token, result, 0, len(result), special - ) - if n_tokens < 0: - result = (ctypes.c_char * -n_tokens)(0) - check = llama_cpp.llama_token_to_piece( - model.model, token, result, 0, len(result), special - ) - if check != -n_tokens: - raise RuntimeError(f"Failed to get piece: token={token}") - else: - result = result[:n_tokens] - return bytes(result).decode("utf-8") - - -def _detokenize_spm(model: _LlamaModel, tokens: List[int]) -> str: - bos_id = model.token_bos() - result = "" - for i, token in enumerate(tokens): - piece = _token_to_piece(model, token) - if ( - (tokens[0] == bos_id and i == 1) or (tokens[0] != bos_id and i == 0) - ) and piece[0] == " ": - piece = piece[1:] - result += piece - return result - - -def _detokenize_bpe(model: _LlamaModel, tokens: List[int]) -> str: - result = "" - for token in tokens: - piece = _token_to_piece(model, token) - result += piece - return result - - -def _should_add_bos(model: _LlamaModel) -> bool: - assert model.model is not None - add_bos = llama_cpp.llama_add_bos_token(model.model) - if add_bos: - return add_bos - else: - return llama_cpp.llama_vocab_type(model.model) == llama_cpp.LLAMA_VOCAB_TYPE_SPM - - # Embedding functions -def _normalize_embedding(embedding): +def normalize_embedding(embedding): norm = float(np.linalg.norm(embedding)) if norm == 0.0: return embedding @@ -703,7 +542,7 @@ def _normalize_embedding(embedding): @dataclass -class _LlamaSamplingParams: +class LlamaSamplingParams: n_prev: int = 64 n_probs: int = 0 top_k: int = 40 @@ -730,13 +569,13 @@ class _LlamaSamplingParams: @dataclass -class _LlamaSamplingContext: - params: _LlamaSamplingParams = field(default_factory=_LlamaSamplingParams) +class LlamaSamplingContext: + params: LlamaSamplingParams = field(default_factory=LlamaSamplingParams) mirostat_mu: ctypes.c_float = field(default_factory=ctypes.c_float) grammar: Optional[LlamaGrammar] = None # NOTE: Missing parsed_grammar prev: list[int] = field(default_factory=list) - cur: list[llama_cpp.llama_token_data] = field(default_factory=list) + cur: list[llama_token_data] = field(default_factory=list) def reset(self): self.prev = [] @@ -745,7 +584,7 @@ def reset(self): self.grammar.reset() def cp(self): - return _LlamaSamplingContext( + return LlamaSamplingContext( params=self.params, mirostat_mu=self.mirostat_mu, grammar=self.grammar, @@ -759,12 +598,12 @@ def last(self) -> Optional[int]: else: return None - def prev_str(self, ctx_main: _LlamaContext, n: int) -> str: + def prev_str(self, ctx_main: LlamaContext, n: int) -> str: return ctx_main.model.detokenize(self.prev[-n:]).decode("utf-8") def sample( self, - ctx_main: _LlamaContext, + ctx_main: LlamaContext, idx: int = 0, logits_array: Optional[npt.NDArray[np.single]] = None, ): @@ -782,7 +621,7 @@ def sample( for token, logit_bias in self.params.logit_bias.items(): logits_array[token] += logit_bias - token_data_array = _LlamaTokenDataArray( + token_data_array = LlamaTokenDataArray( n_vocab=n_vocab ) # TODO: Only create this once token_data_array.copy_logits(logits_array) @@ -794,7 +633,7 @@ def sample( last_tokens = self.prev[-self.params.penalty_last_n :] last_tokens_size = min(len(last_tokens), self.params.penalty_last_n) if last_tokens_size > 0: - last_tokens_p = (llama_cpp.llama_token * len(last_tokens))(*last_tokens) + last_tokens_p = (llama_token * len(last_tokens))(*last_tokens) ctx_main.sample_repetition_penalties( token_data_array, last_tokens_p, @@ -838,9 +677,6 @@ def sample( ctx_main.sample_top_k( token_data_array, self.params.top_k, min_keep=min_keep ) - ctx_main.sample_tail_free( - token_data_array, self.params.tfs_z, min_keep=min_keep - ) ctx_main.sample_typical( token_data_array, self.params.typical_p, min_keep=min_keep ) @@ -854,7 +690,173 @@ def sample( id = ctx_main.sample_token(token_data_array) return id - def accept(self, ctx_main: _LlamaContext, id: int, apply_grammar: bool): + def accept(self, ctx_main: LlamaContext, id: int, apply_grammar: bool): if apply_grammar and self.grammar is not None: ctx_main.grammar_accept_token(self.grammar, id) - self.prev.append(id) \ No newline at end of file + self.prev.append(id) + + +from typing import List, Callable, Optional, Union +import ctypes +import nexa.gguf.llama.llama_cpp + + +class CustomSampler: + def __init__( + self, apply_func: typing.Callable[[llama_token_data_array], None] + ): + self.apply_func = apply_func + + def apply_wrapper( + sampler: llama_sampler_p, + cur_p: llama_token_data_array_p, + ): + self.apply_func(cur_p) + + def free_wrapper(sampler: llama_sampler_p): + pass + + sampler_i = llama_sampler_i() + sampler_i.apply = llama_sampler_i_apply(apply_wrapper) + self._apply_wrapper_ref = apply_wrapper + + sampler_i.name = llama_sampler_i_name(0) + sampler_i.accept = llama_sampler_i_accept(0) + sampler_i.reset = llama_sampler_i_reset(0) + sampler_i.clone = llama_sampler_i_clone(0) + sampler_i.free = llama_sampler_i_free(0) + + self.sampler = llama_sampler() + self.sampler.iface = ctypes.pointer(sampler_i) + self.sampler.ctx = None + + def get_sampler(self) -> llama_sampler_p: + return ctypes.pointer(self.sampler) + + +class LlamaSampler: + def __init__(self): + params = llama_sampler_chain_params() + self.sampler = llama_sampler_chain_init(params) + self.samplers: List[llama_sampler_p] = [] + self.custom_samplers: List[Tuple[int, CustomSampler]] = [] + + def add_greedy(self): + sampler = llama_sampler_init_greedy() + self._add_sampler(sampler) + + def add_dist(self, seed: int): + sampler = llama_sampler_init_dist(seed) + self._add_sampler(sampler) + + def add_softmax(self): + sampler = llama_sampler_init_softmax() + self._add_sampler(sampler) + + def add_top_k(self, k: int): + sampler = llama_sampler_init_top_k(k) + self._add_sampler(sampler) + + def add_top_p(self, p: float, min_keep: int): + sampler = llama_sampler_init_top_p(p, min_keep) + self._add_sampler(sampler) + + def add_min_p(self, p: float, min_keep: int): + sampler = llama_sampler_init_min_p(p, min_keep) + self._add_sampler(sampler) + + def add_typical(self, p: float, min_keep: int): + sampler = llama_sampler_init_typical(p, min_keep) + self._add_sampler(sampler) + + def add_temp(self, temp: float): + sampler = llama_sampler_init_temp(temp) + self._add_sampler(sampler) + + def add_temp_ext(self, t: float, delta: float, exponent: float): + sampler = llama_sampler_init_temp_ext(t, delta, exponent) + self._add_sampler(sampler) + + def add_mirostat(self, n_vocab: int, seed: int, tau: float, eta: float, m: int): + sampler = llama_sampler_init_mirostat(n_vocab, seed, tau, eta, m) + self._add_sampler(sampler) + + def add_mirostat_v2(self, seed: int, tau: float, eta: float): + sampler = llama_sampler_init_mirostat_v2(seed, tau, eta) + self._add_sampler(sampler) + + def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar): + sampler = llama_sampler_init_grammar( + model.model, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8") + ) + self._add_sampler(sampler) + + def add_penalties( + self, + n_vocab: int, + special_eos_id: int, + linefeed_id: int, + penalty_last_n: int, + penalty_repeat: float, + penalty_freq: float, + penalty_present: float, + penalize_nl: bool, + ignore_eos: bool, + ): + sampler = llama_sampler_init_penalties( + n_vocab, + special_eos_id, + linefeed_id, + penalty_last_n, + penalty_repeat, + penalty_freq, + penalty_present, + penalize_nl, + ignore_eos, + ) + self._add_sampler(sampler) + + def init_logit_bias( + self, n_vocab: int, n_logit_bias, logit_bias: llama_logit_bias_p + ): + sampler = llama_sampler_init_logit_bias( + n_vocab, n_logit_bias, logit_bias + ) + self._add_sampler(sampler) + + def add_custom( + self, apply_func: Callable[[llama_token_data_array], None] + ): + custom_sampler = CustomSampler(apply_func) + sampler = custom_sampler.get_sampler() + self._add_sampler(sampler) + # NOTE: Must remove custom samplers before free or llama.cpp will try to free them + self.custom_samplers.append( + (llama_sampler_chain_n(self.sampler) - 1, custom_sampler) + ) + + def _add_sampler(self, sampler: llama_sampler_p): + assert self.sampler is not None + llama_sampler_chain_add(self.sampler, sampler) + self.samplers.append(sampler) + + def get_seed(self) -> int: + assert self.sampler is not None + return llama_sampler_get_seed(self.sampler) + + def sample(self, ctx: LlamaContext, idx: int) -> int: + assert self.sampler is not None + return llama_sampler_sample(self.sampler, ctx.ctx, idx) + + def close(self): + if self.sampler: + # NOTE: Must remove custom samplers before free or llama.cpp will try to free them + for i, _ in reversed(self.custom_samplers): + llama_sampler_chain_remove(self.sampler, i) + llama_sampler_free(self.sampler) + self.sampler = None + self.samplers.clear() + self.custom_samplers.clear() + + def __del__(self): + self.close() diff --git a/nexa/gguf/llama/_logger_transformers.py b/nexa/gguf/llama/_logger_transformers.py index 83721274..2fb0b209 100644 --- a/nexa/gguf/llama/_logger_transformers.py +++ b/nexa/gguf/llama/_logger_transformers.py @@ -2,8 +2,24 @@ import ctypes import logging -from nexa.gguf.llama import llama_cpp +import nexa.gguf.llama as llama_cpp +# enum ggml_log_level { +# GGML_LOG_LEVEL_NONE = 0, +# GGML_LOG_LEVEL_INFO = 1, +# GGML_LOG_LEVEL_WARN = 2, +# GGML_LOG_LEVEL_ERROR = 3, +# GGML_LOG_LEVEL_DEBUG = 4, +# GGML_LOG_LEVEL_CONT = 5, // continue previous log +# }; +GGML_LOG_LEVEL_TO_LOGGING_LEVEL = { + 0: logging.CRITICAL, + 1: logging.INFO, + 2: logging.WARNING, + 3: logging.ERROR, + 4: logging.DEBUG, + 5: logging.DEBUG, +} # Mapping ggml log levels to Python logging levels GGML_LOG_LEVEL_TO_LOGGING_LEVEL = { 2: logging.ERROR, @@ -15,21 +31,6 @@ # Initialize the logger for llama-cpp-python logger = logging.getLogger("nexa-transformers") -# # Define the log callback function -# @llama_cpp.llama_log_callback -# def llama_log_callback( -# level: int, -# text: bytes, -# user_data: ctypes.c_void_p, -# ): -# # Check if the logger is set to log the provided level -# if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]: -# # Print the log message to stderr -# print(text.decode("utf-8"), end="", flush=True, file=sys.stderr) - -# # Set the log callback function for llama_cpp -# llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0)) - # Utility function to set verbosity def set_verbose(verbose: bool): logger.setLevel(logging.DEBUG if verbose else logging.ERROR) diff --git a/nexa/gguf/llama/_utils_transformers.py b/nexa/gguf/llama/_utils_transformers.py index 0049e9cc..29628193 100644 --- a/nexa/gguf/llama/_utils_transformers.py +++ b/nexa/gguf/llama/_utils_transformers.py @@ -17,7 +17,7 @@ class suppress_stdout_stderr(object): sys = sys os = os - def __init__(self, disable: bool = False): + def __init__(self, disable: bool = True): self.disable = disable # Oddly enough this works better than the contextlib version @@ -75,4 +75,4 @@ class Singleton(object, metaclass=MetaSingleton): """ def __init__(self): - super(Singleton, self).__init__() \ No newline at end of file + super(Singleton, self).__init__() diff --git a/nexa/gguf/llama/audio_lm_cpp.py b/nexa/gguf/llama/audio_lm_cpp.py index 88db2a33..76187f8c 100644 --- a/nexa/gguf/llama/audio_lm_cpp.py +++ b/nexa/gguf/llama/audio_lm_cpp.py @@ -40,7 +40,7 @@ def _load_shared_library(lib_base_name: str, base_path: Path = None): def _get_lib(is_qwen: bool = True): # Specify the base name of the shared library to load - _lib_base_name = "nexa-qwen2-audio-lib_shared" if is_qwen else "nexa-omni-audio-lib_shared" + _lib_base_name = "nexa-qwen2-audio-lib_shared" if is_qwen else "omni_audio_shared" base_path = ( Path(__file__).parent.parent.parent.parent.resolve() / "nexa" diff --git a/nexa/gguf/llama/llama.py b/nexa/gguf/llama/llama.py index 0007b515..4ceb378f 100644 --- a/nexa/gguf/llama/llama.py +++ b/nexa/gguf/llama/llama.py @@ -7,6 +7,7 @@ import json import ctypes import typing +import random import fnmatch import warnings import contextlib @@ -31,7 +32,12 @@ from nexa.gguf.llama.llama_types import * from nexa.gguf.llama.llama_grammar import LlamaGrammar -from nexa.gguf.llama.llama_cache import BaseLlamaCache +from nexa.gguf.llama.llama_cache import ( + BaseLlamaCache, + LlamaCache, # type: ignore + LlamaDiskCache, # type: ignore + LlamaRAMCache, # type: ignore +) from nexa.gguf.llama.llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer import nexa.gguf.llama.llama_cpp as llama_cpp import nexa.gguf.llama.llama_chat_format as llama_chat_format @@ -41,15 +47,7 @@ import numpy as np import numpy.typing as npt -from nexa.gguf.llama._internals_transformers import ( - _LlamaModel, # type: ignore - _LlamaContext, # type: ignore - _LlamaBatch, # type: ignore - _LlamaTokenDataArray, # type: ignore - _LlamaSamplingParams, # type: ignore - _LlamaSamplingContext, # type: ignore - _normalize_embedding, # type: ignore -) +import nexa.gguf.llama._internals_transformers as internals from nexa.gguf.llama._logger_transformers import set_verbose from nexa.gguf.llama._utils_transformers import suppress_stdout_stderr @@ -77,6 +75,7 @@ def __init__( seed: int = llama_cpp.LLAMA_DEFAULT_SEED, n_ctx: int = 512, n_batch: int = 512, + n_ubatch: int = 512, n_threads: Optional[int] = None, n_threads_batch: Optional[int] = None, rope_scaling_type: Optional[ @@ -90,7 +89,7 @@ def __init__( yarn_beta_fast: float = 32.0, yarn_beta_slow: float = 1.0, yarn_orig_ctx: int = 0, - logits_all: bool = True, # switch + logits_all: bool = False, embedding: bool = False, offload_kqv: bool = True, flash_attn: bool = False, @@ -158,6 +157,7 @@ def __init__( seed: RNG seed, -1 for random n_ctx: Text context, 0 = from model n_batch: Prompt processing maximum batch size + n_ubatch: Physical batch size n_threads: Number of threads to use for generation n_threads_batch: Number of threads to use for batch processing rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054 @@ -258,28 +258,28 @@ def __init__( for i, (k, v) in enumerate(kv_overrides.items()): self._kv_overrides_array[i].key = k.encode("utf-8") if isinstance(v, bool): - self._kv_overrides_array[i].tag = ( - llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL - ) + self._kv_overrides_array[ + i + ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_BOOL self._kv_overrides_array[i].value.val_bool = v elif isinstance(v, int): - self._kv_overrides_array[i].tag = ( - llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT - ) + self._kv_overrides_array[ + i + ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_INT self._kv_overrides_array[i].value.val_i64 = v elif isinstance(v, float): - self._kv_overrides_array[i].tag = ( - llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT - ) + self._kv_overrides_array[ + i + ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_FLOAT self._kv_overrides_array[i].value.val_f64 = v elif isinstance(v, str): # type: ignore v_bytes = v.encode("utf-8") if len(v_bytes) > 128: # TODO: Make this a constant raise ValueError(f"Value for {k} is too long: {v}") v_bytes = v_bytes.ljust(128, b"\0") - self._kv_overrides_array[i].tag = ( - llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR - ) + self._kv_overrides_array[ + i + ].tag = llama_cpp.LLAMA_KV_OVERRIDE_TYPE_STR # copy min(v_bytes, 128) to str_value address = typing.cast( int, @@ -295,20 +295,23 @@ def __init__( else: raise ValueError(f"Unknown value type for {k}: {v}") - self._kv_overrides_array[-1].key = ( - b"\0" # ensure sentinel element is zeroed - ) + self._kv_overrides_array[ + -1 + ].key = b"\0" # ensure sentinel element is zeroed self.model_params.kv_overrides = self._kv_overrides_array self.n_batch = min(n_ctx, n_batch) # ??? self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1) self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count() + # Used by the sampler + self._seed = seed or llama_cpp.LLAMA_DEFAULT_SEED + # Context Params self.context_params = llama_cpp.llama_context_default_params() - self.context_params.seed = seed self.context_params.n_ctx = n_ctx self.context_params.n_batch = self.n_batch + self.context_params.n_ubatch = min(self.n_batch, n_ubatch) self.context_params.n_threads = self.n_threads self.context_params.n_threads_batch = self.n_threads_batch self.context_params.rope_scaling_type = ( @@ -336,10 +339,9 @@ def __init__( yarn_beta_slow if yarn_beta_slow != 0.0 else 0 ) self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 - # self.context_params.logits_all = ( - # logits_all if draft_model is None else True - # ) # Must be set to True for speculative decoding - self.context_params.logits_all = True + self.context_params.logits_all = ( + logits_all if draft_model is None else True + ) # Must be set to True for speculative decoding self.context_params.embeddings = embedding # TODO: Rename to embeddings self.context_params.offload_kqv = offload_kqv self.context_params.flash_attn = flash_attn @@ -364,7 +366,7 @@ def __init__( self._model = self._stack.enter_context( contextlib.closing( - _LlamaModel( + internals.LlamaModel( path_model=self.model_path, params=self.model_params, verbose=self.verbose, @@ -381,10 +383,11 @@ def __init__( self.n_batch = min(n_ctx, n_batch) self.context_params.n_ctx = self._model.n_ctx_train() self.context_params.n_batch = self.n_batch + self.context_params.n_ubatch = min(self.n_batch, n_ubatch) self._ctx = self._stack.enter_context( contextlib.closing( - _LlamaContext( + internals.LlamaContext( model=self._model, params=self.context_params, verbose=self.verbose, @@ -394,7 +397,7 @@ def __init__( self._batch = self._stack.enter_context( contextlib.closing( - _LlamaBatch( + internals.LlamaBatch( n_tokens=self.n_batch, embd=0, n_seq_max=self.context_params.n_ctx, @@ -406,7 +409,6 @@ def __init__( self._lora_adapter: Optional[llama_cpp.llama_lora_adapter_p] = None if self.lora_path: - assert self._model.model is not None self._lora_adapter = llama_cpp.llama_lora_adapter_init( self._model.model, self.lora_path.encode("utf-8"), @@ -424,7 +426,6 @@ def free_lora_adapter(): self._stack.callback(free_lora_adapter) - assert self._ctx.ctx is not None if llama_cpp.llama_lora_adapter_set( self._ctx.ctx, self._lora_adapter, self.lora_scale ): @@ -437,9 +438,9 @@ def free_lora_adapter(): self.chat_format = chat_format self.chat_handler = chat_handler - self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = ( - {} - ) + self._chat_handlers: Dict[ + str, llama_chat_format.LlamaChatCompletionHandler + ] = {} self.draft_model = draft_model @@ -449,12 +450,12 @@ def free_lora_adapter(): self._token_nl = self.token_nl() self._token_eos = self.token_eos() - self._candidates = _LlamaTokenDataArray(n_vocab=self._n_vocab) + self._candidates = internals.LlamaTokenDataArray(n_vocab=self._n_vocab) self.n_tokens = 0 self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc) self.scores: npt.NDArray[np.single] = np.ndarray( - (n_ctx, self._n_vocab), dtype=np.single + (n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single ) self._mirostat_mu = ctypes.c_float( @@ -538,14 +539,14 @@ def free_lora_adapter(): f"Using fallback chat format: {self.chat_format}", file=sys.stderr ) + self._sampler = None + @property def ctx(self) -> llama_cpp.llama_context_p: - assert self._ctx.ctx is not None return self._ctx.ctx @property def model(self) -> llama_cpp.llama_model_p: - assert self._model.model is not None return self._model.model @property @@ -586,7 +587,10 @@ def tokenize( return self.tokenizer_.tokenize(text, add_bos, special) def detokenize( - self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False + self, + tokens: List[int], + prev_tokens: Optional[List[int]] = None, + special: bool = False, ) -> bytes: """Detokenize a list of tokens. @@ -598,8 +602,10 @@ def detokenize( Returns: The detokenized string. """ - return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special) - + return self.tokenizer_.detokenize( + tokens, prev_tokens=prev_tokens, special=special + ) + def set_cache(self, cache: Optional[BaseLlamaCache]): """Set the cache. @@ -614,8 +620,7 @@ def set_seed(self, seed: int): Args: seed: The random seed. """ - assert self._ctx.ctx is not None - llama_cpp.llama_set_rng_seed(self._ctx.ctx, seed) + self._seed = seed def reset(self): """Reset the model state.""" @@ -627,8 +632,6 @@ def eval(self, tokens: Sequence[int]): Args: tokens: The list of tokens to evaluate. """ - assert self._ctx.ctx is not None - assert self._batch.batch is not None self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] @@ -649,15 +652,106 @@ def eval(self, tokens: Sequence[int]): ) self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits else: - rows = 1 - cols = self._n_vocab - logits = np.ctypeslib.as_array( - self._ctx.get_logits(), shape=(rows * cols,) - ) - self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits + # rows = 1 + # cols = self._n_vocab + # logits = np.ctypeslib.as_array( + # self._ctx.get_logits(), shape=(rows * cols,) + # ) + # self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits + # NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all + pass # Update n_tokens self.n_tokens += n_tokens + def _init_sampler( + self, + top_k: int = 40, + top_p: float = 0.95, + min_p: float = 0.05, + typical_p: float = 1.0, + temp: float = 0.80, + repeat_penalty: float = 1.0, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_eta: float = 0.1, + mirostat_tau: float = 5.0, + penalize_nl: bool = True, + logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, + ): + sampler = internals.LlamaSampler() + + if logits_processor is not None: + # Create and add a custom sampler + def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): + size = token_data_array.contents.size + data_soa = token_data_array.contents.data + data_soa_address = ctypes.addressof(data_soa.contents) + # NOTE: This is probably broken + recarray = np.recarray( + shape=(size,), + dtype=np.dtype( + [("id", np.intc), ("logit", np.single), ("p", np.single)], + align=True, + ), + buf=(llama_cpp.llama_token_data * size).from_address( + data_soa_address + ), + ) + for logit_processor in logits_processor: + recarray.logit[:] = logit_processor(self._input_ids, recarray.logit) + + sampler.add_custom(apply_func) + + sampler.add_penalties( + n_vocab=self._n_vocab, + special_eos_id=self._token_eos, + linefeed_id=self._token_nl, + penalty_last_n=self.last_n_tokens_size, + penalty_repeat=repeat_penalty, + penalty_freq=frequency_penalty, + penalty_present=presence_penalty, + penalize_nl=penalize_nl, + ignore_eos=False, + ) + + if grammar is not None: + sampler.add_grammar(self._model, grammar) + + if temp < 0.0: + sampler.add_softmax() + sampler.add_dist(self._seed) + elif temp == 0.0: + sampler.add_greedy() + else: + if mirostat_mode == 1: + mirostat_m = 100 + sampler.add_mirostat( + self._n_vocab, + self._seed, + mirostat_tau, + mirostat_eta, + mirostat_m, + ) + elif mirostat_mode == 2: + sampler.add_mirostat_v2( + self._seed, + mirostat_tau, + mirostat_eta, + ) + else: + n_probs = 0 + min_keep = max(1, n_probs) + sampler.add_top_k(top_k) + sampler.add_typical(typical_p, min_keep) + sampler.add_top_p(top_p, min_keep) + sampler.add_min_p(min_p, min_keep) + sampler.add_temp(temp) + sampler.add_dist(self._seed) + return sampler + def sample( self, top_k: int = 40, @@ -674,8 +768,6 @@ def sample( mirostat_tau: float = 5.0, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, grammar: Optional[LlamaGrammar] = None, idx: Optional[int] = None, ): @@ -690,69 +782,37 @@ def sample( Returns: The sampled token. """ - assert self._ctx is not None assert self.n_tokens > 0 - if idx is None: - logits: npt.NDArray[np.single] = self._scores[-1, :] - else: - logits = self._scores[idx, :] - - if logits_processor is not None: - logits[:] = ( - logits_processor(self._input_ids, logits) - if idx is None - else logits_processor(self._input_ids[: idx + 1], logits) + tmp_sampler = False + + if self._sampler is None: + tmp_sampler = True + self._sampler = self._init_sampler( + top_k=top_k, + top_p=top_p, + min_p=min_p, + typical_p=typical_p, + temp=temp, + repeat_penalty=repeat_penalty, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + penalize_nl=penalize_nl, + logits_processor=logits_processor, + grammar=grammar, ) - sampling_params = _LlamaSamplingParams( - top_k=top_k, - top_p=top_p, - min_p=min_p, - tfs_z=tfs_z, - typical_p=typical_p, - temp=temp, - penalty_last_n=self.last_n_tokens_size, - penalty_repeat=repeat_penalty, - penalty_freq=frequency_penalty, - penalty_present=presence_penalty, - mirostat=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - penalize_nl=penalize_nl, - ) - sampling_context = _LlamaSamplingContext( - params=sampling_params, - grammar=grammar, - ) - sampling_context.prev = list(self.eval_tokens) - id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits) - sampling_context.accept( - ctx_main=self._ctx, - id=id, - apply_grammar=grammar is not None, - ) + ridx = idx - self.n_tokens if idx is not None else -1 - if logprobs is not None and (top_logprobs is not None and top_logprobs > 0): - sampled_logprobs = self.logits_to_logprobs(logits) - token_logprob = float(sampled_logprobs[id]) - - top_logprobs_dict = None - if top_logprobs is not None: - sorted_indices = sampled_logprobs.argsort()[::-1] - top_indices = sorted_indices[:top_logprobs] - top_logprobs_dict = { - self.detokenize([i]).decode("utf-8", errors="ignore"): float(sampled_logprobs[i]) - for i in top_indices - } - - return { - "token": id, - "token_logprob": token_logprob, - "top_logprobs": top_logprobs_dict - } - else: - return id + assert self.ctx is not None + token = self._sampler.sample(self._ctx, ridx) + if tmp_sampler: + self._sampler = None + return token def generate( self, @@ -772,8 +832,6 @@ def generate( mirostat_eta: float = 0.1, penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, grammar: Optional[LlamaGrammar] = None, ) -> Generator[int, Optional[Sequence[int]], None]: @@ -798,6 +856,23 @@ def generate( """ # Reset mirostat sampling self._mirostat_mu = ctypes.c_float(2.0 * mirostat_tau) + self._sampler = self._init_sampler( + top_k=top_k, + top_p=top_p, + min_p=min_p, + typical_p=typical_p, + temp=temp, + repeat_penalty=repeat_penalty, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + penalize_nl=penalize_nl, + logits_processor=logits_processor, + grammar=grammar, + ) # Check for kv cache prefix match if reset and self.n_tokens > 0: @@ -812,16 +887,19 @@ def generate( tokens = tokens[longest_prefix:] self.n_tokens = longest_prefix if self.verbose: - print(f"Llama.generate: {longest_prefix} prefix-match hit, " - f"remaining {len(tokens)} prompt tokens to eval", file=sys.stderr) + print( + f"Llama.generate: {longest_prefix} prefix-match hit, " + f"remaining {len(tokens)} prompt tokens to eval", + file=sys.stderr, + ) # Reset the model state if reset: self.reset() - # Reset the grammar - if grammar is not None: - grammar.reset() + # # Reset the grammar + # if grammar is not None: + # grammar.reset() sample_idx = self.n_tokens + len(tokens) - 1 tokens = list(tokens) @@ -830,7 +908,7 @@ def generate( while True: self.eval(tokens) while sample_idx < self.n_tokens: - result = self.sample( + token = self.sample( top_k=top_k, top_p=top_p, min_p=min_p, @@ -844,26 +922,17 @@ def generate( mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, logits_processor=logits_processor, - logprobs=logprobs, - top_logprobs=top_logprobs, grammar=grammar, penalize_nl=penalize_nl, idx=sample_idx, ) - if isinstance(result, dict): - token = result["token"] - logprobs_info = result - else: - token = result - logprobs_info = None - sample_idx += 1 if stopping_criteria is not None and stopping_criteria( - self._input_ids, self._scores[-1, :] + self._input_ids[: sample_idx], self._scores[sample_idx - self.n_tokens, :] ): return - tokens_or_none = yield token, logprobs_info + tokens_or_none = yield token tokens.clear() tokens.append(token) if tokens_or_none is not None: @@ -896,7 +965,6 @@ def create_embedding( Returns: An embedding object. """ - assert self._model.model is not None model_name: str = model if model is not None else self.model_path input = input if isinstance(input, list) else [input] @@ -941,7 +1009,6 @@ def embed( Returns: A list of embeddings """ - assert self._ctx.ctx is not None n_embd = self.n_embd() n_batch = self.n_batch @@ -955,7 +1022,7 @@ def embed( ) if self.verbose: - llama_cpp.llama_reset_timings(self._ctx.ctx) + llama_cpp.llama_perf_context_reset(self._ctx.ctx) if isinstance(input, str): inputs = [input] @@ -969,7 +1036,6 @@ def embed( data: Union[List[List[float]], List[List[List[float]]]] = [] def decode_batch(seq_sizes: List[int]): - assert self._ctx.ctx is not None llama_cpp.llama_kv_cache_clear(self._ctx.ctx) self._ctx.decode(self._batch) self._batch.reset() @@ -984,7 +1050,9 @@ def decode_batch(seq_sizes: List[int]): for j in range(size) ] if normalize: - embedding = [_normalize_embedding(e) for e in embedding] + embedding = [ + internals.normalize_embedding(e) for e in embedding + ] data.append(embedding) pos += size else: @@ -992,7 +1060,7 @@ def decode_batch(seq_sizes: List[int]): ptr = llama_cpp.llama_get_embeddings_seq(self._ctx.ctx, i) embedding: List[float] = ptr[:n_embd] if normalize: - embedding = _normalize_embedding(embedding) + embedding = internals.normalize_embedding(embedding) data.append(embedding) # init state @@ -1035,7 +1103,7 @@ def decode_batch(seq_sizes: List[int]): decode_batch(s_batch) if self.verbose: - llama_cpp.llama_print_timings(self._ctx.ctx) + llama_cpp.llama_perf_context_print(self._ctx.ctx) output = data[0] if isinstance(input, str) else data @@ -1077,7 +1145,6 @@ def _create_completion( ) -> Union[ Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] ]: - assert self._ctx is not None assert suffix is None or suffix.__class__ is str completion_id: str = f"cmpl-{str(uuid.uuid4())}" @@ -1222,7 +1289,7 @@ def logit_bias_processor( raise ValueError( "logprobs is not supported for models created with logits_all=False" ) - + if self.cache: try: cache_item = self.cache[prompt_tokens] @@ -1241,13 +1308,13 @@ def logit_bias_processor( print("Llama._create_completion: cache miss", file=sys.stderr) if seed is not None: - self._ctx.set_rng_seed(seed) + self.set_seed(seed) + else: + self.set_seed(random.Random(self._seed).randint(0, 2 ** 32)) finish_reason = "length" multibyte_fix = 0 - logprobs_or_none = None - - for token, logprobs_info in self.generate( + for token in self.generate( prompt_tokens, top_k=top_k, top_p=top_p, @@ -1263,11 +1330,8 @@ def logit_bias_processor( repeat_penalty=repeat_penalty, stopping_criteria=stopping_criteria, logits_processor=logits_processor, - logprobs=logprobs, - top_logprobs=logprobs, grammar=grammar, ): - assert self._model.model is not None if llama_cpp.llama_token_is_eog(self._model.model, token): text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) finish_reason = "stop" @@ -1275,20 +1339,6 @@ def logit_bias_processor( completion_tokens.append(token) - if logprobs_info and logprobs_or_none is None: - logprobs_or_none = { - "tokens": [], - "text_offset": [], - "token_logprobs": [], - "top_logprobs": [] - } - - if logprobs_info: - logprobs_or_none["tokens"].append(self.detokenize([token]).decode("utf-8", errors="ignore")) - logprobs_or_none["text_offset"].append(len(self.detokenize(completion_tokens[:-1]))) - logprobs_or_none["token_logprobs"].append(logprobs_info["token_logprob"]) - logprobs_or_none["top_logprobs"].append(logprobs_info["top_logprobs"]) - all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) # Contains multi-byte UTF8 @@ -1468,15 +1518,15 @@ def logit_bias_processor( if stream: remaining_tokens = completion_tokens[returned_tokens:] - all_text = self.detokenize( + remaining_text = self.detokenize( remaining_tokens, prev_tokens=prompt_tokens + completion_tokens[:returned_tokens], ) - any_stop = [s for s in stop_sequences if s in all_text] + any_stop = [s for s in stop_sequences if s in remaining_text] if len(any_stop) > 0: - end = min(all_text.index(stop) for stop in any_stop) + end = min(remaining_text.index(stop) for stop in any_stop) else: - end = len(all_text) + end = len(remaining_text) token_end_position = 0 for token in remaining_tokens: @@ -1487,7 +1537,7 @@ def logit_bias_processor( ) ) - # logprobs_or_none: Optional[CompletionLogprobs] = None + logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: if token == bos_token_id: continue @@ -1572,10 +1622,7 @@ def logit_bias_processor( { "text": "", "index": 0, - "delta": { - "content": "", - }, - "logprobs": logprobs_or_none, + "logprobs": None, "finish_reason": finish_reason, } ], @@ -1601,7 +1648,7 @@ def logit_bias_processor( if suffix_token_id < 0 and suffix is not None: text_str = text_str + suffix - # logprobs_or_none: Optional[CompletionLogprobs] = None + logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: text_offset = 0 if echo else len(prompt) token_offset = 0 if echo else len(prompt_tokens[1:]) @@ -1985,7 +2032,7 @@ def create_chat_completion_openai_v1( *args: Any, **kwargs: Any, ): - """Generate a chat completion with return type based on the OpenAI v1 API. + """Generate a chat completion with return type based on the the OpenAI v1 API. OpenAI python package is required to use this method. @@ -2029,6 +2076,7 @@ def __getstate__(self): seed=self.context_params.seed, n_ctx=self.context_params.n_ctx, n_batch=self.n_batch, + n_ubatch=self.context_params.n_ubatch, n_threads=self.context_params.n_threads, n_threads_batch=self.context_params.n_threads_batch, rope_scaling_type=self.context_params.rope_scaling_type, @@ -2069,7 +2117,6 @@ def __setstate__(self, state): self.__init__(**state) def save_state(self) -> LlamaState: - assert self._ctx.ctx is not None if self.verbose: print("Llama.save_state: saving llama state", file=sys.stderr) state_size = llama_cpp.llama_get_state_size(self._ctx.ctx) @@ -2096,15 +2143,17 @@ def save_state(self) -> LlamaState: n_tokens=self.n_tokens, llama_state=bytes(llama_state_compact), llama_state_size=n_bytes, + seed=self._seed, ) def load_state(self, state: LlamaState) -> None: - assert self._ctx.ctx is not None # Only filling in up to `n_tokens` and then zero-ing out the rest self.scores[: state.n_tokens, :] = state.scores.copy() - self.scores[state.n_tokens :, :] = 0.0 + rest = self.scores[state.n_tokens :, :] + rest[rest > 0] = 0.0 self.input_ids = state.input_ids.copy() self.n_tokens = state.n_tokens + self._seed = state.seed state_size = state.llama_state_size LLamaStateArrayType = ctypes.c_uint8 * state_size llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state) @@ -2147,62 +2196,6 @@ def pooling_type(self) -> str: def close(self) -> None: """Explicitly free the model from memory.""" self._stack.close() - - def unload_lora(self): - """Unload the LoRA adapter while keeping the base model in memory.""" - if self._lora_adapter is not None: - llama_cpp.llama_lora_adapter_clear(self._ctx.ctx) - llama_cpp.llama_lora_adapter_free(self._lora_adapter) - self._lora_adapter = None - self.lora_path = None - self.lora_scale = 1.0 - - def reload_lora(self, lora_path: str, lora_scale: float = 1.0): - """Reload a LoRA adapter from the given path. - - Args: - lora_path: Path to the LoRA adapter file - lora_scale: Scale to apply to the LoRA adapter (default: 1.0) - - Raises: - RuntimeError: If initialization or setting of the LoRA adapter fails - """ - # First unload any existing LoRA adapter - if self._lora_adapter is not None: - self.unload_lora() - - # Initialize new LoRA adapter - assert self._model.model is not None - self._lora_adapter = llama_cpp.llama_lora_adapter_init( - self._model.model, - lora_path.encode("utf-8"), - ) - if self._lora_adapter is None: - raise RuntimeError( - f"Failed to initialize LoRA adapter from lora path: {lora_path}" - ) - - def free_lora_adapter(): - if self._lora_adapter is None: - return - llama_cpp.llama_lora_adapter_free(self._lora_adapter) - self._lora_adapter = None - - self._stack.callback(free_lora_adapter) - - # Apply the LoRA adapter - assert self._ctx.ctx is not None - if llama_cpp.llama_lora_adapter_set( - self._ctx.ctx, self._lora_adapter, lora_scale - ): - # Clean up on failure - self.unload_lora() - raise RuntimeError( - f"Failed to set LoRA adapter from lora path: {lora_path}" - ) - - self.lora_path = lora_path - self.lora_scale = lora_scale def __del__(self) -> None: self.close() @@ -2240,6 +2233,7 @@ def from_pretrained( cls, repo_id: str, filename: Optional[str], + additional_files: Optional[List] = None, local_dir: Optional[Union[str, os.PathLike[str]]] = None, local_dir_use_symlinks: Union[bool, Literal["auto"]] = "auto", cache_dir: Optional[Union[str, os.PathLike[str]]] = None, @@ -2252,6 +2246,7 @@ def from_pretrained( Args: repo_id: The model repo id. filename: A filename or glob pattern to match the model file in the repo. + additional_files: A list of filenames or glob patterns to match additional model files in the repo. local_dir: The local directory to save the model to. local_dir_use_symlinks: Whether to use symlinks when downloading the model. **kwargs: Additional keyword arguments to pass to the Llama constructor. @@ -2282,6 +2277,7 @@ def from_pretrained( rel_path = Path(file).relative_to(repo_id) file_list.append(str(rel_path)) + # find the only/first shard file: matching_files = [file for file in file_list if fnmatch.fnmatch(file, filename)] # type: ignore if len(matching_files) == 0: @@ -2311,6 +2307,35 @@ def from_pretrained( cache_dir=cache_dir, ) + if additional_files: + for additonal_file_name in additional_files: + # find the additional shard file: + matching_additional_files = [file for file in file_list if fnmatch.fnmatch(file, additonal_file_name)] + + if len(matching_additional_files) == 0: + raise ValueError( + f"No file found in {repo_id} that match {additonal_file_name}\n\n" + f"Available Files:\n{json.dumps(file_list)}" + ) + + if len(matching_additional_files) > 1: + raise ValueError( + f"Multiple files found in {repo_id} matching {additonal_file_name}\n\n" + f"Available Files:\n{json.dumps(files)}" + ) + + (matching_additional_file,) = matching_additional_files + + # download the additional file + hf_hub_download( + repo_id=repo_id, + filename=matching_additional_file, + subfolder=subfolder, + local_dir=local_dir, + local_dir_use_symlinks=local_dir_use_symlinks, + cache_dir=cache_dir, + ) + if local_dir is None: model_path = hf_hub_download( repo_id=repo_id, @@ -2324,6 +2349,7 @@ def from_pretrained( else: model_path = os.path.join(local_dir, filename) + # loading the first file of a sharded GGUF loads all remaining shard files in the subfolder return cls( model_path=model_path, **kwargs, @@ -2338,12 +2364,14 @@ def __init__( n_tokens: int, llama_state: bytes, llama_state_size: int, + seed: int, ): self.input_ids = input_ids self.scores = scores self.n_tokens = n_tokens self.llama_state = llama_state self.llama_state_size = llama_state_size + self.seed = seed LogitsProcessor = Callable[ diff --git a/nexa/gguf/llama/llama_cache.py b/nexa/gguf/llama/llama_cache.py index 54f22eb7..6b05e11e 100644 --- a/nexa/gguf/llama/llama_cache.py +++ b/nexa/gguf/llama/llama_cache.py @@ -9,7 +9,7 @@ import diskcache -import nexa.gguf.llama as llama_cpp +import nexa.gguf.llama.llama from nexa.gguf.llama.llama_types import * @@ -32,7 +32,7 @@ def _find_longest_prefix_key( pass @abstractmethod - def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + def __getitem__(self, key: Sequence[int]) -> "nexa.gguf.llama.LlamaState": raise NotImplementedError @abstractmethod @@ -41,7 +41,7 @@ def __contains__(self, key: Sequence[int]) -> bool: @abstractmethod def __setitem__( - self, key: Sequence[int], value: "llama_cpp.llama.LlamaState" + self, key: Sequence[int], value: "nexa.gguf.llama.LlamaState" ) -> None: raise NotImplementedError @@ -52,9 +52,9 @@ class LlamaRAMCache(BaseLlamaCache): def __init__(self, capacity_bytes: int = (2 << 30)): super().__init__(capacity_bytes) self.capacity_bytes = capacity_bytes - self.cache_state: OrderedDict[Tuple[int, ...], "llama_cpp.llama.LlamaState"] = ( - OrderedDict() - ) + self.cache_state: OrderedDict[ + Tuple[int, ...], "nexa.gguf.llama.LlamaState" + ] = OrderedDict() @property def cache_size(self): @@ -67,7 +67,7 @@ def _find_longest_prefix_key( min_len = 0 min_key = None keys = ( - (k, llama_cpp.llama.Llama.longest_token_prefix(k, key)) + (k, nexa.gguf.llama.Llama.longest_token_prefix(k, key)) for k in self.cache_state.keys() ) for k, prefix_len in keys: @@ -76,7 +76,7 @@ def _find_longest_prefix_key( min_key = k return min_key - def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + def __getitem__(self, key: Sequence[int]) -> "nexa.gguf.llama.LlamaState": key = tuple(key) _key = self._find_longest_prefix_key(key) if _key is None: @@ -88,7 +88,7 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": def __contains__(self, key: Sequence[int]) -> bool: return self._find_longest_prefix_key(tuple(key)) is not None - def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): + def __setitem__(self, key: Sequence[int], value: "nexa.gguf.llama.LlamaState"): key = tuple(key) if key in self.cache_state: del self.cache_state[key] @@ -121,18 +121,18 @@ def _find_longest_prefix_key( min_len = 0 min_key: Optional[Tuple[int, ...]] = None for k in self.cache.iterkeys(): # type: ignore - prefix_len = llama_cpp.llama.Llama.longest_token_prefix(k, key) + prefix_len = nexa.gguf.llama.Llama.longest_token_prefix(k, key) if prefix_len > min_len: min_len = prefix_len min_key = k # type: ignore return min_key - def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": + def __getitem__(self, key: Sequence[int]) -> "nexa.gguf.llama.LlamaState": key = tuple(key) _key = self._find_longest_prefix_key(key) if _key is None: raise KeyError("Key not found") - value: "llama_cpp.llama.LlamaState" = self.cache.pop(_key) # type: ignore + value: "nexa.gguf.llama.LlamaState" = self.cache.pop(_key) # type: ignore # NOTE: This puts an integer as key in cache, which breaks, # Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens # self.cache.push(_key, side="front") # type: ignore @@ -141,7 +141,7 @@ def __getitem__(self, key: Sequence[int]) -> "llama_cpp.llama.LlamaState": def __contains__(self, key: Sequence[int]) -> bool: return self._find_longest_prefix_key(tuple(key)) is not None - def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): + def __setitem__(self, key: Sequence[int], value: "nexa.gguf.llama.LlamaState"): print("LlamaDiskCache.__setitem__: called", file=sys.stderr) key = tuple(key) if key in self.cache: @@ -152,4 +152,4 @@ def __setitem__(self, key: Sequence[int], value: "llama_cpp.llama.LlamaState"): while self.cache_size > self.capacity_bytes and len(self.cache) > 0: key_to_remove = next(iter(self.cache)) del self.cache[key_to_remove] - print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) \ No newline at end of file + print("LlamaDiskCache.__setitem__: trim", file=sys.stderr) diff --git a/nexa/gguf/llama/llama_chat_format.py b/nexa/gguf/llama/llama_chat_format.py index ff5cd06d..aeee3399 100644 --- a/nexa/gguf/llama/llama_chat_format.py +++ b/nexa/gguf/llama/llama_chat_format.py @@ -304,7 +304,6 @@ def _convert_text_completion_chunks_to_chat( } ], } - yield { "id": "chat" + chunk["id"], "model": chunk["model"], @@ -1010,7 +1009,7 @@ def format_qwen( **kwargs: Any, ) -> ChatFormatterResponse: _roles = dict(user="<|im_start|>user", assistant="<|im_start|>assistant") - system_message = "You are a helpful assistant." + system_message = _get_system_message(messages) or "You are a helpful assistant." system_template = "<|im_start|>system\n{system_message}" system_message = system_template.format(system_message=system_message) _messages = _map_roles(messages, _roles) @@ -1364,34 +1363,6 @@ def format_gemma( return ChatFormatterResponse(prompt=_prompt, stop=_sep) -@register_chat_format("octopusv2") -def format_octopus_v2( - messages: List[llama_types.ChatCompletionRequestMessage], - **kwargs: Any, -) -> ChatFormatterResponse: - system_message = "Below is the query from the users, please call the correct function and generate the parameters to call the function.\n\n" - _roles = dict(user="Query:", assistant="Response:") - _sep = "\n\n" - _messages = _map_roles(messages, _roles) - - # Assuming the last message should be the assistant's response - _messages.append((_roles["assistant"], None)) - - # Concatenating the prompt - _prompt = system_message - for role, content in _messages: - if content: - _prompt += f"{role} {content.strip()}{_sep}" - else: - _prompt += f"{role} " - - # The final prompt - _prompt = _prompt.strip() - - # Returning the formatted response - return ChatFormatterResponse(prompt=_prompt, stop=_sep) - - # Tricky chat formats that require custom chat handlers @@ -2736,6 +2707,31 @@ def last_image_embed_free(): def load_image(self, image_url: str) -> bytes: return self._load_image(image_url) + def _embed_image_bytes(self, image_bytes: bytes, n_threads_batch: int = 1): + if ( + self._last_image_embed is not None + and self._last_image_hash is not None + and hash(image_bytes) == self._last_image_hash + ): + return self._last_image_embed + with suppress_stdout_stderr(disable=self.verbose): + # Free the previous image embed + if self._last_image_embed is not None: + self._llava_cpp.llava_image_embed_free(self._last_image_embed) + self._last_image_embed = None + self._last_image_hash = None + embed = self._llava_cpp.llava_image_embed_make_with_bytes( + self.clip_ctx, + n_threads_batch, + (ctypes.c_uint8 * len(image_bytes)).from_buffer( + bytearray(image_bytes) + ), + len(image_bytes), + ) + self._last_image_embed = embed + self._last_image_hash = hash(image_bytes) + return embed + def __call__( self, *, @@ -2798,30 +2794,9 @@ def __call__( ) split_text = self.split_text_on_image_urls(text, image_urls) - def embed_image_bytes(image_bytes: bytes): - if ( - self._last_image_embed is not None - and self._last_image_hash is not None - and hash(image_bytes) == self._last_image_hash - ): - return self._last_image_embed - with suppress_stdout_stderr(disable=self.verbose): - # Free the previous image embed - if self._last_image_embed is not None: - self._llava_cpp.llava_image_embed_free(self._last_image_embed) - self._last_image_embed = None - self._last_image_hash = None - embed = self._llava_cpp.llava_image_embed_make_with_bytes( - self.clip_ctx, - llama.context_params.n_threads_batch, - (ctypes.c_uint8 * len(image_bytes)).from_buffer( - bytearray(image_bytes) - ), - len(image_bytes), - ) - self._last_image_embed = embed - self._last_image_hash = hash(image_bytes) - return embed + if self.verbose: + print(text, file=sys.stderr) + # Evaluate prompt llama.reset() @@ -2838,7 +2813,7 @@ def embed_image_bytes(image_bytes: bytes): llama.eval(tokens) else: image_bytes = self.load_image(value) - embed = embed_image_bytes(image_bytes) + embed = self._embed_image_bytes(image_bytes, llama.context_params.n_threads_batch) if llama.n_tokens + embed.contents.n_image_pos > llama.n_ctx(): raise ValueError( f"Prompt exceeds n_ctx: {llama.n_tokens + embed.contents.n_image_pos} > {llama.n_ctx()}" @@ -3337,6 +3312,44 @@ class Llama3VisionAlphaChatHandler(Llava15ChatHandler): Llama3VisionAlpha = Llama3VisionAlphaChatHandler +class MiniCPMv26ChatHandler(Llava15ChatHandler): + DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." + + CHAT_FORMAT = ( + "{% for message in messages %}" + "{% if loop.first and messages[0]['role'] != 'system' %}" + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "{% endif %}" + "<|im_start|>{{ message['role'] }}\n" + "{% if message['content'] is iterable %}" + "{% for content in message['content'] %}" + "{% if content.type == 'image_url' %}" + "{% if content.image_url is string %}" + "{{ content.image_url }}" + "{% endif %}" + "{% if content.image_url is mapping %}" + "{{ content.image_url.url }}" + "{% endif %}" + "{% endif %}" + "{% endfor %}" + + "{% for content in message['content'] %}" + "{% if content.type == 'text' %}" + "{{ content.text }}" + "{% endif %}" + "{% endfor %}" + "{% endif %}" + "{% if message['content'] is string %}" + "{{ message['content'] }}" + "{% endif %}" + "<|im_end|>\n" + "{% endfor %}" + "{% if add_generation_prompt %}" + "<|im_start|>assistant\n" + "{% endif %}" + ) + + @register_chat_completion_handler("chatml-function-calling") def chatml_function_calling( llama: llama.Llama, @@ -3777,4 +3790,4 @@ def chatml_function_calling( }, } - raise ValueError("Automatic streaming tool choice is not supported") \ No newline at end of file + raise ValueError("Automatic streaming tool choice is not supported") diff --git a/nexa/gguf/llama/llama_cpp.py b/nexa/gguf/llama/llama_cpp.py index 442d2e86..b8b8702e 100644 --- a/nexa/gguf/llama/llama_cpp.py +++ b/nexa/gguf/llama/llama_cpp.py @@ -1,90 +1,44 @@ from __future__ import annotations -import sys import os import ctypes -import functools import pathlib from typing import ( - Any, Callable, - List, Union, NewType, Optional, TYPE_CHECKING, - TypeVar, - Generic, ) -from typing_extensions import TypeAlias +from nexa.gguf.llama._ctypes_extensions import ( + byref, + ctypes_function_for_shared_library, +) + from nexa.gguf.lib_utils import load_library +if TYPE_CHECKING: + from nexa.gguf.llama._ctypes_extensions import ( + CtypesCData, + CtypesArray, + CtypesPointer, + CtypesVoidPointer, + CtypesRef, + CtypesPointerOrRef, + CtypesFuncPointer, + ) + + # Specify the base name of the shared library to load _lib_base_name = "llama" - # Load the library _lib = load_library(_lib_base_name) -# ctypes sane type hint helpers -# -# - Generic Pointer and Array types -# - PointerOrRef type with a type hinted byref function -# -# NOTE: Only use these for static type checking not for runtime checks -# no good will come of that - -if TYPE_CHECKING: - CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore - - CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore - - CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore - - CtypesVoidPointer: TypeAlias = ctypes.c_void_p - - class CtypesRef(Generic[CtypesCData]): - pass - - CtypesPointerOrRef: TypeAlias = Union[ - CtypesPointer[CtypesCData], CtypesRef[CtypesCData] - ] - - CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore - -F = TypeVar("F", bound=Callable[..., Any]) - - -def ctypes_function_for_shared_library(lib: ctypes.CDLL): - def ctypes_function( - name: str, argtypes: List[Any], restype: Any, enabled: bool = True - ): - def decorator(f: F) -> F: - if enabled: - func = getattr(lib, name) - func.argtypes = argtypes - func.restype = restype - functools.wraps(f)(func) - return func - else: - return f - - return decorator - - return ctypes_function - - ctypes_function = ctypes_function_for_shared_library(_lib) -def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]: - """Type-annotated version of ctypes.byref""" - ... - - -byref = ctypes.byref # type: ignore - # from ggml.h # // NOTE: always add types at the end of the enum to keep backward compatibility # enum ggml_type { @@ -148,11 +102,7 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa GGML_TYPE_I64 = 27 GGML_TYPE_F64 = 28 GGML_TYPE_IQ1_M = 29 -GGML_TYPE_BF16 = 30, -GGML_TYPE_Q4_0_4_4 = 31 -GGML_TYPE_Q4_0_4_8 = 32 -GGML_TYPE_Q4_0_8_8 = 33 -GGML_TYPE_COUNT = 34 +GGML_TYPE_COUNT = 30 # from ggml-backend.h # typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); @@ -176,6 +126,9 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa # define LLAMA_DEFAULT_SEED 0xFFFFFFFF LLAMA_DEFAULT_SEED = 0xFFFFFFFF +# define LLAMA_TOKEN_NULL -1 +LLAMA_TOKEN_NULL = -1 + # define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' LLAMA_FILE_MAGIC_GGLA = 0x67676C61 @@ -187,8 +140,8 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa # define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN LLAMA_SESSION_MAGIC = LLAMA_FILE_MAGIC_GGSN -# define LLAMA_SESSION_VERSION 8 -LLAMA_SESSION_VERSION = 8 +# define LLAMA_SESSION_VERSION 9 +LLAMA_SESSION_VERSION = 9 # define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ LLAMA_STATE_SEQ_MAGIC = LLAMA_FILE_MAGIC_GGSQ @@ -203,6 +156,9 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa llama_context_p = NewType("llama_context_p", int) llama_context_p_ctypes = ctypes.c_void_p +# # struct llama_sampler; +# llama_sampler_p = NewType("llama_sampler_p", int) +# llama_sampler_p_ctypes = ctypes.c_void_p # typedef int32_t llama_pos; llama_pos = ctypes.c_int32 @@ -263,6 +219,7 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa # LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, # LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, # LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, +# LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, # }; LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0 LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1 @@ -290,6 +247,7 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa LLAMA_VOCAB_PRE_TYPE_BLOOM = 23 LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24 LLAMA_VOCAB_PRE_TYPE_EXAONE = 25 +LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26 # // note: these values should be synchronized with ggml_rope @@ -447,12 +405,14 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa # LLAMA_POOLING_TYPE_MEAN = 1, # LLAMA_POOLING_TYPE_CLS = 2, # LLAMA_POOLING_TYPE_LAST = 3, +# LLAMA_POOLING_TYPE_RANK = 4, // used by reranking models to attach the classification head to the graph # }; LLAMA_POOLING_TYPE_UNSPECIFIED = -1 LLAMA_POOLING_TYPE_NONE = 0 LLAMA_POOLING_TYPE_MEAN = 1 LLAMA_POOLING_TYPE_CLS = 2 LLAMA_POOLING_TYPE_LAST = 3 +LLAMA_POOLING_TYPE_RANK = 4 # enum llama_attention_type { # LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1, @@ -463,10 +423,11 @@ def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCDa LLAMA_ATTENTION_TYPE_CAUSAL = 0 LLAMA_ATTENTION_TYPE_NON_CAUSAL = 1 + # enum llama_split_mode { -# LLAMA_SPLIT_MODE_NONE = 0, // single GPU -# LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs -# LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs +# LLAMA_SPLIT_MODE_NONE = 0, // single GPU +# LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs +# LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs # }; LLAMA_SPLIT_MODE_NONE = 0 LLAMA_SPLIT_MODE_LAYER = 1 @@ -502,8 +463,11 @@ class llama_token_data(ctypes.Structure): # typedef struct llama_token_data_array { +# // TODO: consider SoA +# // NOTE: this pointer can be modified by the samplers # llama_token_data * data; # size_t size; +# int64_t selected; // this is the index in the data array (i.e. not the token id) # bool sorted; # } llama_token_data_array; class llama_token_data_array(ctypes.Structure): @@ -512,16 +476,19 @@ class llama_token_data_array(ctypes.Structure): Attributes: data (ctypes.Array[llama_token_data]): token data size (int): size of the array + selected (int): index in the data array (i.e. not the token id) sorted (bool): whether the array is sorted""" if TYPE_CHECKING: data: CtypesArray[llama_token_data] size: int + selected: int sorted: bool _fields_ = [ ("data", llama_token_data_p), ("size", ctypes.c_size_t), + ("selected", ctypes.c_int64), ("sorted", ctypes.c_bool), ] @@ -541,8 +508,11 @@ class llama_token_data_array(ctypes.Structure): # // - token : the token ids of the input (used when embd is NULL) # // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) # // - pos : the positions of the respective token in the sequence +# // (if set to NULL, the token position will be tracked automatically by llama_decode) # // - seq_id : the sequence to which the respective token belongs +# // (if set to NULL, the sequence ID will be assumed to be 0) # // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output +# // (if set to NULL, only the logits for last token will be returned) # // # typedef struct llama_batch { # int32_t n_tokens; @@ -553,16 +523,6 @@ class llama_token_data_array(ctypes.Structure): # int32_t * n_seq_id; # llama_seq_id ** seq_id; # int8_t * logits; // TODO: rename this to "output" - - -# // NOTE: helpers for smooth API transition - can be deprecated in the future -# // for future-proof code, use the above fields instead and ignore everything below -# // -# // pos[i] = all_pos_0 + i*all_pos_1 -# // -# llama_pos all_pos_0; // used if pos == NULL -# llama_pos all_pos_1; // used if pos == NULL -# llama_seq_id all_seq_id; // used if seq_id == NULL # } llama_batch; class llama_batch(ctypes.Structure): """Input data for llama_decode @@ -597,9 +557,6 @@ class llama_batch(ctypes.Structure): ("n_seq_id", ctypes.POINTER(ctypes.c_int32)), ("seq_id", ctypes.POINTER(ctypes.POINTER(llama_seq_id))), ("logits", ctypes.POINTER(ctypes.c_int8)), - ("all_pos_0", llama_pos), - ("all_pos_1", llama_pos), - ("all_seq_id", llama_seq_id), ] @@ -740,7 +697,6 @@ class llama_model_params(ctypes.Structure): # // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations # // https://github.com/ggerganov/llama.cpp/pull/7544 # struct llama_context_params { -# uint32_t seed; // RNG seed, -1 for random # uint32_t n_ctx; // text context, 0 = from model # uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode # uint32_t n_ubatch; // physical maximum batch size @@ -773,6 +729,7 @@ class llama_model_params(ctypes.Structure): # bool embeddings; // if true, extract embeddings (together with logits) # bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU # bool flash_attn; // whether to use flash attention [EXPERIMENTAL] +# bool no_perf; // whether to measure performance timings # // Abort callback @@ -785,7 +742,6 @@ class llama_context_params(ctypes.Structure): """Parameters for llama_context Attributes: - seed (int): RNG seed, -1 for random n_ctx (int): text context, 0 = from model n_batch (int): logical maximum batch size that can be submitted to llama_decode n_ubatch (int): physical maximum batch size @@ -816,7 +772,6 @@ class llama_context_params(ctypes.Structure): """ if TYPE_CHECKING: - seed: int n_ctx: int n_batch: int n_ubatch: int @@ -846,7 +801,6 @@ class llama_context_params(ctypes.Structure): abort_callback_data: ctypes.c_void_p _fields_ = [ - ("seed", ctypes.c_uint32), ("n_ctx", ctypes.c_uint32), ("n_batch", ctypes.c_uint32), ("n_ubatch", ctypes.c_uint32), @@ -952,101 +906,44 @@ class llama_model_quantize_params(ctypes.Structure): ] -# // grammar types -# struct llama_grammar; -llama_grammar_p = ctypes.c_void_p - -# // grammar element type -# enum llama_gretype { -# // end of rule definition -# LLAMA_GRETYPE_END = 0, - -# // start of alternate definition for rule -# LLAMA_GRETYPE_ALT = 1, - -# // non-terminal element: reference to rule -# LLAMA_GRETYPE_RULE_REF = 2, - -# // terminal element: character (code point) -# LLAMA_GRETYPE_CHAR = 3, - -# // inverse char(s) ([^a], [^a-b] [^abc]) -# LLAMA_GRETYPE_CHAR_NOT = 4, +# typedef struct llama_logit_bias { +# llama_token token; +# float bias; +# } llama_logit_bias; +class llama_logit_bias(ctypes.Structure): + """Used to store logit bias -# // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to -# // be an inclusive range ([a-z]) -# LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - -# // modifies a preceding LLAMA_GRETYPE_CHAR or -# // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) -# LLAMA_GRETYPE_CHAR_ALT = 6, + Attributes: + token (llama_token): token id + bias (float): bias""" -# // any character (.) -# LLAMA_GRETYPE_CHAR_ANY = 7, -# }; -LLAMA_GRETYPE_END = 0 -LLAMA_GRETYPE_ALT = 1 -LLAMA_GRETYPE_RULE_REF = 2 -LLAMA_GRETYPE_CHAR = 3 -LLAMA_GRETYPE_CHAR_NOT = 4 -LLAMA_GRETYPE_CHAR_RNG_UPPER = 5 -LLAMA_GRETYPE_CHAR_ALT = 6 -LLAMA_GRETYPE_CHAR_ANY = 7 - - -# typedef struct llama_grammar_element { -# enum llama_gretype type; -# uint32_t value; // Unicode code point or rule ID -# } llama_grammar_element; -class llama_grammar_element(ctypes.Structure): if TYPE_CHECKING: - type: int - value: int + token: llama_token + bias: float _fields_ = [ - ("type", ctypes.c_int), - ("value", ctypes.c_uint32), + ("token", llama_token), + ("bias", ctypes.c_float), ] -llama_grammar_element_p = ctypes.POINTER(llama_grammar_element) +llama_logit_bias_p = ctypes.POINTER(llama_logit_bias) -# // performance timing information -# struct llama_timings { -# double t_start_ms; -# double t_end_ms; -# double t_load_ms; -# double t_sample_ms; -# double t_p_eval_ms; -# double t_eval_ms; +# typedef struct llama_sampler_chain_params { +# bool no_perf; // whether to measure performance timings +# } llama_sampler_chain_params; +class llama_sampler_chain_params(ctypes.Structure): + """Parameters for llama_sampler_chain + + Attributes: + no_perf (bool): whether to measure performance timings""" -# int32_t n_sample; -# int32_t n_p_eval; -# int32_t n_eval; -# }; -class llama_timings(ctypes.Structure): if TYPE_CHECKING: - t_start_ms: float - t_end_ms: float - t_load_ms: float - t_sample_ms: float - t_p_eval_ms: float - t_eval_ms: float - n_sample: int - n_p_eval: int - n_eval: int + no_perf: bool _fields_ = [ - ("t_start_ms", ctypes.c_double), - ("t_end_ms", ctypes.c_double), - ("t_load_ms", ctypes.c_double), - ("t_sample_ms", ctypes.c_double), - ("t_p_eval_ms", ctypes.c_double), - ("t_eval_ms", ctypes.c_double), - ("n_sample", ctypes.c_int32), - ("n_p_eval", ctypes.c_int32), - ("n_eval", ctypes.c_int32), + ("no_perf", ctypes.c_bool), ] @@ -1069,7 +966,7 @@ class llama_chat_message(ctypes.Structure): # // Helpers for getting default parameters -# LLAMA_API struct llama_model_params llama_model_default_params(void); +# LLAMA_API struct llama_model_params llama_model_default_params(void); @ctypes_function( "llama_model_default_params", [], @@ -1080,7 +977,7 @@ def llama_model_default_params() -> llama_model_params: ... -# LLAMA_API struct llama_context_params llama_context_default_params(void); +# LLAMA_API struct llama_context_params llama_context_default_params(void); @ctypes_function( "llama_context_default_params", [], @@ -1091,6 +988,17 @@ def llama_context_default_params() -> llama_context_params: ... +# LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); +@ctypes_function( + "llama_sampler_chain_default_params", + [], + llama_sampler_chain_params, +) +def llama_sampler_chain_default_params() -> llama_sampler_chain_params: + """Get default parameters for llama_sampler_chain""" + ... + + # LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); @ctypes_function( "llama_model_quantize_default_params", @@ -1171,7 +1079,7 @@ def llama_backend_free(): # LLAMA_API struct llama_model * llama_load_model_from_file( # const char * path_model, -# struct llama_model_params params); +# struct llama_model_params params); @ctypes_function( "llama_load_model_from_file", [ctypes.c_char_p, llama_model_params], @@ -1253,9 +1161,9 @@ def llama_supports_gpu_offload() -> bool: ... -# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); -@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) -def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: +# LLAMA_API bool llama_supports_rpc (void); +@ctypes_function("llama_supports_rpc", [], ctypes.c_bool) +def llama_supports_rpc() -> bool: ... @@ -1283,24 +1191,6 @@ def llama_n_seq_max(ctx: llama_context_p, /) -> int: ... -# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); -@ctypes_function("llama_pooling_type", [llama_context_p_ctypes], ctypes.c_int) -def llama_pooling_type(ctx: llama_context_p, /) -> int: - ... - - -# LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); -@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_vocab_type(model: llama_model_p, /) -> int: - ... - - -# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); -@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int) -def llama_rope_type(model: llama_model_p, /) -> int: - ... - - # LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); @ctypes_function("llama_n_vocab", [llama_model_p_ctypes], ctypes.c_int32) def llama_n_vocab(model: llama_model_p, /) -> int: @@ -1325,6 +1215,36 @@ def llama_n_layer(model: llama_model_p, /) -> int: ... +# LLAMA_API int32_t llama_n_head (const struct llama_model * model); +@ctypes_function("llama_n_head", [llama_model_p_ctypes], ctypes.c_int32) +def llama_n_head(model: llama_model_p, /) -> int: + ... + + +# LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); +@ctypes_function("llama_get_model", [llama_context_p_ctypes], llama_model_p_ctypes) +def llama_get_model(ctx: llama_context_p, /) -> Optional[llama_model_p]: + ... + + +# LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); +@ctypes_function("llama_pooling_type", [llama_context_p_ctypes], ctypes.c_int) +def llama_pooling_type(ctx: llama_context_p, /) -> int: + ... + + +# LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); +@ctypes_function("llama_vocab_type", [llama_model_p_ctypes], ctypes.c_int) +def llama_vocab_type(model: llama_model_p, /) -> int: + ... + + +# LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); +@ctypes_function("llama_rope_type", [llama_model_p_ctypes], ctypes.c_int) +def llama_rope_type(model: llama_model_p, /) -> int: + ... + + # // Get the model's RoPE frequency scaling factor # LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); @ctypes_function("llama_rope_freq_scale_train", [llama_model_p_ctypes], ctypes.c_float) @@ -1492,10 +1412,10 @@ def llama_model_decoder_start_token(model: llama_model_p, /) -> int: # // Returns true if the model is recurrent (like Mamba, RWKV, etc.) # LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); -# @ctypes_function("llama_model_is_recurrent", [llama_model_p_ctypes], ctypes.c_bool) -# def llama_model_is_recurrent(model: llama_model_p, /) -> bool: -# """Returns true if the model is recurrent (like Mamba, RWKV, etc.)""" -# ... +@ctypes_function("llama_model_is_recurrent", [llama_model_p_ctypes], ctypes.c_bool) +def llama_model_is_recurrent(model: llama_model_p, /) -> bool: + """Returns true if the model is recurrent (like Mamba, RWKV, etc.)""" + ... # // Returns 0 on success @@ -1983,7 +1903,7 @@ def llama_kv_cache_update(ctx: llama_context_p, /): # // Returns the *actual* size in bytes of the state -# // (rng, logits, embedding and kv_cache) +# // (logits, embedding and kv_cache) # // Only use when saving the state, not when restoring it, otherwise the size may be too small. # LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); @ctypes_function("llama_state_get_size", [llama_context_p_ctypes], ctypes.c_size_t) @@ -2332,30 +2252,26 @@ def llama_state_seq_load_file( # // -# // Return batch for single sequence of tokens starting at pos_0 +# // Return batch for single sequence of tokens +# // The sequence ID will be fixed to 0 +# // The position of the tokens will be tracked automatically by llama_decode # // # // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it # // # LLAMA_API struct llama_batch llama_batch_get_one( # llama_token * tokens, -# int32_t n_tokens, -# llama_pos pos_0, -# llama_seq_id seq_id); +# int32_t n_tokens); @ctypes_function( "llama_batch_get_one", [ llama_token_p, - ctypes.c_int, - llama_pos, - llama_seq_id, + ctypes.c_int32, ], llama_batch, ) def llama_batch_get_one( tokens: CtypesArray[llama_token], n_tokens: Union[ctypes.c_int, int], - pos_0: Union[llama_pos, int], - seq_id: llama_seq_id, /, ) -> llama_batch: """Return batch for single sequence of tokens starting at pos_0 @@ -2602,7 +2518,8 @@ def llama_get_embeddings_ith( # // Get the embeddings for a sequence id # // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE -# // shape: [n_embd] (1-dimensional) +# // when pooling_type == LLAMA_POOLING_TYPE_RANK, returns float[1] with the rank of the sequence +# // otherwise: float[n_embd] (1-dimensional) # LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); @ctypes_function( "llama_get_embeddings_seq", @@ -2692,6 +2609,13 @@ def llama_token_eos(model: llama_model_p, /) -> int: ... +# LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn +@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) +def llama_token_eot(model: llama_model_p, /) -> int: + """end-of-turn""" + ... + + # LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification @ctypes_function("llama_token_cls", [llama_model_p_ctypes], llama_token) def llama_token_cls(model: llama_model_p, /) -> int: @@ -2726,34 +2650,60 @@ def llama_add_eos_token(model: llama_model_p, /) -> bool: # // Codellama infill tokens -# LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix +# DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead"); @ctypes_function("llama_token_prefix", [llama_model_p_ctypes], llama_token) def llama_token_prefix(model: llama_model_p) -> int: """codellama infill tokens""" ... -# LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle +# DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead"); @ctypes_function("llama_token_middle", [llama_model_p_ctypes], llama_token) def llama_token_middle(model: llama_model_p, /) -> int: ... -# LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix +# DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead"); @ctypes_function("llama_token_suffix", [llama_model_p_ctypes], llama_token) def llama_token_suffix(model: llama_model_p, /) -> int: ... -# LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle -@ctypes_function("llama_token_eot", [llama_model_p_ctypes], llama_token) -def llama_token_eot(model: llama_model_p, /) -> int: +# LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model); +@ctypes_function("llama_token_fim_pre", [llama_model_p_ctypes], llama_token) +def llama_token_fim_pre(model: llama_model_p, /) -> int: ... +# LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model); +@ctypes_function("llama_token_fim_suf", [llama_model_p_ctypes], llama_token) +def llama_token_fim_suf(model: llama_model_p, /) -> int: + ... + +# LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model); +@ctypes_function("llama_token_fim_mid", [llama_model_p_ctypes], llama_token) +def llama_token_fim_mid(model: llama_model_p, /) -> int: + ... + +# LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model); +@ctypes_function("llama_token_fim_pad", [llama_model_p_ctypes], llama_token) +def llama_token_fim_pad(model: llama_model_p, /) -> int: + ... + +# LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model); +@ctypes_function("llama_token_fim_rep", [llama_model_p_ctypes], llama_token) +def llama_token_fim_rep(model: llama_model_p, /) -> int: + ... + +# LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model); +@ctypes_function("llama_token_fim_sep", [llama_model_p_ctypes], llama_token) +def llama_token_fim_sep(model: llama_model_p, /) -> int: + ... # // # // Tokenization # // +# // The API is thread-safe. +# // # /// @details Convert the provided text into tokens. @@ -2860,6 +2810,23 @@ def llama_token_to_piece( ... +# # // check if token0 is contained as a prefix in token1 +# # LLAMA_API bool llama_token_is_prefix( +# # const struct llama_model * model, +# # llama_token token0, +# # llama_token token1); +# @ctypes_function( +# "llama_token_is_prefix", +# [llama_model_p_ctypes, llama_token, llama_token], +# ctypes.c_bool, +# ) +# def llama_token_is_prefix( +# model: llama_model_p, token0: Union[llama_token, int], token1: Union[llama_token, int], / +# ) -> bool: +# """Check if token0 is contained as a prefix in token1""" +# ... + + # /// @details Convert the provided tokens into text (inverse of llama_tokenize()). # /// @param text The char pointer must be large enough to hold the resulting text. # /// @return Returns the number of chars/bytes on success, no more than text_len_max. @@ -2954,413 +2921,315 @@ def llama_chat_apply_template( # // -# // Grammar +# // Sampling API +# // +# // Sample usage: +# // +# // // prepare the sampling chain at the start +# // auto sparams = llama_sampler_chain_default_params(); +# // +# // llama_sampler * smpl = llama_sampler_chain_init(sparams); +# // +# // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50)); +# // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); +# // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8)); +# // +# // // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat" +# // // this sampler will be responsible to select the actual token +# // llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed)); +# // +# // ... +# // +# // // decoding loop: +# // while (...) { +# // ... +# // +# // llama_decode(ctx, batch); +# // +# // // sample from the logits of the last token in the batch +# // const llama_token id = llama_sampler_sample(smpl, ctx, -1); +# // +# // // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.) +# // llama_sampler_accept(smpl, id); +# // ... +# // } +# // +# // llama_sampler_free(smpl); +# // +# // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). +# // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab # // +# typedef void * llama_sampler_context_t; +llama_sampler_context_t = ctypes.c_void_p -# LLAMA_API struct llama_grammar * llama_grammar_init( -# const llama_grammar_element ** rules, -# size_t n_rules, -# size_t start_rule_index); -@ctypes_function( - "llama_grammar_init", - [ - ctypes.POINTER(llama_grammar_element_p), - ctypes.c_size_t, - ctypes.c_size_t, - ], - llama_grammar_p, -) -def llama_grammar_init( - rules: CtypesArray[ - CtypesPointer[llama_grammar_element] - ], # NOTE: This might be wrong type sig - n_rules: Union[ctypes.c_size_t, int], - start_rule_index: Union[ctypes.c_size_t, int], - /, -) -> Optional[llama_grammar_p]: - """Initialize a grammar from a set of rules.""" + +# // user code can implement the interface below in order to create custom llama_sampler +# struct llama_sampler_i { +# const char * (*name) (const struct llama_sampler * smpl); // can be NULL +# void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL +# void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required +# void (*reset) ( struct llama_sampler * smpl); // can be NULL +# struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL +# void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL +# +# // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph +# //void (*apply_ggml) (struct llama_sampler * smpl, ...); +# }; +class llama_sampler_i(ctypes.Structure): ... -# LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); -@ctypes_function( - "llama_grammar_free", - [llama_grammar_p], - None, +# struct llama_sampler { +# struct llama_sampler_i * iface; +# llama_sampler_context_t ctx; +# }; +class llama_sampler(ctypes.Structure): + _fields_ = [ + ("iface", ctypes.POINTER(llama_sampler_i)), + ("ctx", llama_sampler_context_t), + ] + + +if TYPE_CHECKING: + llama_sampler_p = CtypesPointer[llama_sampler] + +llama_sampler_p_ctypes = ctypes.POINTER(llama_sampler) + +llama_sampler_i_name = ctypes.CFUNCTYPE(ctypes.c_char_p, llama_sampler_p_ctypes) +llama_sampler_i_accept = ctypes.CFUNCTYPE(None, llama_sampler_p_ctypes, llama_token) +llama_sampler_i_apply = ctypes.CFUNCTYPE( + None, llama_sampler_p_ctypes, llama_token_data_array_p ) -def llama_grammar_free(grammar: llama_grammar_p, /): - """Free a grammar.""" - ... +llama_sampler_i_reset = ctypes.CFUNCTYPE(None, llama_sampler_p_ctypes) +llama_sampler_i_clone = ctypes.CFUNCTYPE(llama_sampler_p_ctypes, llama_sampler_p_ctypes) +llama_sampler_i_free = ctypes.CFUNCTYPE(None, llama_sampler_p_ctypes) + +llama_sampler_i._fields_ = [ + ("name", llama_sampler_i_name), + ("accept", llama_sampler_i_accept), + ("apply", llama_sampler_i_apply), + ("reset", llama_sampler_i_reset), + ("clone", llama_sampler_i_clone), + ("free", llama_sampler_i_free), +] -# LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); +# // mirror of llama_sampler_i: +# LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); @ctypes_function( - "llama_grammar_copy", - [llama_grammar_p], - llama_grammar_p, + "llama_sampler_name", + [llama_sampler_p_ctypes], + ctypes.c_char_p, ) -def llama_grammar_copy(grammar: llama_grammar_p, /) -> llama_grammar_p: - """Copy a grammar.""" +def llama_sampler_name(smpl: llama_sampler_p, /) -> bytes: ... -# /// @details Apply constraints from grammar -# LLAMA_API void llama_grammar_sample( -# const struct llama_grammar * grammar, -# const struct llama_context * ctx, -# llama_token_data_array * candidates); +# LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); @ctypes_function( - "llama_grammar_sample", - [ - llama_grammar_p, - llama_context_p_ctypes, - llama_token_data_array_p, - ], + "llama_sampler_accept", + [llama_sampler_p_ctypes, llama_token], None, ) -def llama_grammar_sample( - grammar: llama_grammar_p, - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - /, -): - """Apply constraints from grammar""" +def llama_sampler_accept(smpl: llama_sampler_p, token: Union[llama_token, int], /): ... -# LLAMA_API DEPRECATED(void llama_sample_grammar( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# const struct llama_grammar * grammar), -# "use llama_grammar_sample instead"); +# LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @ctypes_function( - "llama_sample_grammar", - [llama_context_p_ctypes, llama_token_data_array_p, llama_grammar_p], + "llama_sampler_apply", + [llama_sampler_p_ctypes, llama_token_data_array_p], None, ) -def llama_sample_grammar( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - grammar, # type: llama_grammar_p - /, +def llama_sampler_apply( + smpl: llama_sampler_p, cur_p: CtypesArray[llama_token_data_array], / ): - """Apply constraints from grammar - - Parameters: - candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - grammar: A grammar object containing the rules and constraints to apply to the generated text. - """ ... -# /// @details Accepts the sampled token into the grammar -# LLAMA_API void llama_grammar_accept_token( -# struct llama_grammar * grammar, -# struct llama_context * ctx, -# llama_token token); +# LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); @ctypes_function( - "llama_grammar_accept_token", - [llama_grammar_p, llama_context_p_ctypes, llama_token], + "llama_sampler_reset", + [llama_sampler_p_ctypes], None, ) -def llama_grammar_accept_token( - grammar: llama_grammar_p, - ctx: llama_context_p, - token: Union[llama_token, int], - /, -): - """Accepts the sampled token into the grammar""" +def llama_sampler_reset(smpl: llama_sampler_p, /): ... -# // -# // Sampling functions -# // +# LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); +@ctypes_function( + "llama_sampler_clone", + [llama_sampler_p_ctypes], + llama_sampler_p_ctypes, +) +def llama_sampler_clone(smpl: llama_sampler_p, /) -> llama_sampler_p: + ... -# // Sets the current rng seed. -# LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); +# // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) +# LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); @ctypes_function( - "llama_set_rng_seed", - [llama_context_p_ctypes, ctypes.c_uint32], + "llama_sampler_free", + [llama_sampler_p_ctypes], None, ) -def llama_set_rng_seed(ctx: llama_context_p, seed: Union[ctypes.c_uint32, int], /): - """Sets the current rng seed.""" +def llama_sampler_free(smpl: llama_sampler_p, /): ... -# /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. -# /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. -# LLAMA_API void llama_sample_repetition_penalties( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# const llama_token * last_tokens, -# size_t penalty_last_n, -# float penalty_repeat, -# float penalty_freq, -# float penalty_present); -@ctypes_function( - "llama_sample_repetition_penalties", - [ - llama_context_p_ctypes, - llama_token_data_array_p, - llama_token_p, - ctypes.c_size_t, - ctypes.c_float, - ctypes.c_float, - ctypes.c_float, - ], - None, +# // llama_sampler_chain +# // a type of llama_sampler that can chain multiple samplers one after another +# +# LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); +@ctypes_function( + "llama_sampler_chain_init", + [llama_sampler_chain_params], + llama_sampler_p_ctypes, ) -def llama_sample_repetition_penalties( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - last_tokens_data: CtypesArray[llama_token], - penalty_last_n: Union[ctypes.c_size_t, int], - penalty_repeat: Union[ctypes.c_float, float], - penalty_freq: Union[ctypes.c_float, float], - penalty_present: Union[ctypes.c_float, float], - /, -): - """Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - """ +def llama_sampler_chain_init(params: llama_sampler_chain_params, /) -> llama_sampler_p: ... -# /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 -# /// @param logits Logits extracted from the original generation context. -# /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. -# /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. -# LLAMA_API void llama_sample_apply_guidance( -# struct llama_context * ctx, -# float * logits, -# float * logits_guidance, -# float scale); +# // important: takes ownership of the sampler object and will free it when llama_sampler_free is called +# LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); @ctypes_function( - "llama_sample_apply_guidance", - [ - llama_context_p_ctypes, - ctypes.POINTER(ctypes.c_float), - ctypes.POINTER(ctypes.c_float), - ctypes.c_float, - ], + "llama_sampler_chain_add", + [llama_sampler_p_ctypes, llama_sampler_p_ctypes], None, ) -def llama_sample_apply_guidance( - ctx: llama_context_p, - logits: CtypesArray[ctypes.c_float], - logits_guidance: CtypesArray[ctypes.c_float], - scale: Union[ctypes.c_float, float], - /, -): - """Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806""" +def llama_sampler_chain_add(chain: llama_sampler_p, smpl: llama_sampler_p, /): ... -# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. -# LLAMA_API void llama_sample_softmax( -# struct llama_context * ctx, -# llama_token_data_array * candidates); +# LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); @ctypes_function( - "llama_sample_softmax", - [llama_context_p_ctypes, llama_token_data_array_p], - None, + "llama_sampler_chain_get", + [llama_sampler_p_ctypes, ctypes.c_int32], + llama_sampler_p_ctypes, ) -def llama_sample_softmax( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - /, -): - """Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.""" +def llama_sampler_chain_get( + chain: llama_sampler_p, i: Union[ctypes.c_int32, int], / +) -> llama_sampler_p: ... -# /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -# LLAMA_API void llama_sample_top_k( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# int32_t k, -# size_t min_keep); +# LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); @ctypes_function( - "llama_sample_top_k", - [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_int32, ctypes.c_size_t], - None, + "llama_sampler_chain_n", + [llama_sampler_p_ctypes], + ctypes.c_int, ) -def llama_sample_top_k( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - k: Union[ctypes.c_int, int], - min_keep: Union[ctypes.c_size_t, int], - /, -): - """Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751""" +def llama_sampler_chain_n(chain: llama_sampler_p, /) -> int: ... -# /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 -# LLAMA_API void llama_sample_top_p( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# float p, -# size_t min_keep); +# // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed +# LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); @ctypes_function( - "llama_sample_top_p", - [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t], - None, + "llama_sampler_chain_remove", + [llama_sampler_p_ctypes, ctypes.c_int32], + llama_sampler_p_ctypes, ) -def llama_sample_top_p( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - p: Union[ctypes.c_float, float], - min_keep: Union[ctypes.c_size_t, int], - /, -): - """Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751""" +def llama_sampler_chain_remove( + chain: llama_sampler_p, i: Union[ctypes.c_int32, int], / +) -> llama_sampler_p: ... -# /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 -# LLAMA_API void llama_sample_min_p( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# float p, -# size_t min_keep); +# // available samplers: +# +# LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); +@ctypes_function("llama_sampler_init_greedy", [], llama_sampler_p_ctypes) +def llama_sampler_init_greedy() -> llama_sampler_p: + ... + + +# LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); +@ctypes_function("llama_sampler_init_dist", [ctypes.c_uint32], llama_sampler_p_ctypes) +def llama_sampler_init_dist(seed: int) -> llama_sampler_p: + ... + + +# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. +# /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. +# DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void), +# "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)"); +@ctypes_function("llama_sampler_init_softmax", [], llama_sampler_p_ctypes) +def llama_sampler_init_softmax() -> llama_sampler_p: + ... + + +# /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 +# LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); +@ctypes_function("llama_sampler_init_top_k", [ctypes.c_int32], llama_sampler_p_ctypes) +def llama_sampler_init_top_k(k: int) -> llama_sampler_p: + ... + + +# /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 +# LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); @ctypes_function( - "llama_sample_min_p", - [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t], - None, + "llama_sampler_init_top_p", + [ctypes.c_float, ctypes.c_size_t], + llama_sampler_p_ctypes, ) -def llama_sample_min_p( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - p: Union[ctypes.c_float, float], - min_keep: Union[ctypes.c_size_t, int], - /, -): - """Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841""" +def llama_sampler_init_top_p(p: float, min_keep: int) -> llama_sampler_p: ... -# /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. -# LLAMA_API void llama_sample_tail_free( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# float z, -# size_t min_keep); +# /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 +# LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); @ctypes_function( - "llama_sample_tail_free", - [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t], - None, + "llama_sampler_init_min_p", + [ctypes.c_float, ctypes.c_size_t], + llama_sampler_p_ctypes, ) -def llama_sample_tail_free( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - z: Union[ctypes.c_float, float], - min_keep: Union[ctypes.c_size_t, int], - /, -): - """Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.""" +def llama_sampler_init_min_p(p: float, min_keep: int) -> llama_sampler_p: ... # /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. -# LLAMA_API void llama_sample_typical( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# float p, -# size_t min_keep); +# LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); @ctypes_function( - "llama_sample_typical", - [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float, ctypes.c_size_t], - None, + "llama_sampler_init_typical", + [ctypes.c_float, ctypes.c_size_t], + llama_sampler_p_ctypes, ) -def llama_sample_typical( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - p: Union[ctypes.c_float, float], - min_keep: Union[ctypes.c_size_t, int], - /, -): - """Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.""" +def llama_sampler_init_typical(p: float, min_keep: int) -> llama_sampler_p: ... -# /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. -# LLAMA_API void llama_sample_entropy( -# struct llama_context * ctx, -# llama_token_data_array * candidates_p, -# float min_temp, -# float max_temp, -# float exponent_val); +# LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); +@ctypes_function("llama_sampler_init_temp", [ctypes.c_float], llama_sampler_p_ctypes) +def llama_sampler_init_temp(t: float) -> llama_sampler_p: + ... + + +# /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. +# LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent); @ctypes_function( - "llama_sample_entropy", - [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_float, - ctypes.c_float, - ], - None, + "llama_sampler_init_temp_ext", + [ctypes.c_float, ctypes.c_float, ctypes.c_float], + llama_sampler_p_ctypes, ) -def llama_sample_entropy( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - min_temp: Union[ctypes.c_float, float], - max_temp: Union[ctypes.c_float, float], - exponent_val: Union[ctypes.c_float, float], - /, -): - """Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.""" +def llama_sampler_init_temp_ext( + t: float, delta: float, exponent: float +) -> llama_sampler_p: ... -# LLAMA_API void llama_sample_temp( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# float temp); +# /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335 +# LLAMA_API struct llama_sampler * llama_sampler_init_xtc (float p, float t, size_t min_keep, uint32_t seed); @ctypes_function( - "llama_sample_temp", - [llama_context_p_ctypes, llama_token_data_array_p, ctypes.c_float], - None, + "llama_sampler_init_xtc", + [ctypes.c_float, ctypes.c_float, ctypes.c_size_t, ctypes.c_uint32], + llama_sampler_p_ctypes, ) -def llama_sample_temp( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - temp: Union[ctypes.c_float, float], - /, -): - """Temperature sampling described in academic paper "Generating Long Sequences with Sparse Transformers" https://arxiv.org/abs/1904.10509 - - Parameters: - candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - temp: The temperature value to use for the sampling. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - """ +def llama_sampler_init_xtc( + p: float, t: float, min_keep: int, seed: int, / +) -> llama_sampler_p: ... @@ -3370,45 +3239,20 @@ def llama_sample_temp( # /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. # /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. # /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -# LLAMA_API llama_token llama_sample_token_mirostat( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# float tau, -# float eta, -# int32_t m, -# float * mu); +# LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( +# int32_t n_vocab, +# uint32_t seed, +# float tau, +# float eta, +# int32_t m); @ctypes_function( - "llama_sample_token_mirostat", - [ - llama_context_p_ctypes, - llama_token_data_array_p, - ctypes.c_float, - ctypes.c_float, - ctypes.c_int32, - ctypes.POINTER(ctypes.c_float), - ], - llama_token, + "llama_sampler_init_mirostat", + [ctypes.c_int32, ctypes.c_uint32, ctypes.c_float, ctypes.c_float, ctypes.c_int32], + llama_sampler_p_ctypes, ) -def llama_sample_token_mirostat( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - tau: Union[ctypes.c_float, float], - eta: Union[ctypes.c_float, float], - m: Union[ctypes.c_int, int], - mu: CtypesPointerOrRef[ctypes.c_float], - /, -) -> int: - """Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - - Parameters: - candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - m: The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - """ +def llama_sampler_init_mirostat( + n_vocab: int, seed: int, tau: float, eta: float, m: int, / +) -> llama_sampler_p: ... @@ -3417,82 +3261,189 @@ def llama_sample_token_mirostat( # /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. # /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. # /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -# LLAMA_API llama_token llama_sample_token_mirostat_v2( -# struct llama_context * ctx, -# llama_token_data_array * candidates, -# float tau, -# float eta, -# float * mu); +# LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( +# uint32_t seed, +# float tau, +# float eta); +@ctypes_function( + "llama_sampler_init_mirostat_v2", + [ctypes.c_uint32, ctypes.c_float, ctypes.c_float], + llama_sampler_p_ctypes, +) +def llama_sampler_init_mirostat_v2( + seed: int, tau: float, eta: float, / +) -> llama_sampler_p: + ... + + +# LLAMA_API struct llama_sampler * llama_sampler_init_grammar( +# const struct llama_model * model, +# const char * grammar_str, +# const char * grammar_root); @ctypes_function( - "llama_sample_token_mirostat_v2", + "llama_sampler_init_grammar", + [llama_model_p_ctypes, ctypes.c_char_p, ctypes.c_char_p], + llama_sampler_p_ctypes, +) +def llama_sampler_init_grammar( + model: llama_model_p, grammar_str: bytes, grammar_root: bytes, / +) -> llama_sampler_p: + ... + + +# LLAMA_API struct llama_sampler * llama_sampler_init_penalties( +# int32_t n_vocab, // llama_n_vocab() +# llama_token special_eos_id, // llama_token_eos() +# llama_token linefeed_id, // llama_token_nl() +# int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) +# float penalty_repeat, // 1.0 = disabled +# float penalty_freq, // 0.0 = disabled +# float penalty_present, // 0.0 = disabled +# bool penalize_nl, // consider newlines as a repeatable token +# bool ignore_eos); // ignore the end-of-sequence token +@ctypes_function( + "llama_sampler_init_penalties", [ - llama_context_p_ctypes, - llama_token_data_array_p, + ctypes.c_int32, + llama_token, + llama_token, + ctypes.c_int32, ctypes.c_float, ctypes.c_float, - ctypes.POINTER(ctypes.c_float), + ctypes.c_float, + ctypes.c_bool, + ctypes.c_bool, ], - llama_token, -) -def llama_sample_token_mirostat_v2( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] + llama_sampler_p_ctypes, +) +def llama_sampler_init_penalties( + n_vocab: int, + special_eos_id: int, + linefeed_id: int, + penalty_last_n: int, + penalty_repeat: float, + penalty_freq: float, + penalty_present: float, + penalize_nl: bool, + ignore_eos: bool, + /, +) -> llama_sampler_p: + ... + + +# /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 +# LLAMA_API struct llama_sampler * llama_sampler_init_dry( +# const struct llama_model * model, +# float dry_multiplier, +# float dry_base, +# int32_t dry_allowed_length, +# int32_t dry_penalty_last_n, +# const char ** seq_breakers, +# size_t num_breakers); +@ctypes_function( + "llama_sampler_init_dry", + [ + llama_model_p_ctypes, + ctypes.c_float, + ctypes.c_float, + ctypes.c_int32, + ctypes.c_int32, + ctypes.POINTER(ctypes.c_char_p), + ctypes.c_size_t, ], - tau: Union[ctypes.c_float, float], - eta: Union[ctypes.c_float, float], - mu: CtypesPointerOrRef[ctypes.c_float], + llama_sampler_p_ctypes, +) +def llama_sampler_init_dry( + model: llama_model_p, + dry_multiplier: float, + dry_base: float, + dry_allowed_length: int, + dry_penalty_last_n: int, + seq_breakers: CtypesArray[bytes], + num_breakers: int, /, -) -> int: - """Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +) -> llama_sampler_p: + ... - Parameters: - candidates: A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - tau: The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - eta: The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - mu: Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - """ + +# LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( +# int32_t n_vocab, +# int32_t n_logit_bias, +# const llama_logit_bias * logit_bias); +@ctypes_function( + "llama_sampler_init_logit_bias", + [ctypes.c_int32, ctypes.c_int32, llama_logit_bias_p], + llama_sampler_p_ctypes, +) +def llama_sampler_init_logit_bias( + n_vocab: int, n_logit_bias: int, logit_bias: CtypesArray[llama_logit_bias], / +) -> llama_sampler_p: ... -# /// @details Selects the token with the highest probability. -# /// Does not compute the token probabilities. Use llama_sample_softmax() instead. -# LLAMA_API llama_token llama_sample_token_greedy( -# struct llama_context * ctx, -# llama_token_data_array * candidates); +# // this sampler is meant to be used for fill-in-the-middle infilling +# // it's supposed to be used after top_k + top_p sampling +# // +# // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG +# // 2. combine probs of tokens that have the same prefix +# // +# // example: +# // +# // - before: +# // "hel": 0.5 +# // "hell": 0.2 +# // "hello": 0.1 +# // "dummy": 0.1 +# // +# // - after: +# // "hel": 0.8 +# // "dummy": 0.1 +# // +# // 3. discard non-EOG tokens with low prob +# // 4. if no tokens are left -> pick EOT +# // +# LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model); @ctypes_function( - "llama_sample_token_greedy", - [llama_context_p_ctypes, llama_token_data_array_p], - llama_token, + "llama_sampler_init_infill", + [llama_model_p_ctypes], + llama_sampler_p_ctypes, ) -def llama_sample_token_greedy( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - /, -) -> int: - """Selects the token with the highest probability.""" +def llama_sampler_init_infill(model: llama_model_p, /) -> llama_sampler_p: + """This sampler is meant to be used for fill-in-the-middle infilling. + """ ... -# /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. -# LLAMA_API llama_token llama_sample_token( -# struct llama_context * ctx, -# llama_token_data_array * candidates); +# // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise +# LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); @ctypes_function( - "llama_sample_token", - [llama_context_p_ctypes, llama_token_data_array_p], + "llama_sampler_get_seed", + [llama_sampler_p_ctypes], + ctypes.c_uint32, +) +def llama_sampler_get_seed(smpl: llama_sampler_p, /) -> int: + ... + + +# /// @details Sample and accept a token from the idx-th output of the last evaluation +# // +# // Shorthand for: +# // const auto * logits = llama_get_logits_ith(ctx, idx); +# // llama_token_data_array cur_p = { ... init from logits ... }; +# // llama_sampler_apply(smpl, &cur_p); +# // auto token = cur_p.data[cur_p.selected].id; +# // llama_sampler_accept(smpl, token); +# // return token; +# // Returns the sampled token +# LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); +@ctypes_function( + "llama_sampler_sample", + [llama_sampler_p_ctypes, llama_context_p_ctypes, ctypes.c_int32], llama_token, ) -def llama_sample_token( - ctx: llama_context_p, - candidates: Union[ - CtypesArray[llama_token_data_array], CtypesPointerOrRef[llama_token_data_array] - ], - /, +def llama_sampler_sample( + smpl: llama_sampler_p, ctx: llama_context_p, idx: int, / ) -> int: - """Randomly selects a token from the candidates based on their probabilities.""" ... @@ -3543,79 +3494,139 @@ def llama_split_prefix( ... -# Performance information +# // Print system information +# LLAMA_API const char * llama_print_system_info(void); +@ctypes_function("llama_print_system_info", [], ctypes.c_char_p) +def llama_print_system_info() -> bytes: + ... + + +# // Set callback for all future logging events. +# // If this is not called, or NULL is supplied, everything is output on stderr. +# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); +@ctypes_function( + "llama_log_set", + [ctypes.c_void_p, ctypes.c_void_p], + None, +) +def llama_log_set( + log_callback: Optional[CtypesFuncPointer], + user_data: ctypes.c_void_p, + /, +): + """Set callback for all future logging events. + + If this is not called, or NULL is supplied, everything is output on stderr.""" + ... + + +# // +# // Performance utils +# // +# // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements. +# // + + +# struct llama_perf_context_data { +# double t_start_ms; +# double t_load_ms; +# double t_p_eval_ms; +# double t_eval_ms; +# +# int32_t n_p_eval; +# int32_t n_eval; +# }; +class llama_perf_context_data(ctypes.Structure): + _fields_ = [ + ("t_start_ms", ctypes.c_double), + ("t_load_ms", ctypes.c_double), + ("t_p_eval_ms", ctypes.c_double), + ("t_eval_ms", ctypes.c_double), + ("n_p_eval", ctypes.c_int32), + ("n_eval", ctypes.c_int32), + ] + + +# struct llama_perf_sampler_data { +# double t_sample_ms; +# +# int32_t n_sample; +# }; +class llama_perf_sampler_data(ctypes.Structure): + _fields_ = [ + ("t_sample_ms", ctypes.c_double), + ("n_sample", ctypes.c_int32), + ] -# LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); +# LLAMA_API struct llama_perf_context_data llama_perf_context (const struct llama_context * ctx); @ctypes_function( - "llama_get_timings", + "llama_perf_context", [llama_context_p_ctypes], - llama_timings, + llama_perf_context_data, ) -def llama_get_timings(ctx: llama_context_p, /) -> llama_timings: - """Get performance information""" +def llama_perf_context(ctx: llama_context_p, /) -> llama_perf_context_data: ... -# LLAMA_API void llama_print_timings(struct llama_context * ctx); +# LLAMA_API void llama_perf_context_print(const struct llama_context * ctx); @ctypes_function( - "llama_print_timings", + "llama_perf_context_print", [llama_context_p_ctypes], None, ) -def llama_print_timings(ctx: llama_context_p, /): - """Print performance information""" +def llama_perf_context_print(ctx: llama_context_p, /): ... -# LLAMA_API void llama_reset_timings(struct llama_context * ctx); +# LLAMA_API void llama_perf_context_reset( struct llama_context * ctx); @ctypes_function( - "llama_reset_timings", + "llama_perf_context_reset", [llama_context_p_ctypes], None, ) -def llama_reset_timings(ctx: llama_context_p, /): - """Reset performance information""" +def llama_perf_context_reset(ctx: llama_context_p, /): ... -# Print system information -# LLAMA_API const char * llama_print_system_info(void); +# // NOTE: the following work only with samplers constructed via llama_sampler_chain_init +# LLAMA_API struct llama_perf_sampler_data llama_perf_sampler (const struct llama_sampler * chain); @ctypes_function( - "llama_print_system_info", - [], - ctypes.c_char_p, + "llama_perf_sampler", + [llama_sampler_p_ctypes], + llama_perf_sampler_data, ) -def llama_print_system_info() -> bytes: - """Print system information""" +def llama_perf_sampler(chain: llama_sampler_p, /) -> llama_perf_sampler_data: ... -# NOTE: THIS IS CURRENTLY BROKEN AS ggml_log_callback IS NOT EXPOSED IN LLAMA.H -# // Set callback for all future logging events. -# // If this is not called, or NULL is supplied, everything is output on stderr. -# LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); +# LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); @ctypes_function( - "llama_log_set", - [ctypes.c_void_p, ctypes.c_void_p], + "llama_perf_sampler_print", + [llama_sampler_p_ctypes], None, ) -def llama_log_set( - log_callback: Optional[CtypesFuncPointer], - user_data: ctypes.c_void_p, - /, -): - """Set callback for all future logging events. +def llama_perf_sampler_print(chain: llama_sampler_p, /): + ... - If this is not called, or NULL is supplied, everything is output on stderr.""" + +# LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); +@ctypes_function( + "llama_perf_sampler_reset", + [llama_sampler_p_ctypes], + None, +) +def llama_perf_sampler_reset(chain: llama_sampler_p, /): ... -# LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); +# LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx); @ctypes_function( - "llama_dump_timing_info_yaml", - [ctypes.c_void_p, llama_context_p_ctypes], + "llama_perf_dump_yaml", + [ctypes.POINTER(ctypes.c_void_p), llama_context_p_ctypes], None, ) -def llama_dump_timing_info_yaml(stream: ctypes.c_void_p, ctx: llama_context_p, /): - ... \ No newline at end of file +def llama_perf_dump_yaml( + stream: ctypes.POINTER(ctypes.c_void_p), ctx: llama_context_p, / +): + ... diff --git a/nexa/gguf/llama/llama_grammar.py b/nexa/gguf/llama/llama_grammar.py index 2fc20d05..b95c77ab 100644 --- a/nexa/gguf/llama/llama_grammar.py +++ b/nexa/gguf/llama/llama_grammar.py @@ -2,11 +2,6 @@ # flake8: noqa from pathlib import Path -import sys -import ctypes -import enum -import typing -import dataclasses from itertools import groupby from typing import ( @@ -18,883 +13,18 @@ Union, ) -import nexa.gguf.llama.llama_cpp as llama_cpp - -class GrammarElementType(enum.IntEnum): - END = llama_cpp.LLAMA_GRETYPE_END - ALT = llama_cpp.LLAMA_GRETYPE_ALT - RULE_REF = llama_cpp.LLAMA_GRETYPE_RULE_REF - CHAR = llama_cpp.LLAMA_GRETYPE_CHAR - CHAR_NOT = llama_cpp.LLAMA_GRETYPE_CHAR_NOT - CHAR_RNG_UPPER = llama_cpp.LLAMA_GRETYPE_CHAR_RNG_UPPER - CHAR_ALT = llama_cpp.LLAMA_GRETYPE_CHAR_ALT - CHAR_ANY = llama_cpp.LLAMA_GRETYPE_CHAR_ANY - - -@dataclasses.dataclass -class GrammarElement: - type: GrammarElementType - value: int - - -@dataclasses.dataclass -class ParseState: - symbol_ids: typing.Dict[str, int] = dataclasses.field(default_factory=dict) - rules: typing.List[typing.List[GrammarElement]] = dataclasses.field(default_factory=list) - - -# static std::pair decode_utf8(const char * src) { -# static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; -# uint8_t first_byte = static_cast(*src); -# uint8_t highbits = first_byte >> 4; -# int len = lookup[highbits]; -# uint8_t mask = (1 << (8 - len)) - 1; -# uint32_t value = first_byte & mask; -# const char * end = src + len; // may overrun! -# const char * pos = src + 1; -# for ( ; pos < end && *pos; pos++) { -# value = (value << 6) + (static_cast(*pos) & 0x3F); -# } -# return std::make_pair(value, pos); -# } -def decode_utf8(src: str) -> typing.Tuple[int, str]: - lookup: list[int] = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4] - first_byte: int = ord(src[0]) - highbits: int = first_byte >> 4 - length: int = lookup[highbits] - mask: int = (1 << (8 - length)) - 1 - value: int = first_byte & mask - end: int = min(len(src), length) # Prevent overrun - - pos: int = 1 - for pos in range(1, end): - if not src[pos]: - break - value = (value << 6) + (ord(src[pos]) & 0x3F) - - return value, src[pos:] if pos < len(src) else "" - - -# static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { -# uint32_t next_id = static_cast(state.symbol_ids.size()); -# auto result = state.symbol_ids.emplace(std::string(src, len), next_id); -# return result.first->second; -# } -def get_symbol_id(state: ParseState, name: str) -> int: - next_id = len(state.symbol_ids) - return state.symbol_ids.setdefault(name, next_id) - - -# static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { -# uint32_t next_id = static_cast(state.symbol_ids.size()); -# state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; -# return next_id; -# } -def generate_symbol_id(state: ParseState, base_name: str) -> int: - next_id = len(state.symbol_ids) - state.symbol_ids[f"{base_name}_{next_id}"] = next_id - return next_id - - -# static void add_rule( -# parse_state & state, -# uint32_t rule_id, -# const std::vector & rule) { -# if (state.rules.size() <= rule_id) { -# state.rules.resize(rule_id + 1); -# } -# state.rules[rule_id] = rule; -# } -def add_rule(state: ParseState, rule_id: int, rule: typing.List[GrammarElement]) -> None: - if len(state.rules) <= rule_id: - state.rules.extend([[]] * (rule_id + 1 - len(state.rules))) - state.rules[rule_id] = rule - - -# static bool is_digit_char(char c) { -# return '0' <= c && c <= '9'; -# } -def is_digit_char(c: str) -> bool: - return "0" <= c <= "9" - - -# static bool is_word_char(char c) { -# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); -# } -def is_word_char(c: str) -> bool: - return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or is_digit_char(c) - - -# static std::pair parse_hex(const char * src, int size) { -# const char * pos = src; -# const char * end = src + size; -# uint32_t value = 0; -# for ( ; pos < end && *pos; pos++) { -# value <<= 4; -# char c = *pos; -# if ('a' <= c && c <= 'f') { -# value += c - 'a' + 10; -# } else if ('A' <= c && c <= 'F') { -# value += c - 'A' + 10; -# } else if ('0' <= c && c <= '9') { -# value += c - '0'; -# } else { -# break; -# } -# } -# if (pos != end) { -# throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); -# } -# return std::make_pair(value, pos); -# } -def parse_hex(src: str, size: int) -> typing.Tuple[int, str]: - pos = 0 - value = 0 - for _ in range(size): - value <<= 4 - c = src[pos] - if "a" <= c <= "f": - value += ord(c) - ord("a") + 10 - elif "A" <= c <= "F": - value += ord(c) - ord("A") + 10 - elif "0" <= c <= "9": - value += ord(c) - ord("0") - else: - break - pos += 1 - if pos != size: - raise ValueError(f"expecting {size} hex chars at {src}") - return value, src[pos:] - - -# static const char * parse_space(const char * src, bool newline_ok) { -# const char * pos = src; -# while (*pos == ' ' || *pos == '\t' || *pos == '#' || -# (newline_ok && (*pos == '\r' || *pos == '\n'))) { -# if (*pos == '#') { -# while (*pos && *pos != '\r' && *pos != '\n') { -# pos++; -# } -# } else { -# pos++; -# } -# } -# return pos; -# } -def parse_space(src: str, newline_ok: bool) -> str: - pos = src - while pos and (pos[0] in (' ', '\t', '#') or (newline_ok and pos[0] in ('\r', '\n'))): - if pos[0] == "#": - while pos and pos[0] not in ("\r", "\n"): - pos = pos[1:] - else: - pos = pos[1:] - return pos - - -# static const char * parse_name(const char * src) { -# const char * pos = src; -# while (is_word_char(*pos)) { -# pos++; -# } -# if (pos == src) { -# throw std::runtime_error(std::string("expecting name at ") + src); -# } -# return pos; -# } -def parse_name(src: str) -> typing.Tuple[str, str]: - pos = src - while pos and is_word_char(pos[0]): - pos = pos[1:] - if pos == src: - raise ValueError(f"expecting name at {src}") - return src[:len(src) - len(pos)], pos - -# static const char * parse_int(const char * src) { -# const char * pos = src; -# while (is_digit_char(*pos)) { -# pos++; -# } -# if (pos == src) { -# throw std::runtime_error(std::string("expecting integer at ") + src); -# } -# return pos; -# } -def parse_int(src: str) -> typing.Tuple[int, str]: - pos = src - while pos and is_digit_char(pos[0]): - pos = pos[1:] - if pos == src: - raise ValueError(f"expecting integer at {src}") - return int(src[:len(src) - len(pos)]), pos - - -# static std::pair parse_char(const char * src) { -# if (*src == '\\') { -# switch (src[1]) { -# case 'x': return parse_hex(src + 2, 2); -# case 'u': return parse_hex(src + 2, 4); -# case 'U': return parse_hex(src + 2, 8); -# case 't': return std::make_pair('\t', src + 2); -# case 'r': return std::make_pair('\r', src + 2); -# case 'n': return std::make_pair('\n', src + 2); -# case '\\': -# case '"': -# case '[': -# case ']': -# return std::make_pair(src[1], src + 2); -# default: -# throw std::runtime_error(std::string("unknown escape at ") + src); -# } -# } else if (*src) { -# return decode_utf8(src); -# } -# throw std::runtime_error("unexpected end of input"); -# } -def parse_char(src: str) -> typing.Tuple[int, str]: - if not src: - raise ValueError("unexpected end of input") - if src[0] == "\\": - if src[1] == "x": - return parse_hex(src[2:], 2) - elif src[1] == "u": - return parse_hex(src[2:], 4) - elif src[1] == "U": - return parse_hex(src[2:], 8) - elif src[1] == "t": - return ord("\t"), src[2:] - elif src[1] == "r": - return ord("\r"), src[2:] - elif src[1] == "n": - return ord("\n"), src[2:] - elif src[1] in ('\\', '"', '[', ']'): - return ord(src[1]), src[2:] - else: - raise ValueError(f"unknown escape at {src}") - return decode_utf8(src) - -# static const char * parse_sequence( -# parse_state & state, -# const char * src, -# const std::string & rule_name, -# std::vector & out_elements, -# bool is_nested) { -# size_t last_sym_start = out_elements.size(); -# const char * pos = src; -# -# auto handle_repetitions = [&](int min_times, int max_times) { -# -# if (last_sym_start == out_elements.size()) { -# throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); -# } -# -# // apply transformation to previous symbol (last_sym_start to end) according to -# // the following rewrite rules: -# // S{m,n} --> S S S (m times) S'(n-m) -# // S'(x) ::= S S'(x-1) | -# // (... n-m definitions of these S' rules ...) -# // S'(1) ::= S | -# // S{m,} --> S S S (m times) S' -# // S' ::= S S' | -# // S* --> S{0,} -# // --> S' ::= S S' | -# // S+ --> S{1,} -# // --> S S' -# // S' ::= S S' | -# // S? --> S{0,1} -# // --> S' -# // S' ::= S | -# -# std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); -# if (min_times == 0) { -# out_elements.resize(last_sym_start); -# } else { -# // Repeat the previous elements (min_times - 1) times -# for (int i = 1; i < min_times; i++) { -# out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); -# } -# } -# -# uint32_t last_rec_rule_id = 0; -# auto n_opt = max_times < 0 ? 1 : max_times - min_times; -# -# std::vector rec_rule(previous_elements); -# for (int i = 0; i < n_opt; i++) { -# rec_rule.resize(previous_elements.size()); -# uint32_t rec_rule_id = generate_symbol_id(state, rule_name); -# if (i > 0 || max_times < 0) { -# rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); -# } -# rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); -# rec_rule.push_back({LLAMA_GRETYPE_END, 0}); -# add_rule(state, rec_rule_id, rec_rule); -# last_rec_rule_id = rec_rule_id; -# } -# if (n_opt > 0) { -# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); -# } -# }; -# -# while (*pos) { -# if (*pos == '"') { // literal string -# pos++; -# last_sym_start = out_elements.size(); -# while (*pos != '"') { -# if (!*pos) { -# throw std::runtime_error("unexpected end of input"); -# } -# auto char_pair = parse_char(pos); -# pos = char_pair.second; -# out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); -# } -# pos = parse_space(pos + 1, is_nested); -# } else if (*pos == '[') { // char range(s) -# pos++; -# enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; -# if (*pos == '^') { -# pos++; -# start_type = LLAMA_GRETYPE_CHAR_NOT; -# } -# last_sym_start = out_elements.size(); -# while (*pos != ']') { -# if (!*pos) { -# throw std::runtime_error("unexpected end of input"); -# } -# auto char_pair = parse_char(pos); -# pos = char_pair.second; -# enum llama_gretype type = last_sym_start < out_elements.size() -# ? LLAMA_GRETYPE_CHAR_ALT -# : start_type; -# -# out_elements.push_back({type, char_pair.first}); -# if (pos[0] == '-' && pos[1] != ']') { -# if (!pos[1]) { -# throw std::runtime_error("unexpected end of input"); -# } -# auto endchar_pair = parse_char(pos + 1); -# pos = endchar_pair.second; -# out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); -# } -# } -# pos = parse_space(pos + 1, is_nested); -# } else if (is_word_char(*pos)) { // rule reference -# const char * name_end = parse_name(pos); -# uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); -# pos = parse_space(name_end, is_nested); -# last_sym_start = out_elements.size(); -# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); -# } else if (*pos == '(') { // grouping -# // parse nested alternates into synthesized rule -# pos = parse_space(pos + 1, true); -# uint32_t sub_rule_id = generate_symbol_id(state, rule_name); -# pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); -# last_sym_start = out_elements.size(); -# // output reference to synthesized rule -# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); -# if (*pos != ')') { -# throw std::runtime_error(std::string("expecting ')' at ") + pos); -# } -# pos = parse_space(pos + 1, is_nested); -# } else if (*pos == '.') { // any char -# last_sym_start = out_elements.size(); -# out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); -# pos = parse_space(pos + 1, is_nested); -# } else if (*pos == '*') { -# pos = parse_space(pos + 1, is_nested); -# handle_repetitions(0, -1); -# } else if (*pos == '+') { -# pos = parse_space(pos + 1, is_nested); -# handle_repetitions(1, -1); -# } else if (*pos == '?') { -# pos = parse_space(pos + 1, is_nested); -# handle_repetitions(0, 1); -# } else if (*pos == '{') { -# pos = parse_space(pos + 1, is_nested); -# -# if (!is_digit_char(*pos)) { -# throw std::runtime_error(std::string("expecting an int at ") + pos); -# } -# const char * int_end = parse_int(pos); -# int min_times = std::stoul(std::string(pos, int_end - pos)); -# pos = parse_space(int_end, is_nested); -# -# int max_times = -1; -# -# if (*pos == '}') { -# max_times = min_times; -# pos = parse_space(pos + 1, is_nested); -# } else if (*pos == ',') { -# pos = parse_space(pos + 1, is_nested); -# -# if (is_digit_char(*pos)) { -# const char * int_end = parse_int(pos); -# max_times = std::stoul(std::string(pos, int_end - pos)); -# pos = parse_space(int_end, is_nested); -# } -# -# if (*pos != '}') { -# throw std::runtime_error(std::string("expecting '}' at ") + pos); -# } -# pos = parse_space(pos + 1, is_nested); -# } else { -# throw std::runtime_error(std::string("expecting ',' at ") + pos); -# } -# handle_repetitions(min_times, max_times); -# } else { -# break; -# } -# } -# return pos; -# } -def parse_sequence(state: ParseState, src: str, rule_name: str, out_elements: typing.List[GrammarElement], is_nested: bool) -> str: - last_sym_start = len(out_elements) - pos = src - - def handle_repetitions(min_times: int, max_times: int) -> None: - nonlocal state, src, rule_name, out_elements, is_nested, last_sym_start, pos - - if last_sym_start == len(out_elements): - raise ValueError(f"expecting preceding item to */+/?/{{ at {pos}") - - previous_elements = out_elements[last_sym_start:] - if min_times == 0: - del out_elements[last_sym_start:] - else: - for i in range(1, min_times): - out_elements.extend(previous_elements) - - last_rec_rule_id = 0 - n_opt = 1 if max_times < 0 else max_times - min_times - - rec_rule = previous_elements[:] - for i in range(n_opt): - rec_rule = rec_rule[:len(previous_elements)] - rec_rule_id = generate_symbol_id(state, rule_name) - if i > 0 or max_times < 0: - rec_rule.append(GrammarElement(GrammarElementType.RULE_REF, rec_rule_id if max_times < 0 else last_rec_rule_id)) - rec_rule.append(GrammarElement(GrammarElementType.ALT, 0)) - rec_rule.append(GrammarElement(GrammarElementType.END, 0)) - add_rule(state, rec_rule_id, rec_rule) - last_rec_rule_id = rec_rule_id - if n_opt > 0: - out_elements.append(GrammarElement(GrammarElementType.RULE_REF, last_rec_rule_id)) - - while pos: - if pos[0] == '"': - pos = pos[1:] - last_sym_start = len(out_elements) - while not pos.startswith('"'): - if not pos: - raise ValueError("unexpected end of input") - char, pos = parse_char(pos) - out_elements.append(GrammarElement(GrammarElementType.CHAR, char)) - pos = parse_space(pos[1:], is_nested) - elif pos[0] == "[": - pos = pos[1:] - start_type = GrammarElementType.CHAR - if pos[0] == "^": - pos = pos[1:] - start_type = GrammarElementType.CHAR_NOT - last_sym_start = len(out_elements) - while pos[0] != "]": - if not pos: - raise ValueError("unexpected end of input") - char, pos = parse_char(pos) - type = GrammarElementType.CHAR_ALT if last_sym_start < len(out_elements) else start_type - out_elements.append(GrammarElement(type, char)) - if pos[0] == "-" and pos[1] != "]": - if not pos[1]: - raise ValueError("unexpected end of input") - endchar, pos = parse_char(pos[1:]) - out_elements.append(GrammarElement(GrammarElementType.CHAR_RNG_UPPER, endchar)) - pos = parse_space(pos[1:], is_nested) - elif pos and is_word_char(pos[0]): - name, rest = parse_name(pos) - ref_rule_id = get_symbol_id(state, name) - pos = parse_space(rest, is_nested) - last_sym_start = len(out_elements) - out_elements.append(GrammarElement(GrammarElementType.RULE_REF, ref_rule_id)) - elif pos.startswith("("): - pos = parse_space(pos[1:], newline_ok=True) - sub_rule_id = generate_symbol_id(state, rule_name) - pos = parse_alternates(state, pos, rule_name, sub_rule_id, is_nested=True) - last_sym_start = len(out_elements) - out_elements.append(GrammarElement(GrammarElementType.RULE_REF, sub_rule_id)) - if pos[0] != ")": - raise ValueError(f"expecting ')' at {pos}") - pos = parse_space(pos[1:], is_nested) - elif pos.startswith("."): - last_sym_start = len(out_elements) - out_elements.append(GrammarElement(GrammarElementType.CHAR_ANY, 0)) - pos = parse_space(pos[1:], is_nested) - elif pos.startswith("*"): - pos = parse_space(pos[1:], is_nested) - handle_repetitions(0, -1) - elif pos.startswith("+"): - pos = parse_space(pos[1:], is_nested) - handle_repetitions(1, -1) - elif pos.startswith("?"): - pos = parse_space(pos[1:], is_nested) - handle_repetitions(0, 1) - elif pos.startswith("{"): - pos = parse_space(pos[1:], is_nested) - - if not is_digit_char(pos): - raise ValueError(f"expecting an int at {pos}") - min_times, pos = parse_int(pos) - pos = parse_space(pos, is_nested) - - max_times = -1 - - if pos[0] == "}": - max_times = min_times - pos = parse_space(pos[1:], is_nested) - elif pos[0] == ",": - pos = parse_space(pos[1:], is_nested) - - if is_digit_char(pos): - max_times, pos = parse_int(pos) - pos = parse_space(pos, is_nested) - - if pos[0] != "}": - raise ValueError("expecting '}' at {}".format(pos)) - - pos = parse_space(pos[1:], is_nested) - else: - raise ValueError(f"expecting ',' at {pos}") - handle_repetitions(min_times, max_times) - else: - break - return pos - - -# const char * parse_alternates( -# parse_state & state, -# const char * src, -# const std::string & rule_name, -# uint32_t rule_id, -# bool is_nested) { -# std::vector rule; -# const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); -# while (*pos == '|') { -# rule.push_back({LLAMA_GRETYPE_ALT, 0}); -# pos = parse_space(pos + 1, true); -# pos = parse_sequence(state, pos, rule_name, rule, is_nested); -# } -# rule.push_back({LLAMA_GRETYPE_END, 0}); -# add_rule(state, rule_id, rule); -# return pos; -# } -def parse_alternates(state: ParseState, src: str, rule_name: str, rule_id: int, is_nested: bool) -> str: - rule = [] - pos = parse_sequence(state, src, rule_name, rule, is_nested) - while pos.startswith("|"): - rule.append(GrammarElement(GrammarElementType.ALT, 0)) - pos = parse_space(pos[1:], newline_ok=True) - pos = parse_sequence(state, pos, rule_name, rule, is_nested) - rule.append(GrammarElement(GrammarElementType.END, 0)) - add_rule(state, rule_id, rule) - return pos - - -# static const char * parse_rule(parse_state & state, const char * src) { -# const char * name_end = parse_name(src); -# const char * pos = parse_space(name_end, false); -# size_t name_len = name_end - src; -# uint32_t rule_id = get_symbol_id(state, src, name_len); -# const std::string name(src, name_len); -# -# if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { -# throw std::runtime_error(std::string("expecting ::= at ") + pos); -# } -# pos = parse_space(pos + 3, true); -# -# pos = parse_alternates(state, pos, name, rule_id, false); -# -# if (*pos == '\r') { -# pos += pos[1] == '\n' ? 2 : 1; -# } else if (*pos == '\n') { -# pos++; -# } else if (*pos) { -# throw std::runtime_error(std::string("expecting newline or end at ") + pos); -# } -# return parse_space(pos, true); -# } -def parse_rule(state: ParseState, src: str) -> str: - pos = src - name, pos = parse_name(pos) - pos = parse_space(pos, newline_ok=False) - rule_id = get_symbol_id(state, name) - - if not pos.startswith("::="): - raise ValueError(f"expecting ::= at {pos}") - - pos = parse_space(pos[3:], newline_ok=True) - - pos = parse_alternates(state, pos, name, rule_id, is_nested=False) - - if pos.startswith("\r"): - pos = pos[2:] if pos[1] == "\n" else pos[1:] - elif pos.startswith("\n"): - pos = pos[1:] - elif pos: - raise ValueError(f"expecting newline or end at {pos}") - return parse_space(pos, newline_ok=True) - - -# parse_state parse(const char * src) { -# try { -# parse_state state; -# const char * pos = parse_space(src, true); -# while (*pos) { -# pos = parse_rule(state, pos); -# } -# // Validate the state to ensure that all rules are defined -# for (const auto & rule : state.rules) { -# for (const auto & elem : rule) { -# if (elem.type == LLAMA_GRETYPE_RULE_REF) { -# // Ensure that the rule at that location exists -# if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { -# // Get the name of the rule that is missing -# for (const auto & kv : state.symbol_ids) { -# if (kv.second == elem.value) { -# throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); -# } -# } -# } -# } -# } -# } -# return state; -# } catch (const std::exception & err) { -# fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); -# return parse_state(); -# } -# } -def parse(src: str) -> ParseState: - state = ParseState() - pos = src - pos = parse_space(pos, newline_ok=True) - while pos: - pos = parse_rule(state, pos) - # validate - for rule in state.rules: - for elem in rule: - if elem.type == GrammarElementType.RULE_REF: - if elem.value >= len(state.rules) or not state.rules[elem.value]: - for k, v in state.symbol_ids.items(): - if v == elem.value: - raise ValueError(f"Undefined rule identifier '{k}'") - return state - - -# static bool is_char_element(llama_grammar_element elem) { -# switch (elem.type) { -# case LLAMA_GRETYPE_CHAR: return true; -# case LLAMA_GRETYPE_CHAR_NOT: return true; -# case LLAMA_GRETYPE_CHAR_ALT: return true; -# case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; -# case LLAMA_GRETYPE_CHAR_ANY: return true; -# default: return false; -# } -# } -def is_char_element(elem: GrammarElement) -> bool: - return elem.type in ( - GrammarElementType.CHAR, - GrammarElementType.CHAR_NOT, - GrammarElementType.CHAR_ALT, - GrammarElementType.CHAR_RNG_UPPER, - GrammarElementType.CHAR_ANY - ) - - -def print_grammar_char(file: typing.TextIO, c: int) -> None: - if 0x20 <= c <= 0x7f: - print(chr(c), end="", file=file) - else: - print(f"", end="", file=file) - - -# static void print_rule( -# FILE * file, -# uint32_t rule_id, -# const std::vector & rule, -# const std::map & symbol_id_names) { -# if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { -# throw std::runtime_error( -# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); -# } -# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); -# for (size_t i = 0, end = rule.size() - 1; i < end; i++) { -# llama_grammar_element elem = rule[i]; -# switch (elem.type) { -# case LLAMA_GRETYPE_END: -# throw std::runtime_error( -# "unexpected end of rule: " + std::to_string(rule_id) + "," + -# std::to_string(i)); -# case LLAMA_GRETYPE_ALT: -# fprintf(file, "| "); -# break; -# case LLAMA_GRETYPE_RULE_REF: -# fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); -# break; -# case LLAMA_GRETYPE_CHAR: -# fprintf(file, "["); -# print_grammar_char(file, elem.value); -# break; -# case LLAMA_GRETYPE_CHAR_NOT: -# fprintf(file, "[^"); -# print_grammar_char(file, elem.value); -# break; -# case LLAMA_GRETYPE_CHAR_RNG_UPPER: -# if (i == 0 || !is_char_element(rule[i - 1])) { -# throw std::runtime_error( -# "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + -# std::to_string(rule_id) + "," + std::to_string(i)); -# } -# fprintf(file, "-"); -# print_grammar_char(file, elem.value); -# break; -# case LLAMA_GRETYPE_CHAR_ALT: -# if (i == 0 || !is_char_element(rule[i - 1])) { -# throw std::runtime_error( -# "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + -# std::to_string(rule_id) + "," + std::to_string(i)); -# } -# print_grammar_char(file, elem.value); -# break; -# case LLAMA_GRETYPE_CHAR_ANY: -# fprintf(file, "."); -# break; -# } -# if (is_char_element(elem)) { -# switch (rule[i + 1].type) { -# case LLAMA_GRETYPE_CHAR_ALT: -# case LLAMA_GRETYPE_CHAR_RNG_UPPER: -# case LLAMA_GRETYPE_CHAR_ANY: -# break; -# default: -# fprintf(file, "] "); -# } -# } -# } -# fprintf(file, "\n"); -# } -def print_rule( - file: typing.TextIO, - rule_id: int, - rule: typing.List[GrammarElement], - symbol_id_names: typing.Dict[int, str], -) -> None: - if not rule or rule[-1].type != GrammarElementType.END: - raise ValueError(f"malformed rule, does not end with LLAMA_GRETYPE_END: {rule_id}") - - print(f"{symbol_id_names[rule_id]} ::=", end=" ", file=file) - - for i, elem in enumerate(rule[:-1]): - if elem.type == GrammarElementType.END: - raise ValueError(f"unexpected end of rule: {rule_id}, {i}") - if elem.type == GrammarElementType.ALT: - print("| ", end="", file=file) - elif elem.type == GrammarElementType.RULE_REF: - print(f"{symbol_id_names[elem.value]} ", end="", file=file) - elif elem.type == GrammarElementType.CHAR: - print("[", end="", file=file) - print_grammar_char(file, elem.value) - elif elem.type == GrammarElementType.CHAR_NOT: - print("[^", end="", file=file) - print_grammar_char(file, elem.value) - elif elem.type == GrammarElementType.CHAR_RNG_UPPER: - if i == 0 or not is_char_element(rule[i - 1]): - raise ValueError(f"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: {rule_id}, {i}") - print(f"-", end="", file=file) - print_grammar_char(file, elem.value) - elif elem.type == GrammarElementType.CHAR_ALT: - if i == 0 or not is_char_element(rule[i - 1]): - raise ValueError(f"LLAMA_GRETYPE_CHAR_ALT without preceding char: {rule_id}, {i}") - print_grammar_char(file, elem.value) - elif elem.type == GrammarElementType.CHAR_ANY: - print(".", end="", file=file) - if is_char_element(elem): - if rule[i + 1].type in (GrammarElementType.CHAR_ALT, GrammarElementType.CHAR_RNG_UPPER, GrammarElementType.CHAR_ANY): - continue - print("] ", end="", file=file) - print(file=file) - - -def print_grammar(file: typing.TextIO, state: ParseState) -> None: - try: - symbol_id_names = {v: k for k, v in state.symbol_ids.items()} - for i, rule in enumerate(state.rules): - print_rule(file, i, rule, symbol_id_names) - except Exception as err: - print(f"\nerror printing grammar: {err}", file=file) - raise err +LLAMA_GRAMMAR_DEFAULT_ROOT = "root" class LlamaGrammar: - def __init__(self, parse_state: ParseState): - self.parse_state = parse_state - - self._grammar_rules = parse_state.rules - self._n_rules = len(self._grammar_rules) - self._start_rule_index = parse_state.symbol_ids["root"] - - self._element_lists = [ - [ - llama_cpp.llama_grammar_element(ctypes.c_int(elem.type), ctypes.c_uint32(elem.value)) - for elem in subvector - ] - for subvector in self._grammar_rules - ] - - # Step 2: Convert each list to llama_grammar_element array and get pointer - self._element_arrays = [ - (llama_cpp.llama_grammar_element * len(sublist))(*sublist) - for sublist in self._element_lists - ] - - # Step 3: Get pointer of each array - self._element_array_pointers = [ - ctypes.cast(subarray, llama_cpp.llama_grammar_element_p) for subarray in self._element_arrays - ] - - # Step 4: Make array of these pointers and get its pointer - self._rules = (llama_cpp.llama_grammar_element_p * len(self._element_array_pointers))( - *self._element_array_pointers - ) - - self.grammar = None - self._init_grammar() - - - def _init_grammar(self): - grammar = llama_cpp.llama_grammar_init( - self._rules, ctypes.c_size_t(self._n_rules), ctypes.c_size_t(self._start_rule_index) - ) - - if grammar is None: - raise ValueError("Failed to create grammar") - - self.grammar = grammar - - def __del__(self): - if self.grammar is not None: - llama_cpp.llama_grammar_free(self.grammar) - self.grammar = None - - def reset(self): - if self.grammar is not None: - llama_cpp.llama_grammar_free(self.grammar) - self._init_grammar() + def __init__(self, *args, _grammar: str, **kwargs): + self._grammar = _grammar + self._root = LLAMA_GRAMMAR_DEFAULT_ROOT @classmethod def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": - parsed_grammar = parse(grammar) - if verbose: - print_grammar(file=sys.stdout, state=parsed_grammar) - return cls(parsed_grammar) - + return cls(_grammar=grammar) + @classmethod def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": try: @@ -1820,4 +950,4 @@ def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None): ) schema = converter.resolve_refs(schema, "stdin") converter.visit(schema, "") - return converter.format_grammar() \ No newline at end of file + return converter.format_grammar() diff --git a/nexa/gguf/llama/llama_speculative.py b/nexa/gguf/llama/llama_speculative.py index 6188cb26..39dfb903 100644 --- a/nexa/gguf/llama/llama_speculative.py +++ b/nexa/gguf/llama/llama_speculative.py @@ -61,4 +61,4 @@ def __call__( input_ids=input_ids, max_ngram_size=self.max_ngram_size, num_pred_tokens=self.num_pred_tokens, - ) \ No newline at end of file + ) diff --git a/nexa/gguf/llama/llama_tokenizer.py b/nexa/gguf/llama/llama_tokenizer.py index f89fadd8..cefd3011 100644 --- a/nexa/gguf/llama/llama_tokenizer.py +++ b/nexa/gguf/llama/llama_tokenizer.py @@ -7,7 +7,7 @@ Any, ) -import nexa.gguf.llama.llama_cpp as llama_cpp +import nexa.gguf.llama from nexa.gguf.llama.llama_types import List @@ -27,7 +27,10 @@ def tokenize( @abc.abstractmethod def detokenize( - self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False + self, + tokens: List[int], + prev_tokens: Optional[List[int]] = None, + special: bool = False, ) -> bytes: """Detokenize the tokens into text. @@ -49,7 +52,10 @@ def tokenize( return self._model.tokenize(text, add_bos=add_bos, special=special) def detokenize( - self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False + self, + tokens: List[int], + prev_tokens: Optional[List[int]] = None, + special: bool = False, ) -> bytes: return self._model.detokenize(tokens, special=special) @@ -80,19 +86,24 @@ def tokenize( ) def detokenize( - self, tokens: List[int], prev_tokens: Optional[List[int]] = None, special: bool = False + self, + tokens: List[int], + prev_tokens: Optional[List[int]] = None, + special: bool = False, ) -> bytes: - skip_special_tokens = not special + skip_special_tokens = not special if prev_tokens is not None: - text = self.hf_tokenizer.decode(prev_tokens + tokens, skip_special_tokens=skip_special_tokens).encode( - "utf-8", errors="ignore" - ) - prev_text = self.hf_tokenizer.decode(prev_tokens, skip_special_tokens=skip_special_tokens).encode( - "utf-8", errors="ignore" - ) + text = self.hf_tokenizer.decode( + prev_tokens + tokens, skip_special_tokens=skip_special_tokens + ).encode("utf-8", errors="ignore") + prev_text = self.hf_tokenizer.decode( + prev_tokens, skip_special_tokens=skip_special_tokens + ).encode("utf-8", errors="ignore") return text[len(prev_text) :] else: - return self.hf_tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens).encode("utf-8", errors="ignore") + return self.hf_tokenizer.decode( + tokens, skip_special_tokens=skip_special_tokens + ).encode("utf-8", errors="ignore") @classmethod def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer": @@ -106,4 +117,4 @@ def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenize hf_tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path ) - return cls(hf_tokenizer) \ No newline at end of file + return cls(hf_tokenizer) diff --git a/nexa/gguf/llama/llama_types.py b/nexa/gguf/llama/llama_types.py index 3cc2122e..bbb58afc 100644 --- a/nexa/gguf/llama/llama_types.py +++ b/nexa/gguf/llama/llama_types.py @@ -295,4 +295,4 @@ class ChatCompletionNamedToolChoice(TypedDict): ChatCompletionChunk = CreateChatCompletionStreamResponse ChatCompletionStreamResponse = CreateChatCompletionStreamResponse ChatCompletionResponseFunction = ChatCompletionFunction -ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall \ No newline at end of file +ChatCompletionFunctionCall = ChatCompletionResponseFunctionCall diff --git a/nexa/gguf/llama/llava_cpp.py b/nexa/gguf/llama/llava_cpp.py index 9671eafb..945da826 100644 --- a/nexa/gguf/llama/llava_cpp.py +++ b/nexa/gguf/llama/llava_cpp.py @@ -1,9 +1,6 @@ from __future__ import annotations -import sys import os -import ctypes -import functools from ctypes import ( c_bool, c_char_p, @@ -17,68 +14,30 @@ ) import pathlib from typing import ( - List, Union, NewType, Optional, - TypeVar, - Callable, - Any, TYPE_CHECKING, - Generic, ) -from typing_extensions import TypeAlias import nexa.gguf.llama.llama_cpp as llama_cpp -from nexa.gguf.lib_utils import load_library - -# Specify the base name of the shared library to load -_libllava_base_name = "llava_shared" -# Load the library -_libllava = load_library(_libllava_base_name) +from nexa.gguf.llama._ctypes_extensions import ( + ctypes_function_for_shared_library, +) -# ctypes helper +from nexa.gguf.lib_utils import load_library if TYPE_CHECKING: - CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore - - CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore - - CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore - - CtypesVoidPointer: TypeAlias = ctypes.c_void_p - - class CtypesRef(Generic[CtypesCData]): - pass + from nexa.gguf.llama._ctypes_extensions import ( + CtypesArray, + ) - CtypesPointerOrRef: TypeAlias = Union[ - CtypesPointer[CtypesCData], CtypesRef[CtypesCData] - ] - - CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore - -F = TypeVar("F", bound=Callable[..., Any]) - - -def ctypes_function_for_shared_library(lib: ctypes.CDLL): - def ctypes_function( - name: str, argtypes: List[Any], restype: Any, enabled: bool = True - ): - def decorator(f: F) -> F: - if enabled: - func = getattr(lib, name) - func.argtypes = argtypes - func.restype = restype - functools.wraps(f)(func) - return func - else: - return f - - return decorator - - return ctypes_function +# Specify the base name of the shared library to load +_libllava_base_name = "llava_shared" +# Load the library +_libllava = load_library(_libllava_base_name) ctypes_function = ctypes_function_for_shared_library(_libllava) @@ -112,7 +71,8 @@ class llava_image_embed(Structure): ) def llava_validate_embed_size( ctx_llama: llama_cpp.llama_context_p, ctx_clip: clip_ctx_p, / -) -> bool: ... +) -> bool: + ... # /** build an image embed from image file bytes */ @@ -128,7 +88,8 @@ def llava_image_embed_make_with_bytes( image_bytes: CtypesArray[c_uint8], image_bytes_length: Union[c_int, int], /, -) -> "_Pointer[llava_image_embed]": ... +) -> "_Pointer[llava_image_embed]": + ... # /** build an image embed from a path to an image filename */ @@ -140,13 +101,15 @@ def llava_image_embed_make_with_bytes( ) def llava_image_embed_make_with_filename( ctx_clip: clip_ctx_p, n_threads: Union[c_int, int], image_path: bytes, / -) -> "_Pointer[llava_image_embed]": ... +) -> "_Pointer[llava_image_embed]": + ... # LLAVA_API void llava_image_embed_free(struct llava_image_embed * embed); # /** free an embedding made with llava_image_embed_make_* */ @ctypes_function("llava_image_embed_free", [POINTER(llava_image_embed)], None) -def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): ... +def llava_image_embed_free(embed: "_Pointer[llava_image_embed]", /): + ... # /** write the image represented by embed into the llama context with batch size n_batch, starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ @@ -167,7 +130,8 @@ def llava_eval_image_embed( n_batch: Union[c_int, int], n_past: "_Pointer[c_int]", /, -) -> bool: ... +) -> bool: + ... ################################################ @@ -180,10 +144,13 @@ def llava_eval_image_embed( @ctypes_function("clip_model_load", [c_char_p, c_int], clip_ctx_p_ctypes) def clip_model_load( fname: bytes, verbosity: Union[c_int, int], / -) -> Optional[clip_ctx_p]: ... +) -> Optional[clip_ctx_p]: + ... # /** free mmproj model */ # CLIP_API void clip_free(struct clip_ctx * ctx); @ctypes_function("clip_free", [clip_ctx_p_ctypes], None) -def clip_free(ctx: clip_ctx_p, /): ... \ No newline at end of file +def clip_free(ctx: clip_ctx_p, /): + ... + diff --git a/nexa/gguf/nexa_inference_vlm_omni.py b/nexa/gguf/nexa_inference_vlm_omni.py index bd5b6b29..4a76a4eb 100644 --- a/nexa/gguf/nexa_inference_vlm_omni.py +++ b/nexa/gguf/nexa_inference_vlm_omni.py @@ -40,7 +40,7 @@ def __init__( else: self.n_gpu_layers = 0 - # Handle direct model file paths (e.g., omnivision:model-fp16) + # Handle direct model file paths (e.g., omniVLM:model-fp16) if model_path and ':model-' in model_path: base_name = model_path.split(':')[0] model_type = model_path.split('model-')[1] diff --git a/pyproject.toml b/pyproject.toml index 24b6ee35..ef2be1f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,7 @@ cmake.args = [ "-DCMAKE_BUILD_PARALLEL_LEVEL=16", "-DSTABLE_DIFFUSION_BUILD=ON", "-DLLAMA_BUILD=ON", - "-DBARK_BUILD=ON" + "-DBARK_BUILD=OFF", ] [tool.scikit-build.metadata.version] diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index 4a5109f6..33500c21 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -1,6 +1,7 @@ from nexa.gguf import NexaTextInference from nexa.gguf.lib_utils import is_gpu_available + model = NexaTextInference( model_path="gemma", local_path=None,