Skip to content

Commit

Permalink
goroutines
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolaydubina committed Jul 29, 2023
1 parent 55bac4f commit cf293f0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 39 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ While they were eating, Timmy's dad came in and said, "Hey Timmy, do you want to

### Performance

| model | llama2.c | llama2.go
| --------------- | ---------------- | ----------------
| stories42M.bin | 32.759507 tok/s | 19.951497 tok/s
| stories110M.bin | 11.304695 tok/s | 7.146943 tok/s
| model | llama2.c | llama2.go
| --------------- | ----------------- | ----------------
| stories42M.bin | 265.348595 tok/s | 25.677383 tok/s
| stories110M.bin | 101.837061 tok/s | 10.474615 tok/s

### Related Work

* https://github.com/karpathy/llama2.c
* https://github.com/poudels14/llama2_rs
* https://github.com/gotzmann/llama.go
* https://github.com/saracen/llama2.go (there: `mmap`; no 3rd party; go routines; single file; prompt; profiling)
* https://github.com/tmc/go-llama2 (there: fork; slices, no 3rd party; single file)
* https://github.com/haormj/llama2.go (there: slices; 3rd party; cobra; makefile; single file)
* https://github.com/saracen/llama2.go (there: `mmap`; no 3rd party; go routines; single file)
* https://github.com/gotzmann/llama.go
79 changes: 46 additions & 33 deletions llama2/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package llama2

import (
"math"
"sync"

"github.com/nikolaydubina/llama2.go/nn"
)
Expand Down Expand Up @@ -33,6 +34,8 @@ type TransformerWeights struct {
}

func Transformer(token int, pos int, config Config, s RunState, w TransformerWeights) {
var wg sync.WaitGroup

// a few convenience variables
x := s.X
dim := config.Dim
Expand All @@ -52,9 +55,11 @@ func Transformer(token int, pos int, config Config, s RunState, w TransformerWei
nn.RMSNorm(s.XB, x, w.RMSAttentionWeight[l*dim:((l+1)*dim)])

// qkv matmuls for this position
nn.MatMul(s.Q, s.XB, w.WQ[l*dim*dim:(l+1)*dim*dim])
nn.MatMul(s.K, s.XB, w.WK[l*dim*dim:(l+1)*dim*dim])
nn.MatMul(s.V, s.XB, w.WV[l*dim*dim:(l+1)*dim*dim])
wg.Add(3)
go func() { nn.MatMul(s.Q, s.XB, w.WQ[l*dim*dim:(l+1)*dim*dim]); wg.Done() }()
go func() { nn.MatMul(s.K, s.XB, w.WK[l*dim*dim:(l+1)*dim*dim]); wg.Done() }()
go func() { nn.MatMul(s.V, s.XB, w.WV[l*dim*dim:(l+1)*dim*dim]); wg.Done() }()
wg.Wait()

// apply RoPE rotation to the q and k vectors for each head
for h := 0; h < config.NumHeads; h++ {
Expand Down Expand Up @@ -82,41 +87,47 @@ func Transformer(token int, pos int, config Config, s RunState, w TransformerWei
copy(valCacheRow, s.V)

// multithread attention. iterate over all heads
// C code had pragma here
// C code had pragma here, using goroutines
wg.Add(config.NumHeads)
for h := 0; h < config.NumHeads; h++ {
// get the query vector for this head
q := s.Q[(h * headSize):((h + 1) * headSize)]
// attention scores for this head
att := s.Att[(h * config.SeqLen):((h + 1) * config.SeqLen)]
// iterate over all timesteps, including the current one
for t := 0; t <= pos; t++ {
// get the key vector for this head and at this timestamp
k := s.KeyCache[(loff + t*dim + h*headSize):(loff + (t+1)*dim + h*headSize)]
// calculate the attention score as the dot product of q and k
var score float32
for i := 0; i < headSize; i++ {
score += q[i] * k[i]
go func(h int) {
defer wg.Done()

// get the query vector for this head
q := s.Q[(h * headSize):((h + 1) * headSize)]
// attention scores for this head
att := s.Att[(h * config.SeqLen):((h + 1) * config.SeqLen)]
// iterate over all timesteps, including the current one
for t := 0; t <= pos; t++ {
// get the key vector for this head and at this timestamp
k := s.KeyCache[(loff + t*dim + h*headSize):(loff + (t+1)*dim + h*headSize)]
// calculate the attention score as the dot product of q and k
var score float32
for i := 0; i < headSize; i++ {
score += q[i] * k[i]
}
score /= float32(math.Sqrt(float64(headSize)))
// save the score to the attention buffer
att[t] = score
}
score /= float32(math.Sqrt(float64(headSize)))
// save the score to the attention buffer
att[t] = score
}

// softmax the scores to get attention weights, from 0..pos inclusively
nn.SoftMax(att[:pos+1])
// softmax the scores to get attention weights, from 0..pos inclusively
nn.SoftMax(att[:pos+1])

// weighted sum of the values, store back into xb
// llama2.c uses memset. resetting to zero in loop is ok since it is next iterated over same slice anyways.
for i := 0; i < headSize; i++ {
s.XB[(h*headSize + i)] = 0
}
for t := 0; t <= pos; t++ {
a := att[t]
// weighted sum of the values, store back into xb
// llama2.c uses memset. resetting to zero in loop is ok since it is next iterated over same slice anyways.
for i := 0; i < headSize; i++ {
s.XB[((h * headSize) + i)] += a * s.ValCache[loff+t*dim+h*headSize+i]
s.XB[(h*headSize + i)] = 0
}
}
for t := 0; t <= pos; t++ {
a := att[t]
for i := 0; i < headSize; i++ {
s.XB[((h * headSize) + i)] += a * s.ValCache[loff+t*dim+h*headSize+i]
}
}
}(h)
}
wg.Wait()

// final matmul to get the output of the attention
nn.MatMul(s.XB2, s.XB, w.WO[l*dim*dim:(l+1)*dim*dim])
Expand All @@ -129,8 +140,10 @@ func Transformer(token int, pos int, config Config, s RunState, w TransformerWei

// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
// first calculate self.w1(x) and self.w3(x)
nn.MatMul(s.HB, s.XB, w.W1[l*dim*hiddenDim:(l+1)*dim*hiddenDim])
nn.MatMul(s.HB2, s.XB, w.W3[l*dim*hiddenDim:(l+1)*dim*hiddenDim])
wg.Add(2)
go func() { nn.MatMul(s.HB, s.XB, w.W1[l*dim*hiddenDim:(l+1)*dim*hiddenDim]); wg.Done() }()
go func() { nn.MatMul(s.HB2, s.XB, w.W3[l*dim*hiddenDim:(l+1)*dim*hiddenDim]); wg.Done() }()
wg.Wait()

// F.silu; silu(x)=x*σ, where σ(x) is the logistic sigmoid
for i := 0; i < hiddenDim; i++ {
Expand Down

0 comments on commit cf293f0

Please sign in to comment.