Skip to content

Commit

Permalink
Merge pull request #8 from MurrellGroup/dim-swaps
Browse files Browse the repository at this point in the history
Dim swaps
  • Loading branch information
murrellb authored Nov 26, 2024
2 parents 39dc60c + 15a3617 commit 5836186
Show file tree
Hide file tree
Showing 9 changed files with 497 additions and 321 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
/docs/Manifest.toml
/docs/build/
/Manifest.toml
.CondaPkg
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,30 @@ version = "1.0.0-DEV"
BytePairEncoding = "a4280ba5-8788-555a-8ca8-4a8c3d966a71"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceTokenizers = "a6888d44-1185-43bb-bd0f-7806f9976d18"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogitSamplers = "1b30fcfc-0ee9-4be2-9cfe-b2289b43e041"
LowRankLayers = "b66182ab-a85c-43b0-99bd-d85cc47c5e50"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SafeTensors = "eeda0dda-7046-4914-a807-2495fc7abb89"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[sources]
HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"}
LogitSamplers = {rev = "main", url = "https://github.com/MurrellGroup/LogitSamplers.jl"}
LowRankLayers = {rev = "main", url = "https://github.com/MurrellGroup/LowRankLayers.jl"}

[extensions]
MetalExt = "Metal"

[compat]
BytePairEncoding = "0.5"
Distributions = "0.25"
Flux = "0.14"
LowRankLayers = "1.0.0"
Metal = "1"
NNlib = "0.9"
SafeTensors = "1"
Expand Down
126 changes: 115 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,32 +1,136 @@
# Jjama3 - Hackable Llama3.1 and Llama3.2 (text) in Julia

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/stable/)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://MurrellGroup.github.io/Jjama3.jl/dev/)
[![Build Status](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/MurrellGroup/Jjama3.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/MurrellGroup/Jjama3.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/MurrellGroup/Jjama3.jl)

# Installation
## Installation


We've split this into a few (unregistered) packages, so you'll need to add them all, and you need JSON3 for loading the configs:
```julia
] add JSON3
] add https://github.com/MurrellGroup/HuggingFaceTokenizers.jl
] add https://github.com/MurrellGroup/LowRankLayers.jl
] add https://github.com/MurrellGroup/LogitSamplers.jl
] add https://github.com/MurrellGroup/Jjama3.jl
```

# Quickstart
## Quickstart

Download a Llama3 model `config.json` and safetensor weights from Huggingface. Eg. [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct). You might need access permissions for this. Note: Huggingface use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original weights you'll need to permute them.
Download a Llama3 model `config.json`, `tokenizer.json`, and model safetensor weights from Hugging Face. Eg. [SmolLM2-360M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-360M-Instruct/tree/main). Note: Hugging Face Llama3 models use a different RoPE convention to the original Meta implementation, and their weights have been permuted. This package works with the Huggingface convention, so if you load from the original Meta-Llama weights from a different source you'll need to do something horrible.

```julia
config = JSON3.read(read("Llama3_2_1B_instruct/config.json", String));
model = load_llama3_from_safetensors("Llama3_2_1B_instruct/model.safetensors", config);
tkn = llama3_tokenizer();
prompt = assistant_prompt("Why would anyone implement the llama3 LLM in Julia?", tkn);
ts = generate(model, prompt, max_new_tokens=500, encoder_for_printing=tkn);
using JSON3, Jjama3

config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String))
model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config)
tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json")

prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.")
generate(model, prompt,
max_new_tokens=500,
tokenizer_for_printing=tkn,
end_token = encode(tkn, "<|im_end|>")[end]);
```

# Capability
## Capability

- Seems to generate reasonable text from Llama3.1 and Llama3.2 models, loaded from Huggingface safetensors.
- Sampling accelerated with KV caching, with argmax and top-p sampling supported.
- Gradients seem to work on CPU, using Flux and Zygote. Untested on GPU.
- Sampling (and forward passes) work with CUDA, where everything is much faster. Gradients untested.
- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is much slower with Metal than with CPU.
- Metal acceleration for forward_inference and forward_loss. Gradients untested. Sampling works, but is slower with Metal than with CPU.


## Samplers

The transformer emits "logits" which control the probability of the next token. A sampler takes these logits and converts them into a probability distribution over the vocabulary, and then samples from this distribution. There are [a few samplers available](https://github.com/MurrellGroup/LogitSamplers.jl), including argmax, top-p, top-k, min-p, and top-nσ. These can substantially affect the output of the model.

```julia
prompt = smollm2_assistant_prompt(tkn,"Tell me the two worst things about Python.");

generate(model, prompt,
max_new_tokens=500,
tokenizer_for_printing=tkn,
end_token = encode(tkn, "<|im_end|>")[end],
sampler = top_nσ_sampler());
```

## Structured Sampling

You can pass in a custom sampler that places additional constraints on the sampling process. As an example, `structured_choice` is a sampler that always selects from a set of predefined options:

```julia
question = "In a Bayesian model, what do we call the probability distribution of parameters given the data?"
choices = ["Prior",
"Likelihood",
"Marginal Likelihood",
"Evidence",
"Posterior"]

vocab = [decode(tkn, [i], skip_special_tokens = false) for i in 1:size(model.output.weight,1)]
eos = encode(tkn, "<|im_end|>")[end]

sysprompt = "You are an expert in Statistics and Probability Theory who answers questions in as few words as possible."
prompt = smollm2_instruct_prompt(tkn, sysprompt, question)

generate(model, prompt,
max_new_tokens=100,
tokenizer_for_printing=tkn,
end_token = eos,
sampler = structured_choice(choices, vocab, eos));
```

This strategy can be extended to force the model outputs to follow specific formats.

## Finetuning

Often we want to adjust model parameters to better fit our specific use case, by further training the model on a new dataset. This can be done on all the model weights, but we also provide low-rank (via LoRA) finetuning.

```julia
using Jjama3, JSON3, Flux

config = JSON3.read(read("SmolLM2-360M-Instruct/config.json", String))
tkn = tokenizer_from_file(Tokenizer, "SmolLM2-360M-Instruct/tokenizer.json")
eos = encode(tkn, "<|im_end|>")[end]

#Add LoRA to Q and V matrices when loading the model
model = load_llama3_from_safetensors("SmolLM2-360M-Instruct/model.safetensors", config,
add_lora_to = [:Q, :V], lora_dim = 64)

#See how the model answers before finetuning
prompt = smollm2_assistant_prompt(tkn, "What language is the best for deep learning?");
generate(model, prompt, max_new_tokens=50, tokenizer_for_printing=tkn, end_token = eos);

#Set up a single, very silly, training example to finetune on
ugh = "Ugh, bruh, what a stupid question.<|im_end|>"
trainsample = decode(tkn, prompt, skip_special_tokens = false) * ugh;
train_toks = encode(tkn, trainsample);

#Set up the optimizer
opt_state = Flux.setup(AdamW(0.001f0), model);

#Train for 5 steps, monitoring the model's output as it tunes
for i in 1:5
grads = Flux.gradient(model) do m
forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:])
end
Flux.update!(opt_state, model, grads[1])
println(i)
generate(model, prompt,
max_new_tokens=50,
tokenizer_for_printing=tkn,
end_token = eos)
println()
end

#Ask the model an unrelated question to see how stupid we've made the model. Try this a few times.
prompt = smollm2_assistant_prompt(tkn, "Explain how tides work?");
generate(model, prompt,
max_new_tokens=500,
tokenizer_for_printing=tkn,
end_token = eos,
sampler = top_nσ_sampler());
```

11 changes: 10 additions & 1 deletion ext/MetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MetalExt
#Note: Metal speeds things up a little for forward_inference and forward_loss calls, but is VERY slow for sampling.
#It seems that each single Metal call has some constant overhead that kills it.

using Metal, Jjama3.NNlib
using Metal, NNlib

function NNlib.batched_mul(a::MtlArray, b::MtlArray)
a_shape = size(a)
Expand All @@ -15,4 +15,13 @@ function NNlib.batched_mul(a::MtlArray, b::MtlArray)
return reshape(res, a_shape[1], b_shape[2], a_shape[3:end]...)
end

function NNlib.PermutedDimsArray(a::MtlArray, perm)
return permutedims(a, perm)
end

function NNlib.batched_transpose(a::MtlArray)
dims = size(a)
return permutedims(a, (2,1,3:length(dims)...))
end

end
37 changes: 35 additions & 2 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,44 @@
module Jjama3

using Flux, BytePairEncoding, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib
using Flux, SafeTensors, Distributions, LinearAlgebra, StatsBase, NNlib
using LogitSamplers, LowRankLayers
import HuggingFaceTokenizers


const tokenizer_from_repo = HuggingFaceTokenizers.from_pretrained
const tokenizer_from_file = HuggingFaceTokenizers.from_file
const Tokenizer = HuggingFaceTokenizers.Tokenizer

const top_pk_sampler = LogitSamplers.top_pk_sampler
const argmax_sampler = LogitSamplers.argmax_sampler
const min_p_sampler = LogitSamplers.min_p_sampler
const top_nσ_sampler = LogitSamplers.top_nσ_sampler



include("layers.jl")
include("model.jl")
include("utils.jl")
include("sampling.jl")

export load_llama321B_from_safetensors, load_llama3_from_safetensors, llama3_tokenizer, assistant_prompt, format_llama32_instruction_prompt, generate, forward_loss, forward_inference, top_pk_sampler, argmax_sampler
export load_llama321B_from_safetensors,
load_llama3_from_safetensors,
generate,
forward_loss,
forward_inference,
top_pk_sampler,
argmax_sampler,
top_nσ_sampler,
min_p_sampler,
tokenizer_from_repo,
tokenizer_from_file,
encode,
decode,
Tokenizer,
llama3_instruct_prompt,
llama3_assistant_prompt,
smollm2_instruct_prompt,
smollm2_assistant_prompt,
structured_choice

end
Loading

0 comments on commit 5836186

Please sign in to comment.