Skip to content

Commit

Permalink
Allow inner slice caching to speed up attention
Browse files Browse the repository at this point in the history
  • Loading branch information
tjake committed Nov 19, 2023
1 parent 661f440 commit 90c29d3
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ protected AbstractTensor forward(int token_id, int pos, AbstractTensor kvbuf) {
TransformerBlock[] transformerBlocks = getTransformerBlocks();

for (int i = 0; i < c.numberOfLayers; i++) {
AbstractTensor kvlayer = kvbuf.slice(i);
AbstractTensor kvlayer = kvbuf.slice(true, i);
AbstractTensor ref = embedding; //reference so we can free
embedding = transformerBlocks[i].forward(embedding, pos, kvlayer);
ref.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import com.google.common.base.Preconditions;

import java.util.Arrays;
import java.util.Optional;

public class CausalSelfAttention {
Expand All @@ -32,6 +33,8 @@ public class CausalSelfAttention {
private final int headSize;
private final float attentionScale;

private final float[][] flashAttnHeads;

public CausalSelfAttention(AbstractModel m, AbstractTensor queryAttnWeights, AbstractTensor keyAttnWeights, AbstractTensor valueAttnWeights,
AbstractTensor outputProjectionWeights, Optional<float[][]> ropeFrequencies)
{
Expand Down Expand Up @@ -65,6 +68,7 @@ public CausalSelfAttention(AbstractModel m, Optional<AbstractTensor> queryAttnBi
this.attentionScale = (float) (1.0 / StrictMath.sqrt(headSize));

this.ropeFrequencies = ropeFrequencies;
this.flashAttnHeads = new float[c.contextLength][c.numberOfHeads];
}

public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor kvMem) {
Expand All @@ -77,8 +81,7 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor
{

//This is our memory of the key and value vectors for each position
//This is our memory of the key and value vectors for each position
AbstractTensor kvp = kvMem.slice(position);
AbstractTensor kvp = kvMem.slice(true, position);

AbstractTensor key = kvp.slice(0);
AbstractTensor val = kvp.slice(1);
Expand Down Expand Up @@ -122,8 +125,8 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor

// with all key-value entries populated, compute attention
// the softmax is incrementally aggregated using the flash attention technique
AbstractTensor k0 = kvMem.slice(0).slice(0);
AbstractTensor v0 = kvMem.slice(0).slice(1);
AbstractTensor k0 = kvMem.slice(true, 0).slice(0);
AbstractTensor v0 = kvMem.slice(true,0).slice(1);

// value is initially the first value for all heads
value.copyFrom(v0, 0, 0, c.embeddingLength);
Expand All @@ -137,30 +140,28 @@ public AbstractTensor forward(AbstractTensor input, int position, AbstractTensor

//POSITION > 0
//This is where the context length gets expensive! We need to run this query token by all prior tokens.
float[][] flashAttnHeads = new float[position][c.numberOfHeads];
VectorMath.pfor(0, position, i -> {
AbstractTensor kk = kvMem.slice(i + 1).slice(0);
//KEY OFFSET
AbstractTensor kk = kvMem.slice(true, i + 1).slice(0);
for(int h = 0; h < c.numberOfHeads; h++){
//KEY OFFSET
flashAttnHeads[i][h] = TensorOperationsProvider.get().dotProduct(query, kk, h * headSize, h * headSize, headSize) * attentionScale;
}
});

//Now aggregate results per head
for (int i = 0; i < position; i++) {
AbstractTensor kk = kvMem.slice(i + 1).slice(1);
//VALUE OFFSET
AbstractTensor vv = kvMem.slice(true, i + 1).slice(1);
for (int h = 0; h < c.numberOfHeads; h++) {
float a = flashAttnHeads[i][h];
if (a > flashAttn_m.get(h)) {
//VALUE OFFSET (since cache is k + v)
float e = (float) Math.exp(flashAttn_m.get(h) - a);
TensorOperationsProvider.get().sxpby(e, kk, value, (h * headSize), h * headSize, headSize);
TensorOperationsProvider.get().sxpby(e, vv, value, (h * headSize), h * headSize, headSize);
flashAttn_l.set(1 + e * flashAttn_l.get(h), h);
flashAttn_m.set(a, h);
} else {
//VALUE OFFSET (since cache is k + v)
float e = (float) Math.exp(a - flashAttn_m.get(h));
TensorOperationsProvider.get().saxpy(e, kk, value, (h * headSize), h * headSize, headSize);
TensorOperationsProvider.get().saxpy(e, vv, value, (h * headSize), h * headSize, headSize);
flashAttn_l.set(flashAttn_l.get(h) + e, h);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
package com.github.tjake.jlama.model;

import com.github.tjake.jlama.math.VectorMath;
import com.github.tjake.jlama.tensor.AbstractTensor;
import com.github.tjake.jlama.tensor.operations.TensorOperations;
import com.github.tjake.jlama.tensor.operations.TensorOperationsProvider;

import java.util.Arrays;
import java.util.Optional;
import java.util.stream.IntStream;

public class TransformerBlock {
private final AbstractModel model;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ final public int size() {
/** Set a value at the given coordinates */
public abstract void set(float v, int ...dims);

/** Get a slice of the tensor along the given dimension */
public AbstractTensor slice(int ...dims) {
return slice(false, dims);
}

/** Get a slice of the tensor along the given dimension */
public AbstractTensor slice(boolean cacheInnerSlice, int ...dims) {
Preconditions.checkArgument(dims.length < shape.length, "Too many dimensions specified for tensor");

if (dims.length == 1 && sliceCache != null && sliceCache[dims[0]] != null)
Expand All @@ -98,7 +102,7 @@ public AbstractTensor slice(int ...dims) {
for (int i = 0; i < slicedShape.length; i++)
length *= slicedShape[i];

AbstractTensor r = this.make(totalOffset, length, slicedShape, false);
AbstractTensor r = this.make(totalOffset, length, slicedShape, cacheInnerSlice);
if (dims.length == 1 && sliceCache != null)
sliceCache[dims[0]] = r;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public void GPT2Run() throws IOException {

@Test
public void LlamaRun() throws Exception {
String modelPrefix = "models/Llama-2-7b-chat-hf";
String modelPrefix = "models/Llama-2-7b-chat-hf-jlama-Q4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));
try (WeightLoader weights = SafeTensorSupport.loadWeights(Path.of(modelPrefix).toFile())) {
LlamaTokenizer tokenizer = new LlamaTokenizer(Paths.get(modelPrefix));
Expand Down

0 comments on commit 90c29d3

Please sign in to comment.