Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add precompute kv_cache in Llama class #270

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions nexa/gguf/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

from nexa.gguf.llama.llama_types import *
from nexa.gguf.llama.llama_grammar import LlamaGrammar
from nexa.gguf.llama.llama_cache import BaseLlamaCache
from nexa.gguf.llama.llama_tokenizer import BaseLlamaTokenizer, LlamaTokenizer
import nexa.gguf.llama.llama_cpp as llama_cpp
import nexa.gguf.llama.llama_chat_format as llama_chat_format
Expand Down Expand Up @@ -350,6 +351,8 @@ def __init__(
# Sampling Params
self.last_n_tokens_size = last_n_tokens_size

self.cache: Optional[BaseLlamaCache] = None

self.lora_base = lora_base
self.lora_scale = lora_scale
self.lora_path = lora_path
Expand Down Expand Up @@ -596,6 +599,14 @@ def detokenize(
The detokenized string.
"""
return self.tokenizer_.detokenize(tokens, prev_tokens=prev_tokens, special=special)

def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache.

Args:
cache: The cache to set.
"""
self.cache = cache

def set_seed(self, seed: int):
"""Set the random seed.
Expand Down Expand Up @@ -1211,6 +1222,23 @@ def logit_bias_processor(
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)

if self.cache:
try:
cache_item = self.cache[prompt_tokens]
cache_prefix_len = Llama.longest_token_prefix(
cache_item.input_ids.tolist(), prompt_tokens
)
eval_prefix_len = Llama.longest_token_prefix(
self._input_ids.tolist(), prompt_tokens
)
if cache_prefix_len > eval_prefix_len:
self.load_state(cache_item)
if self.verbose:
print("Llama._create_completion: cache hit", file=sys.stderr)
except KeyError:
if self.verbose:
print("Llama._create_completion: cache miss", file=sys.stderr)

if seed is not None:
self._ctx.set_rng_seed(seed)
Expand Down Expand Up @@ -1552,8 +1580,19 @@ def logit_bias_processor(
}
],
}
if self.cache:
if self.verbose:
print("Llama._create_completion: cache save", file=sys.stderr)
self.cache[prompt_tokens + completion_tokens] = self.save_state()
if self.verbose:
print("Llama._create_completion: cache saved", file=sys.stderr)
return

if self.cache:
if self.verbose:
print("Llama._create_completion: cache save", file=sys.stderr)
self.cache[prompt_tokens + completion_tokens] = self.save_state()

text_str = text.decode("utf-8", errors="ignore")

if echo:
Expand Down
Loading