From 819f82928ce1619526a2006b81179aa01fba8995 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Sun, 27 Oct 2024 20:45:57 -0400 Subject: [PATCH 1/2] Keep kv-cache to the actual data stays on disk but mmap is removed after each generate. Also add a ephemeral kvcache for embeddings etc --- .../tjake/jlama/model/AbstractModel.java | 181 ++++++++++-------- .../jlama/model/functions/Generator.java | 3 +- .../tjake/jlama/tensor/KvBufferCache.java | 95 ++++++--- .../github/tjake/jlama/net/Coordinator.java | 5 + .../com/github/tjake/jlama/net/Worker.java | 6 + 5 files changed, 175 insertions(+), 115 deletions(-) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 4956b94b..9e041598 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -170,6 +170,11 @@ protected AbstractModel( this.poolingLayer = inferenceType.isPooling ? Optional.ofNullable(loadPoolingWeights()) : Optional.empty(); } + @Override + public void close() { + kvBufferCache.close(); + } + protected abstract EmbedInput loadInputWeights(); protected abstract TransformerBlock[] loadTransformerBlockWeights(); @@ -305,7 +310,7 @@ public float[] embed(String input, PoolingType poolingType) { Preconditions.checkArgument(encoded.length < c.contextLength); float[] outputEmbedding = new float[c.embeddingLength]; - try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(UUID.randomUUID())) { + try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getEphemeralKvBuffer()) { int promptLength = encoded.length; float avgp = 1.0f / promptLength; @@ -506,22 +511,26 @@ public Response generate( Preconditions.checkArgument(encoded.length < c.contextLength && encoded.length < ntokens, "Prompt exceeds max tokens"); - KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(sessionId); // k and v for context window - int startPos = kvmem.getCurrentContextPosition(); // Number of tokens in the buffer + try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(sessionId)) { // k and v for context window + int startPos = kvmem.getCurrentContextPosition(); // Number of tokens in the buffer - logger.debug("Starting at token {} for session {} with prompt {}", startPos, sessionId, promptContext.getPrompt()); + logger.debug( + "Starting at token {} for session {} with prompt {}", + startPos, + sessionId, + promptContext.getPrompt()); - if (ntokens > c.contextLength) ntokens = c.contextLength; + if (ntokens > c.contextLength) ntokens = c.contextLength; - FinishReason reason = FinishReason.MAX_TOKENS; - int promptLength; - long promptBatchTime; - int tokensGenerated; - StringBuilder responseText = new StringBuilder(); - StringBuilder responseTextWithSpecialTokens = new StringBuilder(); + FinishReason reason = FinishReason.MAX_TOKENS; + int promptLength; + long promptBatchTime; + int tokensGenerated; + StringBuilder responseText = new StringBuilder(); + StringBuilder responseTextWithSpecialTokens = new StringBuilder(); - try (AbstractTensor logits = makeDenseTensor(c.vocabularySize)) { - int[] promptTokens; + try (AbstractTensor logits = makeDenseTensor(c.vocabularySize)) { + int[] promptTokens; if (addBosToken()) { promptTokens = new int[(1 + encoded.length)]; @@ -534,89 +543,91 @@ public Response generate( promptLength = encoded.length; } - long start = System.currentTimeMillis(); - long promptStart = start; - // Batch Process Prompt - AbstractTensor last = DebugSupport.isDebug() - ? batchForwardSlow(promptTokens, startPos, kvmem) - : batchForward(promptTokens, startPos, kvmem); - - promptBatchTime = System.currentTimeMillis() - start; - float batchMsPerToken = Math.round((((double) promptBatchTime) / (double) promptLength)); - logger.debug("{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken); - - float genMsPerToken = 0; - tokensGenerated = 0; - int next = sample(last.slice(last.shape().first() - 1), temperature, ThreadLocalRandom.current().nextFloat(), logits); - last.close(); - try { - String c = tokenizer.decode(next); - if (tokenizer.getModel().isSpecialToken(next)) { - responseTextWithSpecialTokens.append(c); - } else { - onTokenWithTimings.accept(c, batchMsPerToken); - responseText.append(c); - responseTextWithSpecialTokens.append(c); - } - } catch (Exception e) { - logger.error("Failed to decode token {}", next, e); - } - - start = System.currentTimeMillis(); - for (int i = startPos + promptTokens.length; i < ntokens; i++) { - AbstractTensor output = forward(next, i, kvmem); - tokensGenerated++; - - next = sample(output, temperature, ThreadLocalRandom.current().nextFloat(), logits); - - if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature); - output.close(); - - kvmem.incrementContextPosition(); - - // Model may tell us it's done - if (c.eosTokens.contains(next)) { - reason = FinishReason.STOP_TOKEN; - break; - } - + long start = System.currentTimeMillis(); + long promptStart = start; + // Batch Process Prompt + AbstractTensor last = DebugSupport.isDebug() + ? batchForwardSlow(promptTokens, startPos, kvmem) + : batchForward(promptTokens, startPos, kvmem); + + promptBatchTime = System.currentTimeMillis() - start; + float batchMsPerToken = Math.round((((double) promptBatchTime) / (double) promptLength)); + logger.debug( + "{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken); + + float genMsPerToken = 0; + tokensGenerated = 0; + int next = sample( + last.slice(last.shape().first() - 1), + temperature, + ThreadLocalRandom.current().nextFloat(), + logits); + last.close(); try { String c = tokenizer.decode(next); - if (tokenizer.getModel().isSpecialToken(next)) { responseTextWithSpecialTokens.append(c); } else { - genMsPerToken = (System.currentTimeMillis() - start) / (float) (tokensGenerated); - onTokenWithTimings.accept(c, genMsPerToken); - responseTextWithSpecialTokens.append(c); + onTokenWithTimings.accept(c, batchMsPerToken); responseText.append(c); + responseTextWithSpecialTokens.append(c); } } catch (Exception e) { logger.error("Failed to decode token {}", next, e); } - } - long end = System.currentTimeMillis(); - - Response response = new Response( - responseText.toString(), - responseTextWithSpecialTokens.toString(), - reason, - promptLength, - tokensGenerated, - promptBatchTime, - end - start - ); - logger.debug( - String.format( - "\n\nelapsed: %ds, prompt %.1fms per token, gen %.1fms per token\n", - TimeUnit.MILLISECONDS.toSeconds(end - promptStart), - batchMsPerToken, - genMsPerToken - ) - ); - - return postProcessResponse(promptContext, response); + start = System.currentTimeMillis(); + for (int i = startPos + promptTokens.length; i < ntokens; i++) { + AbstractTensor output = forward(next, i, kvmem); + tokensGenerated++; + + next = sample( + output, temperature, ThreadLocalRandom.current().nextFloat(), logits); + + if (logger.isTraceEnabled()) + logger.trace("Sampled token {} with temperature {}", next, temperature); + output.close(); + + kvmem.incrementContextPosition(); + + // Model may tell us it's done + if (c.eosTokens.contains(next)) { + reason = FinishReason.STOP_TOKEN; + break; + } + + try { + String c = tokenizer.decode(next); + + if (tokenizer.getModel().isSpecialToken(next)) { + responseTextWithSpecialTokens.append(c); + } else { + genMsPerToken = (System.currentTimeMillis() - start) / (float) (tokensGenerated); + onTokenWithTimings.accept(c, genMsPerToken); + responseTextWithSpecialTokens.append(c); + responseText.append(c); + } + } catch (Exception e) { + logger.error("Failed to decode token {}", next, e); + } + } + + long end = System.currentTimeMillis(); + + Response response = new Response( + responseText.toString(), + responseTextWithSpecialTokens.toString(), + reason, + promptLength, + tokensGenerated, + promptBatchTime, + end - start); + logger.debug(String.format( + "\n\nelapsed: %ds, prompt %.1fms per token, gen %.1fms per token\n", + TimeUnit.MILLISECONDS.toSeconds(end - promptStart), batchMsPerToken, genMsPerToken)); + + return postProcessResponse(promptContext, response); + } } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java index bf9efeb3..66a124e2 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/functions/Generator.java @@ -21,13 +21,14 @@ import com.github.tjake.jlama.safetensors.prompt.ToolCall; import com.github.tjake.jlama.safetensors.tokenizer.Tokenizer; +import java.io.Closeable; import java.util.*; import java.util.function.BiConsumer; /** * Used to define a function that generates tokens from a prompt */ -public interface Generator { +public interface Generator extends Closeable { enum FinishReason { MAX_TOKENS, diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java index e2f26a48..5dbbdcf6 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java @@ -24,6 +24,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.Closeable; import java.io.IOError; import java.io.IOException; import java.io.RandomAccessFile; @@ -35,13 +36,14 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; /** * A cache for key-value buffers used in the model. * @see com.github.tjake.jlama.model.functions.Generator */ -public class KvBufferCache { +public class KvBufferCache implements Closeable { private static final Logger logger = LoggerFactory.getLogger(KvBufferCache.class); private final ConcurrentMap kvBufferCache; private final AbstractModel model; @@ -52,7 +54,19 @@ public KvBufferCache(AbstractModel model) { } public KvBuffer getKvBuffer(UUID session) { - return kvBufferCache.computeIfAbsent(session, s -> new KvBuffer(s, 1 << 24)); // 16MB per page + return kvBufferCache.computeIfAbsent(session, s -> new KvBuffer(s, 1 << 23, false)); // 8MB per page + } + + public KvBuffer getEphemeralKvBuffer() { + return new KvBuffer(UUID.randomUUID(), 1 << 20, true); + } + + @Override + public void close() { + for (KvBuffer kvBuffer : kvBufferCache.values()) { + kvBuffer.close(); + } + kvBufferCache.clear(); } class KvPageContext { @@ -107,15 +121,16 @@ class KvBufferPage implements AutoCloseable { private final KvPageContext pageCtx; private final String pageId; + private final AtomicBoolean closed = new AtomicBoolean(false); private final RandomAccessFile raf; - KvBufferPage(KvPageContext pageCtx, String pageId) { + KvBufferPage(KvPageContext pageCtx, String pageId, boolean ephemeral) { this.pageCtx = pageCtx; this.pageId = pageId; - if (model.getConfig().workingDirectory().isEmpty()) { + if (model.getConfig().workingDirectory().isEmpty() || ephemeral) { this.raf = null; - this.tensor = AbstractTensor.make(model.getWorkingDType(), pageCtx.pageShape); + this.tensor = TensorCache.instance.get(model.getWorkingDType(), pageCtx.pageShape); } else { try { raf = new RandomAccessFile( @@ -126,7 +141,9 @@ class KvBufferPage implements AutoCloseable { "rw" ); long bytes = pageCtx.pageShape.size() * model.getWorkingDType().size(); - raf.setLength(bytes); + logger.debug("Allocating page {} with {} bytes {}", pageId, bytes, raf.length()); + if (raf.length() != bytes) + raf.setLength(bytes); AbstractTensor t; if (model.getWorkingDType() == DType.F32) { @@ -156,13 +173,21 @@ class KvBufferPage implements AutoCloseable { } public AbstractTensor getTensor() { + assert !closed.get() : "Page is closed"; return tensor; } + public boolean isClosed() { + return closed.get(); + } + @Override public void close() throws IOException { - if (raf != null) { - raf.close(); + if (closed.compareAndSet(false, true)) { + if (raf != null) { + raf.close(); + } + tensor.close(); } } } @@ -173,11 +198,13 @@ public class KvBuffer implements AutoCloseable { private final KvBufferPage[][] pages; private final KvPageContext pageContext; + private final boolean ephemeral; - KvBuffer(UUID session, int maxPageSizeInBytes) { + KvBuffer(UUID session, int maxPageSizeInBytes, boolean ephemeral) { this.session = session; this.pageContext = computePageSize(maxPageSizeInBytes); this.pages = new KvBufferPage[pageContext.numberOfLayerPages][pageContext.numberOfContextPages]; + this.ephemeral = ephemeral; } public int getCurrentContextPosition() { @@ -225,23 +252,6 @@ public KvPageContext computePageSize(long maxPageSizeInBytes) { } } - // Try partitioning by context length - for (int y = C; y >= 1; y--) { - long x = maxPageSizeInBytes / (y * s); - - if (x >= 1 && x <= N) { - long product = x * y; - - if (product > maxProduct) { - optimalLayersPerPage = (int) x; - optimalContextLengthPerPage = y; - maxProduct = product; - } - if (product < maxProduct) { - break; - } - } - } // Calculate the number of pages needed int numberOfLayerPages = (int) Math.ceil((double) N / optimalLayersPerPage); @@ -256,12 +266,33 @@ public KvPageContext computePageSize(long maxPageSizeInBytes) { ); } + logger.debug( + "Optimal page size: {} layers, {} context length, {} bytes, {} layer pages, {} length pages", + optimalLayersPerPage, + optimalContextLengthPerPage, + pageSize, + numberOfLayerPages, + numberOfContextPages + ); + return new KvPageContext(session, numberOfLayerPages, numberOfContextPages, optimalLayersPerPage, optimalContextLengthPerPage); } @Override public void close() { - + for (KvBufferPage[] layerPages : pages) { + if (layerPages != null) { + for (KvBufferPage page : layerPages) { + if (page != null) { + try { + page.close(); + } catch (IOException e) { + // error message + } + } + } + } + } } public AbstractTensor getKeyTensorForPosition(int layerIndex, int position) { @@ -280,8 +311,8 @@ private AbstractTensor getTensorForPosition(int layerIndex, int position, int in int relativeContextIndex = position % pageContext.contextLengthPerPage; KvBufferPage page = pages[layerPageIndex][contextPageIndex]; - if (page == null) { - page = new KvBufferPage(pageContext, "L" + layerPageIndex + "C" + contextPageIndex); + if (page == null || page.isClosed()) { + page = new KvBufferPage(pageContext, "L" + layerPageIndex + "C" + contextPageIndex, ephemeral); pages[layerPageIndex][contextPageIndex] = page; } @@ -307,6 +338,12 @@ private AbstractTensor[] getTensorsUptoPosition(int layerIndex, int index, int u for (int i = 0; i <= contextPageIndex; i++) { KvBufferPage page = layerPages[i]; + + if (page == null || page.isClosed()) { + page = new KvBufferPage(pageContext, "L" + layerPageIndex + "C" + contextPageIndex, ephemeral); + layerPages[i] = page; + } + tensors[i] = page.getTensor().slice(true, relativeLayerIndex, index); } diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java index 68eafe9c..9365382c 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Coordinator.java @@ -228,4 +228,9 @@ public Generator.Response generate( return new Generator.Response("", "", FinishReason.ERROR, 0, 0, 0, 0); } } + + @Override + public void close() { + + } } diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java index 2c4fa5b7..676dbcbd 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/Worker.java @@ -251,6 +251,12 @@ public void processOutput(ByteString session, int startPosition, int batchSize, @Override public void close() { + try { + kvBufferCache.close(); + } catch (Exception e) { + logger.error("Error closing kvBufferCache", e); + } + ((ManagedChannel) client.getChannel()).shutdown(); } From 151df77278038bc6bcf36e78063724afc932c4eb Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Tue, 29 Oct 2024 19:14:31 -0400 Subject: [PATCH 2/2] review nits --- .../com/github/tjake/jlama/tensor/KvBufferCache.java | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java index 5dbbdcf6..0a8fee21 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java @@ -33,6 +33,8 @@ import java.nio.ShortBuffer; import java.nio.channels.FileChannel; import java.nio.file.Paths; +import java.util.Iterator; +import java.util.Map; import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -63,10 +65,11 @@ public KvBuffer getEphemeralKvBuffer() { @Override public void close() { - for (KvBuffer kvBuffer : kvBufferCache.values()) { - kvBuffer.close(); + Iterator> it = kvBufferCache.entrySet().iterator(); + while (it.hasNext()) { + it.next().getValue().close(); + it.remove(); } - kvBufferCache.clear(); } class KvPageContext { @@ -287,7 +290,7 @@ public void close() { try { page.close(); } catch (IOException e) { - // error message + logger.debug("Error closing page", e); } } }