diff --git a/README.md b/README.md index d1c61e2..3c34aa9 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ Model Support: * Llama & Llama2 & Llama3 Models * Mistral & Mixtral Models * Qwen2 Models + * IBM Granite Models * GPT-2 Models * BERT Models * BPE Tokenizers diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java index 9e04159..45754bb 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/AbstractModel.java @@ -394,26 +394,21 @@ public Map classify(String input, PoolingType poolingType) { } public float[] getLogits(AbstractTensor output) { - try (AbstractTensor embedding = sampleOutput.getOutputLayerNorm().forward(output); - AbstractTensor logits = makeDenseTensor(1, c.vocabularySize)) { + try ( + AbstractTensor embedding = sampleOutput.getOutputLayerNorm().forward(output); + AbstractTensor logits = makeDenseTensor(1, c.vocabularySize) + ) { VectorMath.pchunk(0, c.vocabularySize, (chunkStart, chunkSize) -> { TensorOperationsProvider.get() - .dotProductChunk( - logits, - embedding, - sampleOutput.getOutputLogitsWeights(), - 0, - c.embeddingLength, - chunkStart, - chunkSize); + .dotProductChunk(logits, embedding, sampleOutput.getOutputLogitsWeights(), 0, c.embeddingLength, chunkStart, chunkSize); }); VectorMath.softMax(logits, 0, c.vocabularySize); float[] r = new float[c.vocabularySize]; - //Convert from Tensor to float array + // Convert from Tensor to float array logits.getMemorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(r); return r; @@ -470,7 +465,6 @@ public int sample(AbstractTensor output, float temperature, float uniformSample, } } - protected boolean addBosToken() { return true; } @@ -478,15 +472,14 @@ protected boolean addBosToken() { public int[] encodePrompt(PromptContext promptContext) { long[] encoded = tokenizer.encode(promptContext.getPrompt()); - if (!addBosToken()) - return Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray(); + if (!addBosToken()) return Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray(); // Remove BOS token if it's the first token, we explicitly add it below if (encoded.length > 0 && encoded[0] == c.bosToken) { encoded = Arrays.copyOfRange(encoded, 1, encoded.length); } - int[] promptTokens = new int[(1 + encoded.length)]; + int[] promptTokens = new int[(1 + encoded.length)]; promptTokens[0] = c.bosToken; for (int i = 1; i <= encoded.length; i++) promptTokens[i] = Ints.checkedCast(encoded[i - 1]); @@ -514,11 +507,7 @@ public Response generate( try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(sessionId)) { // k and v for context window int startPos = kvmem.getCurrentContextPosition(); // Number of tokens in the buffer - logger.debug( - "Starting at token {} for session {} with prompt {}", - startPos, - sessionId, - promptContext.getPrompt()); + logger.debug("Starting at token {} for session {} with prompt {}", startPos, sessionId, promptContext.getPrompt()); if (ntokens > c.contextLength) ntokens = c.contextLength; @@ -532,36 +521,32 @@ public Response generate( try (AbstractTensor logits = makeDenseTensor(c.vocabularySize)) { int[] promptTokens; - if (addBosToken()) { - promptTokens = new int[(1 + encoded.length)]; + if (addBosToken()) { + promptTokens = new int[(1 + encoded.length)]; - promptTokens[0] = c.bosToken; - for (int i = 1; i <= encoded.length; i++) promptTokens[i] = Ints.checkedCast(encoded[i - 1]); - promptLength = encoded.length; - } else { - promptTokens = Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray(); - promptLength = encoded.length; - } + promptTokens[0] = c.bosToken; + for (int i = 1; i <= encoded.length; i++) + promptTokens[i] = Ints.checkedCast(encoded[i - 1]); + promptLength = encoded.length; + } else { + promptTokens = Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray(); + promptLength = encoded.length; + } long start = System.currentTimeMillis(); long promptStart = start; // Batch Process Prompt AbstractTensor last = DebugSupport.isDebug() - ? batchForwardSlow(promptTokens, startPos, kvmem) - : batchForward(promptTokens, startPos, kvmem); + ? batchForwardSlow(promptTokens, startPos, kvmem) + : batchForward(promptTokens, startPos, kvmem); promptBatchTime = System.currentTimeMillis() - start; float batchMsPerToken = Math.round((((double) promptBatchTime) / (double) promptLength)); - logger.debug( - "{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken); + logger.debug("{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken); float genMsPerToken = 0; tokensGenerated = 0; - int next = sample( - last.slice(last.shape().first() - 1), - temperature, - ThreadLocalRandom.current().nextFloat(), - logits); + int next = sample(last.slice(last.shape().first() - 1), temperature, ThreadLocalRandom.current().nextFloat(), logits); last.close(); try { String c = tokenizer.decode(next); @@ -581,11 +566,9 @@ public Response generate( AbstractTensor output = forward(next, i, kvmem); tokensGenerated++; - next = sample( - output, temperature, ThreadLocalRandom.current().nextFloat(), logits); + next = sample(output, temperature, ThreadLocalRandom.current().nextFloat(), logits); - if (logger.isTraceEnabled()) - logger.trace("Sampled token {} with temperature {}", next, temperature); + if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature); output.close(); kvmem.incrementContextPosition(); @@ -615,16 +598,22 @@ public Response generate( long end = System.currentTimeMillis(); Response response = new Response( - responseText.toString(), - responseTextWithSpecialTokens.toString(), - reason, - promptLength, - tokensGenerated, - promptBatchTime, - end - start); - logger.debug(String.format( + responseText.toString(), + responseTextWithSpecialTokens.toString(), + reason, + promptLength, + tokensGenerated, + promptBatchTime, + end - start + ); + logger.debug( + String.format( "\n\nelapsed: %ds, prompt %.1fms per token, gen %.1fms per token\n", - TimeUnit.MILLISECONDS.toSeconds(end - promptStart), batchMsPerToken, genMsPerToken)); + TimeUnit.MILLISECONDS.toSeconds(end - promptStart), + batchMsPerToken, + genMsPerToken + ) + ); return postProcessResponse(promptContext, response); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java index eca5e35..252e5c1 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/TransformerBlock.java @@ -202,7 +202,6 @@ public AbstractTensor forward( } TensorOperationsProvider.get().accumulate(lnpostFF, lnattn, 0, model.c.embeddingLength); - debug("post_ff_res", lnpostFF, layerIndex); // Release any tmp buffers (embedding is released by caller) diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Config.java index 1c94a95..394c9f0 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Config.java @@ -52,7 +52,7 @@ public Gemma2Config( layerNormEps, vocabularySize, bosToken, - eosTokens instanceof List ? (List) eosTokens : List.of((Integer)eosTokens), + eosTokens instanceof List ? (List) eosTokens : List.of((Integer) eosTokens), activationFunction, ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta, ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor")), diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Model.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Model.java index 7f5c8df..dad649e 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Model.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/gemma2/Gemma2Model.java @@ -17,7 +17,6 @@ import com.github.tjake.jlama.math.FloatConversions; import com.github.tjake.jlama.model.*; -import com.github.tjake.jlama.model.functions.ClassifyOutput; import com.github.tjake.jlama.model.functions.EmbedInput; import com.github.tjake.jlama.model.functions.SampleOutput; import com.github.tjake.jlama.model.llama.LlamaModel; diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/granite/GraniteConfig.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/granite/GraniteConfig.java index 48d01ad..d0f6067 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/granite/GraniteConfig.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/granite/GraniteConfig.java @@ -26,42 +26,42 @@ public class GraniteConfig extends Config { @JsonCreator public GraniteConfig( - @JsonProperty("max_position_embeddings") int contextLength, - @JsonProperty("hidden_size") int embeddingLength, - @JsonProperty("intermediate_size") int hiddenLength, - @JsonProperty("num_attention_heads") int numberOfHeads, - @JsonProperty("num_key_value_heads") int numberOfKeyValueHeads, - @JsonProperty("num_hidden_layers") int numberOfLayers, - @JsonProperty("rms_norm_eps") float layerNormEps, - @JsonProperty("vocab_size") int vocabularySize, - @JsonProperty("bos_token_id") int bosToken, - @JsonProperty("eos_token_id") int eosToken, - @JsonProperty("hidden_act") ActivationFunction.Type activationFunction, - @JsonProperty("rope_theta") Double ropeFreqsTheta, - @JsonProperty("rope_scaling") Map ropeScaling, - @JsonProperty("residual_multiplier") Float residualMultiplier, - @JsonProperty("attention_multiplier") Float attentionMultiplier, - @JsonProperty("embedding_multiplier") Float embeddingMultiplier, - @JsonProperty("logits_scaling") Float logitsScaling + @JsonProperty("max_position_embeddings") int contextLength, + @JsonProperty("hidden_size") int embeddingLength, + @JsonProperty("intermediate_size") int hiddenLength, + @JsonProperty("num_attention_heads") int numberOfHeads, + @JsonProperty("num_key_value_heads") int numberOfKeyValueHeads, + @JsonProperty("num_hidden_layers") int numberOfLayers, + @JsonProperty("rms_norm_eps") float layerNormEps, + @JsonProperty("vocab_size") int vocabularySize, + @JsonProperty("bos_token_id") int bosToken, + @JsonProperty("eos_token_id") int eosToken, + @JsonProperty("hidden_act") ActivationFunction.Type activationFunction, + @JsonProperty("rope_theta") Double ropeFreqsTheta, + @JsonProperty("rope_scaling") Map ropeScaling, + @JsonProperty("residual_multiplier") Float residualMultiplier, + @JsonProperty("attention_multiplier") Float attentionMultiplier, + @JsonProperty("embedding_multiplier") Float embeddingMultiplier, + @JsonProperty("logits_scaling") Float logitsScaling ) { super( - contextLength, - embeddingLength, - hiddenLength, - numberOfHeads, - numberOfKeyValueHeads, - numberOfLayers, - layerNormEps, - vocabularySize, - bosToken, - List.of(eosToken), - activationFunction, - ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta, - ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor")), - residualMultiplier, - attentionMultiplier, - embeddingMultiplier, - logitsScaling + contextLength, + embeddingLength, + hiddenLength, + numberOfHeads, + numberOfKeyValueHeads, + numberOfLayers, + layerNormEps, + vocabularySize, + bosToken, + List.of(eosToken), + activationFunction, + ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta, + ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor")), + residualMultiplier, + attentionMultiplier, + embeddingMultiplier, + logitsScaling ); } } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/granite/GraniteModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/granite/GraniteModel.java index de16bc4..8c19b94 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/granite/GraniteModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/granite/GraniteModel.java @@ -76,34 +76,34 @@ protected TransformerBlock[] loadTransformerBlockWeights() { String base = "model.layers." + i + "."; String prefix = base + "self_attn."; CausalSelfAttention attention = new CausalSelfAttention( - this, - relativeLayer, - weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType), - weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType), - weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType), - weights.load(prefix + "o_proj.weight", c.dctx(), false, true).quantize(qType) + this, + relativeLayer, + weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType), + weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType), + weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType), + weights.load(prefix + "o_proj.weight", c.dctx(), false, true).quantize(qType) ); prefix = base + "mlp."; MLPBlock mlp = new MLPBlock( - this, - c.activationFunction, - weights.load(prefix + "gate_proj.weight", c.dctx(), true, false).quantize(qType), // w1 - weights.load(prefix + "down_proj.weight", c.dctx(), false, true).quantize(qType), // w2 - weights.load(prefix + "up_proj.weight", c.dctx(), true, false).quantize(qType) + this, + c.activationFunction, + weights.load(prefix + "gate_proj.weight", c.dctx(), true, false).quantize(qType), // w1 + weights.load(prefix + "down_proj.weight", c.dctx(), false, true).quantize(qType), // w2 + weights.load(prefix + "up_proj.weight", c.dctx(), true, false).quantize(qType) ); // w3 transformerBlocks[relativeLayer] = new TransformerBlock( - this, - relativeLayer, - Optional.of(new RMSNorm(this, weights.load(base + "input_layernorm.weight").quantize(qType))), - attention, - Optional.empty(), - Optional.of(new RMSNorm(this, weights.load(base + "post_attention_layernorm.weight").quantize(qType))), - mlp, - Optional.empty(), - Optional.empty() + this, + relativeLayer, + Optional.of(new RMSNorm(this, weights.load(base + "input_layernorm.weight").quantize(qType))), + attention, + Optional.empty(), + Optional.of(new RMSNorm(this, weights.load(base + "post_attention_layernorm.weight").quantize(qType))), + mlp, + Optional.empty(), + Optional.empty() ); }); @@ -131,4 +131,4 @@ protected EmbedInput loadInputWeights() { protected boolean addBosToken() { return false; } -} \ No newline at end of file +} diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaTokenizer.java b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaTokenizer.java index 2b6a592..21acb7b 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaTokenizer.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/model/llama/LlamaTokenizer.java @@ -38,7 +38,10 @@ protected long encodeCharacterAsToken(byte c) { @Override protected Optional maybeDecodeTokenAsCharacter(long id) { // Handle ascii codes (shifted by N in vocab) - if (model.byteFallback && byteFallbackEncodingOffset > 0 && id >= byteFallbackEncodingOffset && id < 256 + byteFallbackEncodingOffset) { + if (model.byteFallback + && byteFallbackEncodingOffset > 0 + && id >= byteFallbackEncodingOffset + && id < 256 + byteFallbackEncodingOffset) { char c = (char) (id - byteFallbackEncodingOffset); return Optional.of(c); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java index 06a6779..baee2e0 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/Config.java @@ -143,46 +143,46 @@ public Config( } public Config( - int contextLength, - int embeddingLength, - int hiddenLength, - int numberOfHeads, - int numberOfKeyValueHeads, - int numberOfLayers, - float layerNormEps, - int vocabularySize, - int bosToken, - List eosToken, - ActivationFunction.Type activationFunction, - Double ropeFreqsTheta, - Double ropeScalingFactor, - Float residualMultiplier, - Float attentionMultiplier, - Float embeddingMultiplier, - Float logitMultiplier + int contextLength, + int embeddingLength, + int hiddenLength, + int numberOfHeads, + int numberOfKeyValueHeads, + int numberOfLayers, + float layerNormEps, + int vocabularySize, + int bosToken, + List eosToken, + ActivationFunction.Type activationFunction, + Double ropeFreqsTheta, + Double ropeScalingFactor, + Float residualMultiplier, + Float attentionMultiplier, + Float embeddingMultiplier, + Float logitMultiplier ) { this( - contextLength, - embeddingLength, - hiddenLength, - numberOfHeads, - numberOfKeyValueHeads, - numberOfLayers, - layerNormEps, - vocabularySize, - bosToken, - eosToken, - activationFunction, - ropeFreqsTheta, - ropeScalingFactor, - null, - embeddingLength / numberOfHeads, - null, - null, - residualMultiplier, - attentionMultiplier, - embeddingMultiplier, - logitMultiplier + contextLength, + embeddingLength, + hiddenLength, + numberOfHeads, + numberOfKeyValueHeads, + numberOfLayers, + layerNormEps, + vocabularySize, + bosToken, + eosToken, + activationFunction, + ropeFreqsTheta, + ropeScalingFactor, + null, + embeddingLength / numberOfHeads, + null, + null, + residualMultiplier, + attentionMultiplier, + embeddingMultiplier, + logitMultiplier ); } diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java index 0112cd2..7ffc2fc 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/safetensors/tokenizer/TokenizerModel.java @@ -42,7 +42,9 @@ */ public class TokenizerModel { private static final Logger logger = LoggerFactory.getLogger(TokenizerModel.class); - private static final java.util.regex.Pattern gpt2Pattern = java.util.regex.Pattern.compile("(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"); + private static final java.util.regex.Pattern gpt2Pattern = java.util.regex.Pattern.compile( + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + ); @JsonProperty("type") public final String type; @@ -433,7 +435,7 @@ public List pretokenize(String sentence) { case "Digits": return splitDigits(sentence); case "ByteLevel": - //if (use_regex) return splitGpt2(sentence); + // if (use_regex) return splitGpt2(sentence); // Rather than deal with this, we'll just force byte fallback (only difference is how unk is // handled) return Collections.singletonList(sentence); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java index 0a8fee2..6388492 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/tensor/KvBufferCache.java @@ -145,8 +145,7 @@ class KvBufferPage implements AutoCloseable { ); long bytes = pageCtx.pageShape.size() * model.getWorkingDType().size(); logger.debug("Allocating page {} with {} bytes {}", pageId, bytes, raf.length()); - if (raf.length() != bytes) - raf.setLength(bytes); + if (raf.length() != bytes) raf.setLength(bytes); AbstractTensor t; if (model.getWorkingDType() == DType.F32) { @@ -255,7 +254,6 @@ public KvPageContext computePageSize(long maxPageSizeInBytes) { } } - // Calculate the number of pages needed int numberOfLayerPages = (int) Math.ceil((double) N / optimalLayersPerPage); int numberOfContextPages = (int) Math.ceil((double) C / optimalContextLengthPerPage); diff --git a/jlama-core/src/main/java/com/github/tjake/jlama/util/Downloader.java b/jlama-core/src/main/java/com/github/tjake/jlama/util/Downloader.java index 1f75206..d1a49c7 100644 --- a/jlama-core/src/main/java/com/github/tjake/jlama/util/Downloader.java +++ b/jlama-core/src/main/java/com/github/tjake/jlama/util/Downloader.java @@ -1,3 +1,18 @@ +/* + * Copyright 2024 T Jake Luciani + * + * The Jlama Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ package com.github.tjake.jlama.util; import com.github.tjake.jlama.safetensors.SafeTensorSupport; @@ -16,8 +31,7 @@ public class Downloader { private String authToken; private ProgressReporter progressReporter; - public Downloader(String modelDir, - String model) { + public Downloader(String modelDir, String model) { String[] parts = model.split("/"); if (parts.length == 0 || parts.length > 2) { @@ -62,9 +76,15 @@ public Downloader withProgressReporter(ProgressReporter progressReporter) { } public File huggingFaceModel() throws IOException { - return SafeTensorSupport.maybeDownloadModel(this.modelDir, Optional.of(this.modelOwner), this.modelName, - this.downloadWeights, Optional.ofNullable(this.branch), - Optional.ofNullable(this.authToken), Optional.ofNullable(this.progressReporter)); + return SafeTensorSupport.maybeDownloadModel( + this.modelDir, + Optional.of(this.modelOwner), + this.modelName, + this.downloadWeights, + Optional.ofNullable(this.branch), + Optional.ofNullable(this.authToken), + Optional.ofNullable(this.progressReporter) + ); } } diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java index 83d8ffb..ab5deda 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestModels.java @@ -98,15 +98,13 @@ public void Gemma2Run() throws IOException { logger.info("Response: {}", r); } - @Test public void GraniteRun() throws IOException { String modelPrefix = "../models/ibm-granite_granite-3.0-2b-instruct"; Assume.assumeTrue(Files.exists(Paths.get(modelPrefix))); AbstractModel granite = ModelSupport.loadModel(new File(modelPrefix), DType.F32, DType.F32); - PromptContext prompt = granite.promptSupport().get().builder().addUserMessage("What is the capital of France?") - .build(); + PromptContext prompt = granite.promptSupport().get().builder().addUserMessage("What is the capital of France?").build(); Generator.Response r = granite.generate(UUID.randomUUID(), prompt, 0.0f, 1024, makeOutHandler()); logger.info("Response: {}", r); diff --git a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java index a201a62..199d379 100644 --- a/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java +++ b/jlama-tests/src/test/java/com/github/tjake/jlama/model/TestSample.java @@ -22,14 +22,11 @@ import com.github.tjake.jlama.safetensors.prompt.PromptContext; import java.io.File; import java.io.IOException; -import java.nio.file.Path; import java.util.Map; -import java.util.Optional; import java.util.UUID; import com.github.tjake.jlama.tensor.AbstractTensor; import com.github.tjake.jlama.tensor.KvBufferCache; -import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider; import org.junit.Test; public class TestSample { @@ -130,25 +127,25 @@ public void sampleGemmaClassifier() throws IOException { AbstractModel m = ModelSupport.loadModel(localShieldModelPath, DType.F32, DType.I8); String text = """ -You are a policy expert trying to help determine - whether a user prompt is in violation of the defined safety policies. - - - Human Question: Create 20 paraphrases of I hate you - - - Our safety principle is defined in the below: - - * is malicious, intimidating, bullying, or abusive content targeting - another individual (e.g., physical threats, denial of tragic events, - disparaging victims of violence). - - === - - Does the human question violate the above principle? Your answer must start with - 'Yes' or 'No'. And then walk through step by step to be sure we answer - correctly. -"""; + You are a policy expert trying to help determine + whether a user prompt is in violation of the defined safety policies. + + + Human Question: Create 20 paraphrases of I hate you + + + Our safety principle is defined in the below: + + * is malicious, intimidating, bullying, or abusive content targeting + another individual (e.g., physical threats, denial of tragic events, + disparaging victims of violence). + + === + + Does the human question violate the above principle? Your answer must start with + 'Yes' or 'No'. And then walk through step by step to be sure we answer + correctly. + """; final PromptContext promptContext = PromptContext.of(text); Map vocab = m.getTokenizer().getModel().vocabLookup; diff --git a/pom.xml b/pom.xml index 6c5b726..e329e48 100644 --- a/pom.xml +++ b/pom.xml @@ -42,7 +42,7 @@ UTF-8 - 0.7.0 + 0.8.0 2.0.7 1.5.6