-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from MurrellGroup/dim-swaps
Dim swaps
- Loading branch information
Showing
9 changed files
with
497 additions
and
321 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,3 +4,4 @@ | |
/docs/Manifest.toml | ||
/docs/build/ | ||
/Manifest.toml | ||
.CondaPkg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,136 @@ | ||
# Jjama3 - Hackable Llama3.1 and Llama3.2 (text) in Julia | ||
|
||
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/stable/) | ||
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/dev/) | ||
[![Build Status](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml?query=branch%3Amain) | ||
[![Coverage](https://codecov.io/gh/MurrellGroup/Jjama3.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/Jjama3.jl) | ||
|
||
# Installation | ||
## Installation | ||
|
||
|
||
We've split this into a few (unregistered) packages, so you'll need to add them all, and you need JSON3 for loading the configs: | ||
```julia | ||
] add JSON3 | ||
] add https://github.com/MurrellGroup/HuggingFaceTokenizers.jl | ||
] add https://github.com/MurrellGroup/LowRankLayers.jl | ||
] add https://github.com/MurrellGroup/LogitSamplers.jl | ||
] add https://github.com/MurrellGroup/Jjama3.jl | ||
``` | ||
|
||
# Quickstart | ||
## Quickstart | ||
|
||
Download a Llama3 model `config.json` and safetensor weights from Huggingface. Eg. [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct). You might need access permissions for this. Note: Huggingface use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original weights you'll need to permute them. | ||
Download a Llama3 model `config.json`, `tokenizer.json`, and model safetensor weights from Hugging Face. Eg. [SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct/tree/main). Note: Hugging Face Llama3 models use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original Meta-Llama weights from a different source you'll need to do something horrible. | ||
|
||
```julia | ||
config = JSON3.read(read("Llama3_2_1B_instruct/config.json", String)); | ||
model = load_llama3_from_safetensors("Llama3_2_1B_instruct/model.safetensors", config); | ||
tkn = llama3_tokenizer(); | ||
prompt = assistant_prompt("Why would anyone implement the llama3 LLM in Julia?", tkn); | ||
ts = generate(model, prompt, max_new_tokens=500, encoder_for_printing=tkn); | ||
using JSON3, Jjama3 | ||
|
||
config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String)) | ||
model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config) | ||
tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json") | ||
|
||
prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.") | ||
generate(model, prompt, | ||
max_new_tokens=500, | ||
tokenizer_for_printing=tkn, | ||
end_token = encode(tkn, "<|im_end|>")[end]); | ||
``` | ||
|
||
# Capability | ||
## Capability | ||
|
||
- Seems to generate reasonable text from Llama3.1 and Llama3.2 models, loaded from Huggingface safetensors. | ||
- Sampling accelerated with KV caching, with argmax and top-p sampling supported. | ||
- Gradients seem to work on CPU, using Flux and Zygote. Untested on GPU. | ||
- Sampling (and forward passes) work with CUDA, where everything is much faster. Gradients untested. | ||
- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is much slower with Metal than with CPU. | ||
- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is slower with Metal than with CPU. | ||
|
||
|
||
## Samplers | ||
|
||
The transformer emits "logits" which control the probability of the next token. A sampler takes these logits and converts them into a probability distribution over the vocabulary, and then samples from this distribution. There are [a few samplers available](https://github.com/MurrellGroup/LogitSamplers.jl), including argmax, top-p, top-k, min-p, and top-nσ. These can substantially affect the output of the model. | ||
|
||
```julia | ||
prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python."); | ||
|
||
generate(model, prompt, | ||
max_new_tokens=500, | ||
tokenizer_for_printing=tkn, | ||
end_token = encode(tkn, "<|im_end|>")[end], | ||
sampler = top_nσ_sampler()); | ||
``` | ||
|
||
## Structured Sampling | ||
|
||
You can pass in a custom sampler that places additional constraints on the sampling process. As an example, `structured_choice` is a sampler that always selects from a set of predefined options: | ||
|
||
```julia | ||
question = "In a Bayesian model, what do we call the probability distribution of parameters given the data?" | ||
choices = ["Prior", | ||
"Likelihood", | ||
"Marginal Likelihood", | ||
"Evidence", | ||
"Posterior"] | ||
|
||
vocab = [decode(tkn, [i], skip_special_tokens = false) for i in 1:size(model.output.weight,1)] | ||
eos = encode(tkn, "<|im_end|>")[end] | ||
|
||
sysprompt = "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible." | ||
prompt = smollm2_instruct_prompt(tkn, sysprompt, question) | ||
|
||
generate(model, prompt, | ||
max_new_tokens=100, | ||
tokenizer_for_printing=tkn, | ||
end_token = eos, | ||
sampler = structured_choice(choices, vocab, eos)); | ||
``` | ||
|
||
This strategy can be extended to force the model outputs to follow specific formats. | ||
|
||
## Finetuning | ||
|
||
Often we want to adjust model parameters to better fit our specific use case, by further training the model on a new dataset. This can be done on all the model weights, but we also provide low-rank (via LoRA) finetuning. | ||
|
||
```julia | ||
using Jjama3, JSON3, Flux | ||
|
||
config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String)) | ||
tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json") | ||
eos = encode(tkn, "<|im_end|>")[end] | ||
|
||
#Add LoRA to Q and V matrices when loading the model | ||
model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config, | ||
add_lora_to = [:Q, :V], lora_dim = 64) | ||
|
||
#See how the model answers before finetuning | ||
prompt = smollm2_assistant_prompt(tkn, "What language is the best for deep learning?"); | ||
generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = eos); | ||
|
||
#Set up a single, very silly, training example to finetune on | ||
ugh = "Ugh, bruh, what a stupid question.<|im_end|>" | ||
trainsample = decode(tkn, prompt, skip_special_tokens = false) * ugh; | ||
train_toks = encode(tkn, trainsample); | ||
|
||
#Set up the optimizer | ||
opt_state = Flux.setup(AdamW(0.001f0), model); | ||
|
||
#Train for 5 steps, monitoring the model's output as it tunes | ||
for i in 1:5 | ||
grads = Flux.gradient(model) do m | ||
forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:]) | ||
end | ||
Flux.update!(opt_state, model, grads[1]) | ||
println(i) | ||
generate(model, prompt, | ||
max_new_tokens=50, | ||
tokenizer_for_printing=tkn, | ||
end_token = eos) | ||
println() | ||
end | ||
|
||
#Ask the model an unrelated question to see how stupid we've made the model. Try this a few times. | ||
prompt = smollm2_assistant_prompt(tkn, "Explain how tides work?"); | ||
generate(model, prompt, | ||
max_new_tokens=500, | ||
tokenizer_for_printing=tkn, | ||
end_token = eos, | ||
sampler = top_nσ_sampler()); | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,44 @@ | ||
module Jjama3 | ||
|
||
using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib | ||
using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib | ||
using LogitSamplers, LowRankLayers | ||
import HuggingFaceTokenizers | ||
|
||
|
||
const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained | ||
const tokenizer_from_file = HuggingFaceTokenizers.from_file | ||
const Tokenizer = HuggingFaceTokenizers.Tokenizer | ||
|
||
const top_pk_sampler = LogitSamplers.top_pk_sampler | ||
const argmax_sampler = LogitSamplers.argmax_sampler | ||
const min_p_sampler = LogitSamplers.min_p_sampler | ||
const top_nσ_sampler = LogitSamplers.top_nσ_sampler | ||
|
||
|
||
|
||
include("layers.jl") | ||
include("model.jl") | ||
include("utils.jl") | ||
include("sampling.jl") | ||
|
||
export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference, top_pk_sampler, argmax_sampler | ||
export load_llama321B_from_safetensors, | ||
load_llama3_from_safetensors, | ||
generate, | ||
forward_loss, | ||
forward_inference, | ||
top_pk_sampler, | ||
argmax_sampler, | ||
top_nσ_sampler, | ||
min_p_sampler, | ||
tokenizer_from_repo, | ||
tokenizer_from_file, | ||
encode, | ||
decode, | ||
Tokenizer, | ||
llama3_instruct_prompt, | ||
llama3_assistant_prompt, | ||
smollm2_instruct_prompt, | ||
smollm2_assistant_prompt, | ||
structured_choice | ||
|
||
end |
Oops, something went wrong.