From 7eda3aad3a37be14994547a1b306e7a879a7b8b5 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Thu, 26 Dec 2024 21:35:58 -0500 Subject: [PATCH 1/2] Fix prompt usage when sharing the same session id (by stripping out the preamble) --- .../tjake/jlama/cli/commands/ChatCommand.java | 5 +- .../tjake/jlama/model/AbstractModel.java | 12 ++++- .../safetensors/prompt/PromptSupport.java | 32 +++++++++++- ...s.java => NativeSimdTensorOperations.java} | 10 ++-- .../tensor/operations/util/JarSupport.java | 10 ++-- .../jlama/net/openai/OpenAIChatService.java | 2 +- .../tjake/jlama/model/TestCorrectness.java | 49 ++++++++++++++++++- .../tensor/operations/TestOperations.java | 24 ++++----- 8 files changed, 115 insertions(+), 29 deletions(-) rename jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/{NativeTensorOperations.java => NativeSimdTensorOperations.java} (98%) diff --git a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java index 330da77e..78efd95b 100644 --- a/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java +++ b/jlama-cli/src/main/java/com/github/tjake/jlama/cli/commands/ChatCommand.java @@ -66,6 +66,7 @@ public void run() { UUID session = UUID.randomUUID(); PromptSupport promptSupport = m.promptSupport().get(); + PromptSupport.Builder builder = promptSupport.builder(); PrintWriter out = System.console().writer(); out.println("\nChatting with " + modelName + "...\n"); @@ -82,7 +83,6 @@ public void run() { break; } - PromptSupport.Builder builder = promptSupport.builder(); if (first && systemPrompt != null) { builder.addSystemMessage(systemPrompt); } @@ -97,6 +97,9 @@ public void run() { makeOutHandler() ); + // New prompt builder and strip out the preamble since we're continuing the conversation + builder = promptSupport.builder().stripPreamble(); + out.println( "\n\n" + statsColor.format( diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index aecfbd19..4a1f00fa 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -283,8 +283,16 @@ public AbstractTensor batchForward( KvBufferCache.KvBuffer kvbuf, Optional>> tensorReducer ) { - AbstractTensor embedding = embedInput.batchInputsToEmbeddings(token_ids, startPos); - return forward(embedding, startPos, kvbuf, tensorReducer); + AbstractTensor embedding = null; + + //Batch prompt into groups of 1024 + for (int i = 0; i < token_ids.length; i += 1024) { + int[] batch = Arrays.copyOfRange(token_ids, i, Math.min(token_ids.length, i + 1024)); + embedding = embedInput.batchInputsToEmbeddings(batch, startPos + i); + embedding = forward(embedding, startPos + i, kvbuf, tensorReducer); + } + + return embedding; } public AbstractTensor forward( diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java index 7526be98..f22afc1b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/prompt/PromptSupport.java @@ -190,6 +190,7 @@ public static class Builder { private boolean addGenerationPrompt = true; private List messages = new ArrayList<>(2); + private boolean stripPreamble = false; private Builder(TokenizerModel m) { this.m = m; @@ -230,6 +231,11 @@ public Builder addAssistantMessage(String content) { return this; } + public Builder stripPreamble() { + stripPreamble = true; + return this; + } + public PromptContext build() { return build(Optional.empty()); } @@ -259,8 +265,29 @@ private PromptContext build(Optional> optionalTools) { "This model does not support tools, but tools are specified" ); - Map args = new HashMap<>(); + String preamble = ""; + if (stripPreamble) { + Map args = new HashMap<>(); + args.putAll( + Map.of( + "messages", + Map.of(), + "add_generation_prompt", + false, + "eos_token", + m.eosToken(), + "bos_token", + "" + ) + ); // We add the BOS ourselves + optionalTools.ifPresent(tools -> args.put("tools", tools)); + + RenderResult r = jinjava.renderForResult(template, args); + preamble = r.getOutput(); + } + + Map args = new HashMap<>(); args.putAll( Map.of( "messages", @@ -280,7 +307,8 @@ private PromptContext build(Optional> optionalTools) { if (r.hasErrors()) logger.debug("Prompt template errors: " + r.getErrors()); - return new PromptContext(r.getOutput(), optionalTools); + String output = r.getOutput(); + return new PromptContext(output.substring(preamble.length()), optionalTools); } } } diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeSimdTensorOperations.java similarity index 98% rename from jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java rename to jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeSimdTensorOperations.java index 700a0eff..5953e59b 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeTensorOperations.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/NativeSimdTensorOperations.java @@ -28,11 +28,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class NativeTensorOperations implements TensorOperations { - private static final Logger logger = LoggerFactory.getLogger(NativeTensorOperations.class); +public class NativeSimdTensorOperations implements TensorOperations { + private static final Logger logger = LoggerFactory.getLogger(NativeSimdTensorOperations.class); static { - if (!JarSupport.maybeLoadLibrary()) System.loadLibrary("jlama"); + if (!JarSupport.maybeLoadLibrary("jlama")) System.loadLibrary("jlama"); } public static final int HAS_F16C = NativeSimd.HAS_F16C(); @@ -52,7 +52,7 @@ public class NativeTensorOperations implements TensorOperations { final int flags; - public NativeTensorOperations() { + public NativeSimdTensorOperations() { int f = 0; if (RuntimeSupport.isLinux()) f |= HAS_F16C; @@ -63,7 +63,7 @@ public NativeTensorOperations() { checkLib(); } - NativeTensorOperations(int flags) { + NativeSimdTensorOperations(int flags) { this.flags = flags; } diff --git a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/util/JarSupport.java b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/util/JarSupport.java index 1bb4da47..41190e47 100644 --- a/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/util/JarSupport.java +++ b/jlama-native/src/main/java/com/github/tjake/jlama/tensor/operations/util/JarSupport.java @@ -31,16 +31,16 @@ public class JarSupport { private static final Logger logger = LoggerFactory.getLogger(JarSupport.class); - public static boolean maybeLoadLibrary() { + public static boolean maybeLoadLibrary(String libname) { String ext = RuntimeSupport.isMac() ? ".dylib" : RuntimeSupport.isWin() ? ".dll" : ".so"; - URL lib = JarSupport.class.getClassLoader().getResource("META-INF/native/lib/libjlama" + ext); + URL lib = JarSupport.class.getClassLoader().getResource("META-INF/native/lib/lib" + libname + ext); if (lib != null) { try { final File libpath = Files.createTempDirectory("jlama").toFile(); libpath.deleteOnExit(); // just in case - File libfile = Paths.get(libpath.getAbsolutePath(), "libjlama" + ext).toFile(); + File libfile = Paths.get(libpath.getAbsolutePath(), "lib" + libname + ext).toFile(); libfile.deleteOnExit(); // just in case final InputStream in = lib.openStream(); @@ -53,10 +53,10 @@ public static boolean maybeLoadLibrary() { out.close(); in.close(); System.load(libfile.getAbsolutePath()); - logger.debug("Loaded jlama-native library: {}", libfile.getAbsolutePath()); + logger.debug("Loaded {}-native library: {}", libname, libfile.getAbsolutePath()); return true; } catch (IOException e) { - logger.warn("Error loading jlama-native library"); + logger.warn("Error loading {}-native library", libname); } } diff --git a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java index f3d546aa..3b08ebac 100644 --- a/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java +++ b/jlama-net/src/main/java/com/github/tjake/jlama/net/openai/OpenAIChatService.java @@ -109,7 +109,7 @@ Object createChatCompletion(@RequestHeader Map headers, @Valid @ if (request.getStream() != null && request.getStream()) { SseEmitter emitter = new SseEmitter(-1L); CompletableFuture.supplyAsync( - () -> model.generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> CompletableFuture.supplyAsync(() -> { + () -> model. generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> CompletableFuture.supplyAsync(() -> { try { emitter.send( new CreateChatCompletionStreamResponse().id(sessionId.toString()) diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java index f63454e4..0bc13195 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestCorrectness.java @@ -345,7 +345,7 @@ public void testPromptSupportWithTools() { @Test public void testMistralTools() { - String modelPrefix = "../models/Mistral-7B-Instruct-v0.3"; + String modelPrefix = "../models/tjake_Mistral-7B-Instruct-v0.3-JQ4"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); @@ -420,4 +420,51 @@ public void testToolParse() throws JsonProcessingException { Assert.assertEquals(2, toolCalls.size()); } + + @Test + public void testPromptBuilderSession() { + String modelPrefix = "../models/Qwen_Qwen2.5-0.5B-Instruct-JQ4"; + Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); + + Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix)); + PromptSupport.Builder builder = tokenizer.promptSupport().get().builder(); + builder.addSystemMessage("You always respond as a pirate"); + builder.addUserMessage("What is the weather in paris right now?"); + builder.addGenerationPrompt(true); + + Tool t = Tool.from( + Function.builder() + .name("get_current_temperature") + .description("Simulates getting the current temperature at a location.") + .addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true) + .addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true) + .build() + ); + + PromptContext prompt = builder.build(t); + Assert.assertEquals( + "<|im_start|>system\n" + "You always respond as a pirate\n" + + "\n" + + "# Tools\n" + + "\n" + + "You may call one or more functions to assist with the user query.\n" + + "\n" + + "You are provided with function signatures within XML tags:\n" + + "\n" + + "{\"type\": \"function\", \"function\": {\"name\": \"get_current_temperature\", \"description\": \"Simulates getting the current temperature at a location.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\".\"}, \"unit\": {\"type\": \"string\", \"description\": \"The unit to return the temperature in (e.g., \\\"celsius\\\", \\\"fahrenheit\\\").\"}}, \"required\": [\"location\", \"unit\"]}}}\n" + + "\n" + + "\n" + + "For each function call, return a json object with function name and arguments within XML tags:\n" + + "\n" + + "{\"name\": , \"arguments\": }\n" + + "<|im_end|>\n" + + "<|im_start|>user\n" + + "What is the weather in paris right now?<|im_end|>\n" + + "<|im_start|>assistant\n", + prompt.getPrompt()); + + prompt = tokenizer.promptSupport().get().builder().addUserMessage("This is a test").stripPreamble().build(); + Assert.assertEquals( + "<|im_start|>user\n" + "This is a test<|im_end|>\n" + "<|im_start|>assistant\n", prompt.getPrompt()); + } } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java b/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java index cf263ee6..51380884 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/tensor/operations/TestOperations.java @@ -15,7 +15,7 @@ */ package com.github.tjake.jlama.tensor.operations; -import static com.github.tjake.jlama.tensor.operations.NativeTensorOperations.*; +import static com.github.tjake.jlama.tensor.operations.NativeSimdTensorOperations.*; import com.github.tjake.jlama.math.VectorMath; import com.github.tjake.jlama.safetensors.DType; @@ -59,19 +59,19 @@ public static void init() { opTypes.add(new PanamaTensorOperations(MachineSpec.Type.AVX_256)); opTypes.add(new PanamaTensorOperations(MachineSpec.Type.ARM_128)); - if (globalOps instanceof NativeTensorOperations) { - opTypes.add(new NativeTensorOperations()); - opTypes.add(new NativeTensorOperations(0)); + if (globalOps instanceof NativeSimdTensorOperations) { + opTypes.add(new NativeSimdTensorOperations()); + opTypes.add(new NativeSimdTensorOperations(0)); - if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeTensorOperations(HAS_AVX2)); + if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeSimdTensorOperations(HAS_AVX2)); if (RuntimeSupport.isLinux() || RuntimeSupport.isWin()) { - opTypes.add(new NativeTensorOperations(HAS_F16C)); - if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeTensorOperations(HAS_F16C | HAS_AVX2)); + opTypes.add(new NativeSimdTensorOperations(HAS_F16C)); + if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeSimdTensorOperations(HAS_F16C | HAS_AVX2)); } if (RuntimeSupport.isArm()) { - opTypes.add(new NativeTensorOperations(MachineSpec.Type.ARM_128.ctag)); + opTypes.add(new NativeSimdTensorOperations(MachineSpec.Type.ARM_128.ctag)); } } @@ -198,7 +198,7 @@ public void testSplitDotProduct() { @Test public void testNativeDotProduct() { - Assume.assumeTrue(globalOps instanceof NativeTensorOperations); + Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations); AbstractTensor a = makeTensor(SIZE); AbstractTensor b = makeTensor(SIZE); @@ -476,7 +476,7 @@ public void testBatchDotProductWithResultOffset() { @Test public void testNativeBatchDotProduct() { // M == BATCH, N == ROWS, K == SIZE - Assume.assumeTrue(globalOps instanceof NativeTensorOperations); + Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations); FloatBufferTensor c = new FloatBufferTensor(BATCH, ROWS); FloatBufferTensor c1 = new FloatBufferTensor(BATCH, ROWS); @@ -512,7 +512,7 @@ public void testNativeBatchDotProduct() { @Test public void testNativeBatchDotProductWithOffsets() { // M == BATCH, N == ROWS, K == SIZE - Assume.assumeTrue(globalOps instanceof NativeTensorOperations); + Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations); FloatBufferTensor c = new FloatBufferTensor(BATCH, ROWS); FloatBufferTensor c1 = new FloatBufferTensor(BATCH, ROWS); @@ -548,7 +548,7 @@ public void testNativeBatchDotProductWithOffsets() { @Test public void testNativeDotProductFast() { // M == BATCH, N == ROWS, K == SIZE - Assume.assumeTrue(globalOps instanceof NativeTensorOperations); + Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations); FloatBufferTensor c = new FloatBufferTensor(1, SIZE); FloatBufferTensor c1 = new FloatBufferTensor(1, SIZE); From 1569d442532c435aef99d014544c1f3ac6a65966 Mon Sep 17 00:00:00 2001 From: Jake Luciani Date: Thu, 26 Dec 2024 21:48:16 -0500 Subject: [PATCH 2/2] Renamed Native to NativeSimd --- .../tjake/jlama/tensor/operations/TensorOperationsProvider.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperationsProvider.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperationsProvider.java index c79b1aba..6bcfbd01 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperationsProvider.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/operations/TensorOperationsProvider.java @@ -53,7 +53,7 @@ private TensorOperations pickFastestImplementation() { if (!forcePanama) { try { Class nativeClazz = (Class) Class.forName( - "com.github.tjake.jlama.tensor.operations.NativeTensorOperations" + "com.github.tjake.jlama.tensor.operations.NativeSimdTensorOperations" ); pick = nativeClazz.getConstructor().newInstance(); // This will throw if no shared lib found