Skip to content

Commit

Permalink
Merge pull request #100 from tjake/kv-cache-cleanup
Browse files Browse the repository at this point in the history
Keep kv-cache to the actual data stays on disk but mmap is removed af…
  • Loading branch information
tjake authored Oct 29, 2024
2 parents 0691f1e + 151df77 commit fed0222
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ protected AbstractModel(
this.poolingLayer = inferenceType.isPooling ? Optional.ofNullable(loadPoolingWeights()) : Optional.empty();
}

@Override
public void close() {
kvBufferCache.close();
}

protected abstract EmbedInput loadInputWeights();

protected abstract TransformerBlock[] loadTransformerBlockWeights();
Expand Down Expand Up @@ -305,7 +310,7 @@ public float[] embed(String input, PoolingType poolingType) {
Preconditions.checkArgument(encoded.length < c.contextLength);
float[] outputEmbedding = new float[c.embeddingLength];

try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(UUID.randomUUID())) {
try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getEphemeralKvBuffer()) {
int promptLength = encoded.length;
float avgp = 1.0f / promptLength;

Expand Down Expand Up @@ -506,22 +511,26 @@ public Response generate(

Preconditions.checkArgument(encoded.length < c.contextLength && encoded.length < ntokens, "Prompt exceeds max tokens");

KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(sessionId); // k and v for context window
int startPos = kvmem.getCurrentContextPosition(); // Number of tokens in the buffer
try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(sessionId)) { // k and v for context window
int startPos = kvmem.getCurrentContextPosition(); // Number of tokens in the buffer

logger.debug("Starting at token {} for session {} with prompt {}", startPos, sessionId, promptContext.getPrompt());
logger.debug(
"Starting at token {} for session {} with prompt {}",
startPos,
sessionId,
promptContext.getPrompt());

if (ntokens > c.contextLength) ntokens = c.contextLength;
if (ntokens > c.contextLength) ntokens = c.contextLength;

FinishReason reason = FinishReason.MAX_TOKENS;
int promptLength;
long promptBatchTime;
int tokensGenerated;
StringBuilder responseText = new StringBuilder();
StringBuilder responseTextWithSpecialTokens = new StringBuilder();
FinishReason reason = FinishReason.MAX_TOKENS;
int promptLength;
long promptBatchTime;
int tokensGenerated;
StringBuilder responseText = new StringBuilder();
StringBuilder responseTextWithSpecialTokens = new StringBuilder();

try (AbstractTensor logits = makeDenseTensor(c.vocabularySize)) {
int[] promptTokens;
try (AbstractTensor logits = makeDenseTensor(c.vocabularySize)) {
int[] promptTokens;

if (addBosToken()) {
promptTokens = new int[(1 + encoded.length)];
Expand All @@ -534,89 +543,91 @@ public Response generate(
promptLength = encoded.length;
}

long start = System.currentTimeMillis();
long promptStart = start;
// Batch Process Prompt
AbstractTensor last = DebugSupport.isDebug()
? batchForwardSlow(promptTokens, startPos, kvmem)
: batchForward(promptTokens, startPos, kvmem);

promptBatchTime = System.currentTimeMillis() - start;
float batchMsPerToken = Math.round((((double) promptBatchTime) / (double) promptLength));
logger.debug("{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken);

float genMsPerToken = 0;
tokensGenerated = 0;
int next = sample(last.slice(last.shape().first() - 1), temperature, ThreadLocalRandom.current().nextFloat(), logits);
last.close();
try {
String c = tokenizer.decode(next);
if (tokenizer.getModel().isSpecialToken(next)) {
responseTextWithSpecialTokens.append(c);
} else {
onTokenWithTimings.accept(c, batchMsPerToken);
responseText.append(c);
responseTextWithSpecialTokens.append(c);
}
} catch (Exception e) {
logger.error("Failed to decode token {}", next, e);
}

start = System.currentTimeMillis();
for (int i = startPos + promptTokens.length; i < ntokens; i++) {
AbstractTensor output = forward(next, i, kvmem);
tokensGenerated++;

next = sample(output, temperature, ThreadLocalRandom.current().nextFloat(), logits);

if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature);
output.close();

kvmem.incrementContextPosition();

// Model may tell us it's done
if (c.eosTokens.contains(next)) {
reason = FinishReason.STOP_TOKEN;
break;
}

long start = System.currentTimeMillis();
long promptStart = start;
// Batch Process Prompt
AbstractTensor last = DebugSupport.isDebug()
? batchForwardSlow(promptTokens, startPos, kvmem)
: batchForward(promptTokens, startPos, kvmem);

promptBatchTime = System.currentTimeMillis() - start;
float batchMsPerToken = Math.round((((double) promptBatchTime) / (double) promptLength));
logger.debug(
"{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken);

float genMsPerToken = 0;
tokensGenerated = 0;
int next = sample(
last.slice(last.shape().first() - 1),
temperature,
ThreadLocalRandom.current().nextFloat(),
logits);
last.close();
try {
String c = tokenizer.decode(next);

if (tokenizer.getModel().isSpecialToken(next)) {
responseTextWithSpecialTokens.append(c);
} else {
genMsPerToken = (System.currentTimeMillis() - start) / (float) (tokensGenerated);
onTokenWithTimings.accept(c, genMsPerToken);
responseTextWithSpecialTokens.append(c);
onTokenWithTimings.accept(c, batchMsPerToken);
responseText.append(c);
responseTextWithSpecialTokens.append(c);
}
} catch (Exception e) {
logger.error("Failed to decode token {}", next, e);
}
}

long end = System.currentTimeMillis();

Response response = new Response(
responseText.toString(),
responseTextWithSpecialTokens.toString(),
reason,
promptLength,
tokensGenerated,
promptBatchTime,
end - start
);
logger.debug(
String.format(
"\n\nelapsed: %ds, prompt %.1fms per token, gen %.1fms per token\n",
TimeUnit.MILLISECONDS.toSeconds(end - promptStart),
batchMsPerToken,
genMsPerToken
)
);

return postProcessResponse(promptContext, response);
start = System.currentTimeMillis();
for (int i = startPos + promptTokens.length; i < ntokens; i++) {
AbstractTensor output = forward(next, i, kvmem);
tokensGenerated++;

next = sample(
output, temperature, ThreadLocalRandom.current().nextFloat(), logits);

if (logger.isTraceEnabled())
logger.trace("Sampled token {} with temperature {}", next, temperature);
output.close();

kvmem.incrementContextPosition();

// Model may tell us it's done
if (c.eosTokens.contains(next)) {
reason = FinishReason.STOP_TOKEN;
break;
}

try {
String c = tokenizer.decode(next);

if (tokenizer.getModel().isSpecialToken(next)) {
responseTextWithSpecialTokens.append(c);
} else {
genMsPerToken = (System.currentTimeMillis() - start) / (float) (tokensGenerated);
onTokenWithTimings.accept(c, genMsPerToken);
responseTextWithSpecialTokens.append(c);
responseText.append(c);
}
} catch (Exception e) {
logger.error("Failed to decode token {}", next, e);
}
}

long end = System.currentTimeMillis();

Response response = new Response(
responseText.toString(),
responseTextWithSpecialTokens.toString(),
reason,
promptLength,
tokensGenerated,
promptBatchTime,
end - start);
logger.debug(String.format(
"\n\nelapsed: %ds, prompt %.1fms per token, gen %.1fms per token\n",
TimeUnit.MILLISECONDS.toSeconds(end - promptStart), batchMsPerToken, genMsPerToken));

return postProcessResponse(promptContext, response);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
import com.github.tjake.jlama.safetensors.prompt.ToolCall;
import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer;

import java.io.Closeable;
import java.util.*;
import java.util.function.BiConsumer;

/**
* Used to define a function that generates tokens from a prompt
*/
public interface Generator {
public interface Generator extends Closeable {

enum FinishReason {
MAX_TOKENS,
Expand Down
Loading

0 comments on commit fed0222

Please sign in to comment.