Skip to content

Commit

Permalink
Next release
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Oct 30, 2024
1 parent fed0222 commit 3a1ec63
Show file tree
Hide file tree
Showing 15 changed files with 190 additions and 184 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Model Support:
* Llama & Llama2 & Llama3 Models
* Mistral & Mixtral Models
* Qwen2 Models
* IBM Granite Models
* GPT-2 Models
* BERT Models
* BPE Tokenizers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,26 +394,21 @@ public Map<String, Float> classify(String input, PoolingType poolingType) {
}

public float[] getLogits(AbstractTensor output) {
try (AbstractTensor embedding = sampleOutput.getOutputLayerNorm().forward(output);
AbstractTensor logits = makeDenseTensor(1, c.vocabularySize)) {
try (
AbstractTensor embedding = sampleOutput.getOutputLayerNorm().forward(output);
AbstractTensor logits = makeDenseTensor(1, c.vocabularySize)
) {

VectorMath.pchunk(0, c.vocabularySize, (chunkStart, chunkSize) -> {
TensorOperationsProvider.get()
.dotProductChunk(
logits,
embedding,
sampleOutput.getOutputLogitsWeights(),
0,
c.embeddingLength,
chunkStart,
chunkSize);
.dotProductChunk(logits, embedding, sampleOutput.getOutputLogitsWeights(), 0, c.embeddingLength, chunkStart, chunkSize);
});

VectorMath.softMax(logits, 0, c.vocabularySize);

float[] r = new float[c.vocabularySize];

//Convert from Tensor to float array
// Convert from Tensor to float array
logits.getMemorySegment().asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(r);

return r;
Expand Down Expand Up @@ -470,23 +465,21 @@ public int sample(AbstractTensor output, float temperature, float uniformSample,
}
}


protected boolean addBosToken() {
return true;
}

public int[] encodePrompt(PromptContext promptContext) {
long[] encoded = tokenizer.encode(promptContext.getPrompt());

if (!addBosToken())
return Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray();
if (!addBosToken()) return Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray();

// Remove BOS token if it's the first token, we explicitly add it below
if (encoded.length > 0 && encoded[0] == c.bosToken) {
encoded = Arrays.copyOfRange(encoded, 1, encoded.length);
}

int[] promptTokens = new int[(1 + encoded.length)];
int[] promptTokens = new int[(1 + encoded.length)];
promptTokens[0] = c.bosToken;
for (int i = 1; i <= encoded.length; i++)
promptTokens[i] = Ints.checkedCast(encoded[i - 1]);
Expand Down Expand Up @@ -514,11 +507,7 @@ public Response generate(
try (KvBufferCache.KvBuffer kvmem = kvBufferCache.getKvBuffer(sessionId)) { // k and v for context window
int startPos = kvmem.getCurrentContextPosition(); // Number of tokens in the buffer

logger.debug(
"Starting at token {} for session {} with prompt {}",
startPos,
sessionId,
promptContext.getPrompt());
logger.debug("Starting at token {} for session {} with prompt {}", startPos, sessionId, promptContext.getPrompt());

if (ntokens > c.contextLength) ntokens = c.contextLength;

Expand All @@ -532,36 +521,32 @@ public Response generate(
try (AbstractTensor logits = makeDenseTensor(c.vocabularySize)) {
int[] promptTokens;

if (addBosToken()) {
promptTokens = new int[(1 + encoded.length)];
if (addBosToken()) {
promptTokens = new int[(1 + encoded.length)];

promptTokens[0] = c.bosToken;
for (int i = 1; i <= encoded.length; i++) promptTokens[i] = Ints.checkedCast(encoded[i - 1]);
promptLength = encoded.length;
} else {
promptTokens = Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray();
promptLength = encoded.length;
}
promptTokens[0] = c.bosToken;
for (int i = 1; i <= encoded.length; i++)
promptTokens[i] = Ints.checkedCast(encoded[i - 1]);
promptLength = encoded.length;
} else {
promptTokens = Arrays.stream(encoded).mapToInt(Ints::checkedCast).toArray();
promptLength = encoded.length;
}

long start = System.currentTimeMillis();
long promptStart = start;
// Batch Process Prompt
AbstractTensor last = DebugSupport.isDebug()
? batchForwardSlow(promptTokens, startPos, kvmem)
: batchForward(promptTokens, startPos, kvmem);
? batchForwardSlow(promptTokens, startPos, kvmem)
: batchForward(promptTokens, startPos, kvmem);

promptBatchTime = System.currentTimeMillis() - start;
float batchMsPerToken = Math.round((((double) promptBatchTime) / (double) promptLength));
logger.debug(
"{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken);
logger.debug("{} prompt tokens in {}ms | {}ms per token", promptLength, promptBatchTime, batchMsPerToken);

float genMsPerToken = 0;
tokensGenerated = 0;
int next = sample(
last.slice(last.shape().first() - 1),
temperature,
ThreadLocalRandom.current().nextFloat(),
logits);
int next = sample(last.slice(last.shape().first() - 1), temperature, ThreadLocalRandom.current().nextFloat(), logits);
last.close();
try {
String c = tokenizer.decode(next);
Expand All @@ -581,11 +566,9 @@ public Response generate(
AbstractTensor output = forward(next, i, kvmem);
tokensGenerated++;

next = sample(
output, temperature, ThreadLocalRandom.current().nextFloat(), logits);
next = sample(output, temperature, ThreadLocalRandom.current().nextFloat(), logits);

if (logger.isTraceEnabled())
logger.trace("Sampled token {} with temperature {}", next, temperature);
if (logger.isTraceEnabled()) logger.trace("Sampled token {} with temperature {}", next, temperature);
output.close();

kvmem.incrementContextPosition();
Expand Down Expand Up @@ -615,16 +598,22 @@ public Response generate(
long end = System.currentTimeMillis();

Response response = new Response(
responseText.toString(),
responseTextWithSpecialTokens.toString(),
reason,
promptLength,
tokensGenerated,
promptBatchTime,
end - start);
logger.debug(String.format(
responseText.toString(),
responseTextWithSpecialTokens.toString(),
reason,
promptLength,
tokensGenerated,
promptBatchTime,
end - start
);
logger.debug(
String.format(
"\n\nelapsed: %ds, prompt %.1fms per token, gen %.1fms per token\n",
TimeUnit.MILLISECONDS.toSeconds(end - promptStart), batchMsPerToken, genMsPerToken));
TimeUnit.MILLISECONDS.toSeconds(end - promptStart),
batchMsPerToken,
genMsPerToken
)
);

return postProcessResponse(promptContext, response);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ public AbstractTensor forward(
}
TensorOperationsProvider.get().accumulate(lnpostFF, lnattn, 0, model.c.embeddingLength);


debug("post_ff_res", lnpostFF, layerIndex);

// Release any tmp buffers (embedding is released by caller)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public Gemma2Config(
layerNormEps,
vocabularySize,
bosToken,
eosTokens instanceof List ? (List<Integer>) eosTokens : List.of((Integer)eosTokens),
eosTokens instanceof List ? (List<Integer>) eosTokens : List.of((Integer) eosTokens),
activationFunction,
ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta,
ropeScaling == null ? 1.0 : Double.parseDouble(ropeScaling.get("factor")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import com.github.tjake.jlama.math.FloatConversions;
import com.github.tjake.jlama.model.*;
import com.github.tjake.jlama.model.functions.ClassifyOutput;
import com.github.tjake.jlama.model.functions.EmbedInput;
import com.github.tjake.jlama.model.functions.SampleOutput;
import com.github.tjake.jlama.model.llama.LlamaModel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,42 +26,42 @@
public class GraniteConfig extends Config {
@JsonCreator
public GraniteConfig(
@JsonProperty("max_position_embeddings") int contextLength,
@JsonProperty("hidden_size") int embeddingLength,
@JsonProperty("intermediate_size") int hiddenLength,
@JsonProperty("num_attention_heads") int numberOfHeads,
@JsonProperty("num_key_value_heads") int numberOfKeyValueHeads,
@JsonProperty("num_hidden_layers") int numberOfLayers,
@JsonProperty("rms_norm_eps") float layerNormEps,
@JsonProperty("vocab_size") int vocabularySize,
@JsonProperty("bos_token_id") int bosToken,
@JsonProperty("eos_token_id") int eosToken,
@JsonProperty("hidden_act") ActivationFunction.Type activationFunction,
@JsonProperty("rope_theta") Double ropeFreqsTheta,
@JsonProperty("rope_scaling") Map<String, String> ropeScaling,
@JsonProperty("residual_multiplier") Float residualMultiplier,
@JsonProperty("attention_multiplier") Float attentionMultiplier,
@JsonProperty("embedding_multiplier") Float embeddingMultiplier,
@JsonProperty("logits_scaling") Float logitsScaling
@JsonProperty("max_position_embeddings") int contextLength,
@JsonProperty("hidden_size") int embeddingLength,
@JsonProperty("intermediate_size") int hiddenLength,
@JsonProperty("num_attention_heads") int numberOfHeads,
@JsonProperty("num_key_value_heads") int numberOfKeyValueHeads,
@JsonProperty("num_hidden_layers") int numberOfLayers,
@JsonProperty("rms_norm_eps") float layerNormEps,
@JsonProperty("vocab_size") int vocabularySize,
@JsonProperty("bos_token_id") int bosToken,
@JsonProperty("eos_token_id") int eosToken,
@JsonProperty("hidden_act") ActivationFunction.Type activationFunction,
@JsonProperty("rope_theta") Double ropeFreqsTheta,
@JsonProperty("rope_scaling") Map<String, String> ropeScaling,
@JsonProperty("residual_multiplier") Float residualMultiplier,
@JsonProperty("attention_multiplier") Float attentionMultiplier,
@JsonProperty("embedding_multiplier") Float embeddingMultiplier,
@JsonProperty("logits_scaling") Float logitsScaling
) {
super(
contextLength,
embeddingLength,
hiddenLength,
numberOfHeads,
numberOfKeyValueHeads,
numberOfLayers,
layerNormEps,
vocabularySize,
bosToken,
List.of(eosToken),
activationFunction,
ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta,
ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor")),
residualMultiplier,
attentionMultiplier,
embeddingMultiplier,
logitsScaling
contextLength,
embeddingLength,
hiddenLength,
numberOfHeads,
numberOfKeyValueHeads,
numberOfLayers,
layerNormEps,
vocabularySize,
bosToken,
List.of(eosToken),
activationFunction,
ropeFreqsTheta == null ? 10000.0 : ropeFreqsTheta,
ropeScaling == null || !("linear".equals(ropeScaling.get("rope_type"))) ? 1.0 : Double.parseDouble(ropeScaling.get("factor")),
residualMultiplier,
attentionMultiplier,
embeddingMultiplier,
logitsScaling
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,34 +76,34 @@ protected TransformerBlock[] loadTransformerBlockWeights() {
String base = "model.layers." + i + ".";
String prefix = base + "self_attn.";
CausalSelfAttention attention = new CausalSelfAttention(
this,
relativeLayer,
weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "o_proj.weight", c.dctx(), false, true).quantize(qType)
this,
relativeLayer,
weights.load(prefix + "q_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "k_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "v_proj.weight", c.dctx(), true, false).quantize(qType),
weights.load(prefix + "o_proj.weight", c.dctx(), false, true).quantize(qType)
);

prefix = base + "mlp.";

MLPBlock mlp = new MLPBlock(
this,
c.activationFunction,
weights.load(prefix + "gate_proj.weight", c.dctx(), true, false).quantize(qType), // w1
weights.load(prefix + "down_proj.weight", c.dctx(), false, true).quantize(qType), // w2
weights.load(prefix + "up_proj.weight", c.dctx(), true, false).quantize(qType)
this,
c.activationFunction,
weights.load(prefix + "gate_proj.weight", c.dctx(), true, false).quantize(qType), // w1
weights.load(prefix + "down_proj.weight", c.dctx(), false, true).quantize(qType), // w2
weights.load(prefix + "up_proj.weight", c.dctx(), true, false).quantize(qType)
); // w3

transformerBlocks[relativeLayer] = new TransformerBlock(
this,
relativeLayer,
Optional.of(new RMSNorm(this, weights.load(base + "input_layernorm.weight").quantize(qType))),
attention,
Optional.empty(),
Optional.of(new RMSNorm(this, weights.load(base + "post_attention_layernorm.weight").quantize(qType))),
mlp,
Optional.empty(),
Optional.empty()
this,
relativeLayer,
Optional.of(new RMSNorm(this, weights.load(base + "input_layernorm.weight").quantize(qType))),
attention,
Optional.empty(),
Optional.of(new RMSNorm(this, weights.load(base + "post_attention_layernorm.weight").quantize(qType))),
mlp,
Optional.empty(),
Optional.empty()
);
});

Expand Down Expand Up @@ -131,4 +131,4 @@ protected EmbedInput loadInputWeights() {
protected boolean addBosToken() {
return false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ protected long encodeCharacterAsToken(byte c) {
@Override
protected Optional<Character> maybeDecodeTokenAsCharacter(long id) {
// Handle ascii codes (shifted by N in vocab)
if (model.byteFallback && byteFallbackEncodingOffset > 0 && id >= byteFallbackEncodingOffset && id < 256 + byteFallbackEncodingOffset) {
if (model.byteFallback
&& byteFallbackEncodingOffset > 0
&& id >= byteFallbackEncodingOffset
&& id < 256 + byteFallbackEncodingOffset) {
char c = (char) (id - byteFallbackEncodingOffset);
return Optional.of(c);
}
Expand Down
Loading

0 comments on commit 3a1ec63

Please sign in to comment.