Skip to content

Commit

Permalink
add precompute kv_cache in Llama class
Browse files Browse the repository at this point in the history
  • Loading branch information
Davidqian123 committed Nov 21, 2024
1 parent 6509ecc commit 2738901
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions nexa/gguf/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from nexa.gguf.llama.llama_types import *
from nexa.gguf.llama.llama_grammar import LlamaGrammar
from nexa.gguf.llama.llama_cache import BaseLlamaCache
from nexa.gguf.llama.llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
import nexa.gguf.llama.llama_cpp as llama_cpp
import nexa.gguf.llama.llama_chat_format as llama_chat_format
Expand Down Expand Up @@ -350,6 +351,8 @@ def __init__(
# Sampling Params
self.last_n_tokens_size = last_n_tokens_size

self.cache: Optional[BaseLlamaCache] = None

self.lora_base = lora_base
self.lora_scale = lora_scale
self.lora_path = lora_path
Expand Down Expand Up @@ -596,6 +599,14 @@ def detokenize(
The detokenized string.
"""
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special)

def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.
Args:
cache: The cache to set.
"""
self.cache = cache

def set_seed(self, seed: int):
"""Set the random seed.
Expand Down Expand Up @@ -1211,6 +1222,23 @@ def logit_bias_processor(
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)

if self.cache:
try:
cache_item = self.cache[prompt_tokens]
cache_prefix_len = Llama.longest_token_prefix(
cache_item.input_ids.tolist(), prompt_tokens
)
eval_prefix_len = Llama.longest_token_prefix(
self._input_ids.tolist(), prompt_tokens
)
if cache_prefix_len > eval_prefix_len:
self.load_state(cache_item)
if self.verbose:
print("Llama._create_completion: cache hit", file=sys.stderr)
except KeyError:
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)

if seed is not None:
self._ctx.set_rng_seed(seed)
Expand Down Expand Up @@ -1552,8 +1580,19 @@ def logit_bias_processor(
}
],
}
if self.cache:
if self.verbose:
print("Llama._create_completion: cache save", file=sys.stderr)
self.cache[prompt_tokens + completion_tokens] = self.save_state()
if self.verbose:
print("Llama._create_completion: cache saved", file=sys.stderr)
return

if self.cache:
if self.verbose:
print("Llama._create_completion: cache save", file=sys.stderr)
self.cache[prompt_tokens + completion_tokens] = self.save_state()

text_str = text.decode("utf-8", errors="ignore")

if echo:
Expand Down

0 comments on commit 2738901

Please sign in to comment.