From 27389014896541ca787595946dde7ee7c6b51c57 Mon Sep 17 00:00:00 2001 From: Davidqian123 Date: Thu, 21 Nov 2024 21:22:33 +0000 Subject: [PATCH] add precompute kv_cache in Llama class --- nexa/gguf/llama/llama.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/nexa/gguf/llama/llama.py b/nexa/gguf/llama/llama.py index d7b241e7..0007b515 100644 --- a/nexa/gguf/llama/llama.py +++ b/nexa/gguf/llama/llama.py @@ -31,6 +31,7 @@ from nexa.gguf.llama.llama_types import * from nexa.gguf.llama.llama_grammar import LlamaGrammar +from nexa.gguf.llama.llama_cache import BaseLlamaCache from nexa.gguf.llama.llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer import nexa.gguf.llama.llama_cpp as llama_cpp import nexa.gguf.llama.llama_chat_format as llama_chat_format @@ -350,6 +351,8 @@ def __init__( # Sampling Params self.last_n_tokens_size = last_n_tokens_size + self.cache: Optional[BaseLlamaCache] = None + self.lora_base = lora_base self.lora_scale = lora_scale self.lora_path = lora_path @@ -596,6 +599,14 @@ def detokenize( The detokenized string. """ return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special) + + def set_cache(self, cache: Optional[BaseLlamaCache]): + """Set the cache. + + Args: + cache: The cache to set. + """ + self.cache = cache def set_seed(self, seed: int): """Set the random seed. @@ -1211,6 +1222,23 @@ def logit_bias_processor( raise ValueError( "logprobs is not supported for models created with logits_all=False" ) + + if self.cache: + try: + cache_item = self.cache[prompt_tokens] + cache_prefix_len = Llama.longest_token_prefix( + cache_item.input_ids.tolist(), prompt_tokens + ) + eval_prefix_len = Llama.longest_token_prefix( + self._input_ids.tolist(), prompt_tokens + ) + if cache_prefix_len > eval_prefix_len: + self.load_state(cache_item) + if self.verbose: + print("Llama._create_completion: cache hit", file=sys.stderr) + except KeyError: + if self.verbose: + print("Llama._create_completion: cache miss", file=sys.stderr) if seed is not None: self._ctx.set_rng_seed(seed) @@ -1552,8 +1580,19 @@ def logit_bias_processor( } ], } + if self.cache: + if self.verbose: + print("Llama._create_completion: cache save", file=sys.stderr) + self.cache[prompt_tokens + completion_tokens] = self.save_state() + if self.verbose: + print("Llama._create_completion: cache saved", file=sys.stderr) return + if self.cache: + if self.verbose: + print("Llama._create_completion: cache save", file=sys.stderr) + self.cache[prompt_tokens + completion_tokens] = self.save_state() + text_str = text.decode("utf-8", errors="ignore") if echo: