-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Reactant compatibility * Update Project.toml * Refactor * Refactor and fixes * rm ReactantCore * Fixes, conditional caching
- Loading branch information
1 parent
0b50d41
commit 1f44b58
Showing
9 changed files
with
276 additions
and
356 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
name = "Jjama3" | ||
uuid = "1285d783-1a6d-4703-8f05-8ac83ef55592" | ||
authors = ["murrellb <[email protected]> and contributors"] | ||
version = "1.0.0-DEV" | ||
version = "1.1.0-DEV" | ||
|
||
[deps] | ||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18" | ||
|
@@ -19,13 +20,13 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962" | |
|
||
[sources] | ||
HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"} | ||
LogitSamplers = {rev = "main", url = "https://github.com/MurrellGroup/LogitSamplers.jl"} | ||
LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"} | ||
|
||
[extensions] | ||
MetalExt = "Metal" | ||
|
||
[compat] | ||
Accessors = "0.1.38" | ||
Distributions = "0.25" | ||
Flux = "0.14" | ||
LogitSamplers = "0.1" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,53 @@ | ||
module Jjama3 | ||
|
||
using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib | ||
using LogitSamplers, LowRankLayers | ||
import HuggingFaceTokenizers | ||
using Flux | ||
using SafeTensors | ||
using Distributions | ||
using LinearAlgebra | ||
using StatsBase | ||
using NNlib | ||
using LogitSamplers | ||
using LowRankLayers | ||
|
||
using HuggingFaceTokenizers: HuggingFaceTokenizers, Tokenizer | ||
|
||
const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained | ||
const tokenizer_from_file = HuggingFaceTokenizers.from_file | ||
const Tokenizer = HuggingFaceTokenizers.Tokenizer | ||
|
||
const top_pk_sampler = LogitSamplers.top_pk_sampler | ||
const argmax_sampler = LogitSamplers.argmax_sampler | ||
const min_p_sampler = LogitSamplers.min_p_sampler | ||
const top_nσ_sampler = LogitSamplers.top_nσ_sampler | ||
|
||
|
||
include("cache.jl") | ||
export KVCache | ||
|
||
include("layers.jl") | ||
export FeedForward | ||
export RMSNorm | ||
export RoPE | ||
export Attention | ||
export TransformerBlock | ||
export Transformer | ||
|
||
include("model.jl") | ||
include("utils.jl") | ||
export forward_loss | ||
export forward_inference | ||
|
||
include("sampling.jl") | ||
export top_pk_sampler | ||
export argmax_sampler | ||
export top_nσ_sampler | ||
export min_p_sampler | ||
export generate | ||
export tokenizer_from_repo | ||
export tokenizer_from_file | ||
export Tokenizer | ||
|
||
export load_llama321B_from_safetensors, | ||
load_llama3_from_safetensors, | ||
generate, | ||
forward_loss, | ||
forward_inference, | ||
top_pk_sampler, | ||
argmax_sampler, | ||
top_nσ_sampler, | ||
min_p_sampler, | ||
tokenizer_from_repo, | ||
tokenizer_from_file, | ||
encode, | ||
decode, | ||
Tokenizer, | ||
llama3_instruct_prompt, | ||
llama3_assistant_prompt, | ||
smollm2_instruct_prompt, | ||
smollm2_assistant_prompt, | ||
structured_choice | ||
include("utils.jl") | ||
export encode | ||
export decode | ||
export load_llama321B_from_safetensors | ||
export load_llama3_from_safetensors | ||
export llama3_instruct_prompt | ||
export llama3_assistant_prompt | ||
export smollm2_instruct_prompt | ||
export smollm2_assistant_prompt | ||
export structured_choice | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
mutable struct KVCache{T,A<:AbstractArray{T,4}} | ||
cache_k::A | ||
cache_v::A | ||
end | ||
|
||
Flux.@layer KVCache | ||
|
||
head_dim(cache::KVCache) = size(cache.cache_k, 1) | ||
seq_length(cache::KVCache) = size(cache.cache_k, 2) | ||
n_kv_heads(cache::KVCache) = size(cache.cache_k, 3) | ||
batch_size(cache::KVCache) = size(cache.cache_k, 4) | ||
|
||
function KVCache(T; head_dim, seq_length=0, n_kv_heads, batch_size=1) | ||
cache_k = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) | ||
cache_v = zeros(T, head_dim, seq_length, n_kv_heads, batch_size) | ||
return KVCache(cache_k, cache_v) | ||
end | ||
|
||
function config!(cache::KVCache; seq_length=seq_length(cache), batch_size=batch_size(cache)) | ||
cache.cache_k = similar(cache.cache_k, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0 | ||
cache.cache_v = similar(cache.cache_v, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0 | ||
end | ||
|
||
clear!(cache::KVCache) = config!(cache, seq_length=0) | ||
|
||
function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray) | ||
if iszero(seq_length(cache)) | ||
println("fuck") | ||
return xk, xv | ||
else | ||
seqlen = size(xk, 2) | ||
cache.cache_k[:, start_pos+1:start_pos+seqlen, :, :] .= xk | ||
cache.cache_v[:, start_pos+1:start_pos+seqlen, :, :] .= xv | ||
return cache.cache_k[:, 1:start_pos+seqlen, :, :], | ||
cache.cache_v[:, 1:start_pos+seqlen, :, :] | ||
end | ||
end |
Oops, something went wrong.