Skip to content

Commit

Permalink
Merge pull request #136 from tjake/fix-prompt-with-session
Browse files Browse the repository at this point in the history
Fix prompt usage when sharing the same session id (by stripping out t…
  • Loading branch information
tjake authored Dec 27, 2024
2 parents 55e4694 + 1569d44 commit d4628bb
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ public void run() {

UUID session = UUID.randomUUID();
PromptSupport promptSupport = m.promptSupport().get();
PromptSupport.Builder builder = promptSupport.builder();
PrintWriter out = System.console().writer();

out.println("\nChatting with " + modelName + "...\n");
Expand All @@ -82,7 +83,6 @@ public void run() {
break;
}

PromptSupport.Builder builder = promptSupport.builder();
if (first && systemPrompt != null) {
builder.addSystemMessage(systemPrompt);
}
Expand All @@ -97,6 +97,9 @@ public void run() {
makeOutHandler()
);

// New prompt builder and strip out the preamble since we're continuing the conversation
builder = promptSupport.builder().stripPreamble();

out.println(
"\n\n"
+ statsColor.format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,16 @@ public AbstractTensor batchForward(
KvBufferCache.KvBuffer kvbuf,
Optional<Consumer<List<AbstractTensor>>> tensorReducer
) {
AbstractTensor embedding = embedInput.batchInputsToEmbeddings(token_ids, startPos);
return forward(embedding, startPos, kvbuf, tensorReducer);
AbstractTensor embedding = null;

//Batch prompt into groups of 1024
for (int i = 0; i < token_ids.length; i += 1024) {
int[] batch = Arrays.copyOfRange(token_ids, i, Math.min(token_ids.length, i + 1024));
embedding = embedInput.batchInputsToEmbeddings(batch, startPos + i);
embedding = forward(embedding, startPos + i, kvbuf, tensorReducer);
}

return embedding;
}

public AbstractTensor forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ public static class Builder {
private boolean addGenerationPrompt = true;

private List<Message> messages = new ArrayList<>(2);
private boolean stripPreamble = false;

private Builder(TokenizerModel m) {
this.m = m;
Expand Down Expand Up @@ -230,6 +231,11 @@ public Builder addAssistantMessage(String content) {
return this;
}

public Builder stripPreamble() {
stripPreamble = true;
return this;
}

public PromptContext build() {
return build(Optional.empty());
}
Expand Down Expand Up @@ -259,8 +265,29 @@ private PromptContext build(Optional<List<Tool>> optionalTools) {
"This model does not support tools, but tools are specified"
);

Map<String, Object> args = new HashMap<>();

String preamble = "";
if (stripPreamble) {
Map<String, Object> args = new HashMap<>();
args.putAll(
Map.of(
"messages",
Map.of(),
"add_generation_prompt",
false,
"eos_token",
m.eosToken(),
"bos_token",
""
)
); // We add the BOS ourselves
optionalTools.ifPresent(tools -> args.put("tools", tools));

RenderResult r = jinjava.renderForResult(template, args);
preamble = r.getOutput();
}

Map<String, Object> args = new HashMap<>();
args.putAll(
Map.of(
"messages",
Expand All @@ -280,7 +307,8 @@ private PromptContext build(Optional<List<Tool>> optionalTools) {

if (r.hasErrors()) logger.debug("Prompt template errors: " + r.getErrors());

return new PromptContext(r.getOutput(), optionalTools);
String output = r.getOutput();
return new PromptContext(output.substring(preamble.length()), optionalTools);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ private TensorOperations pickFastestImplementation() {
if (!forcePanama) {
try {
Class<? extends TensorOperations> nativeClazz = (Class<? extends TensorOperations>) Class.forName(
"com.github.tjake.jlama.tensor.operations.NativeTensorOperations"
"com.github.tjake.jlama.tensor.operations.NativeSimdTensorOperations"
);
pick = nativeClazz.getConstructor().newInstance();
// This will throw if no shared lib found
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NativeTensorOperations implements TensorOperations {
private static final Logger logger = LoggerFactory.getLogger(NativeTensorOperations.class);
public class NativeSimdTensorOperations implements TensorOperations {
private static final Logger logger = LoggerFactory.getLogger(NativeSimdTensorOperations.class);

static {
if (!JarSupport.maybeLoadLibrary()) System.loadLibrary("jlama");
if (!JarSupport.maybeLoadLibrary("jlama")) System.loadLibrary("jlama");
}

public static final int HAS_F16C = NativeSimd.HAS_F16C();
Expand All @@ -52,7 +52,7 @@ public class NativeTensorOperations implements TensorOperations {

final int flags;

public NativeTensorOperations() {
public NativeSimdTensorOperations() {
int f = 0;

if (RuntimeSupport.isLinux()) f |= HAS_F16C;
Expand All @@ -63,7 +63,7 @@ public NativeTensorOperations() {
checkLib();
}

NativeTensorOperations(int flags) {
NativeSimdTensorOperations(int flags) {
this.flags = flags;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,16 @@
public class JarSupport {
private static final Logger logger = LoggerFactory.getLogger(JarSupport.class);

public static boolean maybeLoadLibrary() {
public static boolean maybeLoadLibrary(String libname) {
String ext = RuntimeSupport.isMac() ? ".dylib" : RuntimeSupport.isWin() ? ".dll" : ".so";
URL lib = JarSupport.class.getClassLoader().getResource("META-INF/native/lib/libjlama" + ext);
URL lib = JarSupport.class.getClassLoader().getResource("META-INF/native/lib/lib" + libname + ext);

if (lib != null) {
try {
final File libpath = Files.createTempDirectory("jlama").toFile();
libpath.deleteOnExit(); // just in case

File libfile = Paths.get(libpath.getAbsolutePath(), "libjlama" + ext).toFile();
File libfile = Paths.get(libpath.getAbsolutePath(), "lib" + libname + ext).toFile();
libfile.deleteOnExit(); // just in case

final InputStream in = lib.openStream();
Expand All @@ -53,10 +53,10 @@ public static boolean maybeLoadLibrary() {
out.close();
in.close();
System.load(libfile.getAbsolutePath());
logger.debug("Loaded jlama-native library: {}", libfile.getAbsolutePath());
logger.debug("Loaded {}-native library: {}", libname, libfile.getAbsolutePath());
return true;
} catch (IOException e) {
logger.warn("Error loading jlama-native library");
logger.warn("Error loading {}-native library", libname);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ Object createChatCompletion(@RequestHeader Map<String, String> headers, @Valid @
if (request.getStream() != null && request.getStream()) {
SseEmitter emitter = new SseEmitter(-1L);
CompletableFuture.supplyAsync(
() -> model.generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> CompletableFuture.supplyAsync(() -> {
() -> model. generate(sessionId, builder.build(), temperature, maxTokens, (t, f) -> CompletableFuture.supplyAsync(() -> {
try {
emitter.send(
new CreateChatCompletionStreamResponse().id(sessionId.toString())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ public void testPromptSupportWithTools() {
@Test
public void testMistralTools() {

String modelPrefix = "../models/Mistral-7B-Instruct-v0.3";
String modelPrefix = "../models/tjake_Mistral-7B-Instruct-v0.3-JQ4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));

Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Expand Down Expand Up @@ -420,4 +420,51 @@ public void testToolParse() throws JsonProcessingException {

Assert.assertEquals(2, toolCalls.size());
}

@Test
public void testPromptBuilderSession() {
String modelPrefix = "../models/Qwen_Qwen2.5-0.5B-Instruct-JQ4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));

Tokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
PromptSupport.Builder builder = tokenizer.promptSupport().get().builder();
builder.addSystemMessage("You always respond as a pirate");
builder.addUserMessage("What is the weather in paris right now?");
builder.addGenerationPrompt(true);

Tool t = Tool.from(
Function.builder()
.name("get_current_temperature")
.description("Simulates getting the current temperature at a location.")
.addParameter("location", "string", "The location to get the temperature for, in the format \"City, Country\".", true)
.addParameter("unit", "string", "The unit to return the temperature in (e.g., \"celsius\", \"fahrenheit\").", true)
.build()
);

PromptContext prompt = builder.build(t);
Assert.assertEquals(
"<|im_start|>system\n" + "You always respond as a pirate\n"
+ "\n"
+ "# Tools\n"
+ "\n"
+ "You may call one or more functions to assist with the user query.\n"
+ "\n"
+ "You are provided with function signatures within <tools></tools> XML tags:\n"
+ "<tools>\n"
+ "{\"type\": \"function\", \"function\": {\"name\": \"get_current_temperature\", \"description\": \"Simulates getting the current temperature at a location.\", \"parameters\": {\"type\": \"object\", \"properties\": {\"location\": {\"type\": \"string\", \"description\": \"The location to get the temperature for, in the format \\\"City, Country\\\".\"}, \"unit\": {\"type\": \"string\", \"description\": \"The unit to return the temperature in (e.g., \\\"celsius\\\", \\\"fahrenheit\\\").\"}}, \"required\": [\"location\", \"unit\"]}}}\n"
+ "</tools>\n"
+ "\n"
+ "For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n"
+ "<tool_call>\n"
+ "{\"name\": <function-name>, \"arguments\": <args-json-object>}\n"
+ "</tool_call><|im_end|>\n"
+ "<|im_start|>user\n"
+ "What is the weather in paris right now?<|im_end|>\n"
+ "<|im_start|>assistant\n",
prompt.getPrompt());

prompt = tokenizer.promptSupport().get().builder().addUserMessage("This is a test").stripPreamble().build();
Assert.assertEquals(
"<|im_start|>user\n" + "This is a test<|im_end|>\n" + "<|im_start|>assistant\n", prompt.getPrompt());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package com.github.tjake.jlama.tensor.operations;

import static com.github.tjake.jlama.tensor.operations.NativeTensorOperations.*;
import static com.github.tjake.jlama.tensor.operations.NativeSimdTensorOperations.*;

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.safetensors.DType;
Expand Down Expand Up @@ -59,19 +59,19 @@ public static void init() {
opTypes.add(new PanamaTensorOperations(MachineSpec.Type.AVX_256));
opTypes.add(new PanamaTensorOperations(MachineSpec.Type.ARM_128));

if (globalOps instanceof NativeTensorOperations) {
opTypes.add(new NativeTensorOperations());
opTypes.add(new NativeTensorOperations(0));
if (globalOps instanceof NativeSimdTensorOperations) {
opTypes.add(new NativeSimdTensorOperations());
opTypes.add(new NativeSimdTensorOperations(0));

if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeTensorOperations(HAS_AVX2));
if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeSimdTensorOperations(HAS_AVX2));

if (RuntimeSupport.isLinux() || RuntimeSupport.isWin()) {
opTypes.add(new NativeTensorOperations(HAS_F16C));
if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeTensorOperations(HAS_F16C | HAS_AVX2));
opTypes.add(new NativeSimdTensorOperations(HAS_F16C));
if (MachineSpec.VECTOR_TYPE == MachineSpec.Type.AVX_512) opTypes.add(new NativeSimdTensorOperations(HAS_F16C | HAS_AVX2));
}

if (RuntimeSupport.isArm()) {
opTypes.add(new NativeTensorOperations(MachineSpec.Type.ARM_128.ctag));
opTypes.add(new NativeSimdTensorOperations(MachineSpec.Type.ARM_128.ctag));
}
}

Expand Down Expand Up @@ -198,7 +198,7 @@ public void testSplitDotProduct() {

@Test
public void testNativeDotProduct() {
Assume.assumeTrue(globalOps instanceof NativeTensorOperations);
Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations);
AbstractTensor a = makeTensor(SIZE);
AbstractTensor b = makeTensor(SIZE);

Expand Down Expand Up @@ -476,7 +476,7 @@ public void testBatchDotProductWithResultOffset() {
@Test
public void testNativeBatchDotProduct() {
// M == BATCH, N == ROWS, K == SIZE
Assume.assumeTrue(globalOps instanceof NativeTensorOperations);
Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations);

FloatBufferTensor c = new FloatBufferTensor(BATCH, ROWS);
FloatBufferTensor c1 = new FloatBufferTensor(BATCH, ROWS);
Expand Down Expand Up @@ -512,7 +512,7 @@ public void testNativeBatchDotProduct() {
@Test
public void testNativeBatchDotProductWithOffsets() {
// M == BATCH, N == ROWS, K == SIZE
Assume.assumeTrue(globalOps instanceof NativeTensorOperations);
Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations);

FloatBufferTensor c = new FloatBufferTensor(BATCH, ROWS);
FloatBufferTensor c1 = new FloatBufferTensor(BATCH, ROWS);
Expand Down Expand Up @@ -548,7 +548,7 @@ public void testNativeBatchDotProductWithOffsets() {
@Test
public void testNativeDotProductFast() {
// M == BATCH, N == ROWS, K == SIZE
Assume.assumeTrue(globalOps instanceof NativeTensorOperations);
Assume.assumeTrue(globalOps instanceof NativeSimdTensorOperations);

FloatBufferTensor c = new FloatBufferTensor(1, SIZE);
FloatBufferTensor c1 = new FloatBufferTensor(1, SIZE);
Expand Down

0 comments on commit d4628bb

Please sign in to comment.