Skip to content

Inference Llama 2 in one file of pure Zig

License

Notifications You must be signed in to change notification settings

jingsupo/llama2.zig

 
 

Repository files navigation

llama2.zig

Cute Llama

This is the Zig version of llama2.c by Andrej Karpathy. It runs inference for the llama2 model architecture recently published by Meta.

It currently supports:

  • Inference of llama2 model checkpoints
  • Temperature control
  • Top-p (nucleus) sampling
  • Prompt handling (bpe tokenization)
  • Sequence length control
  • Custom tokenizers
  • Multiquery support
  • Running really fast

The ultimate goal is to create a fast, portable, and user-friendly implementation of the llama2 model architecture. The code prioritizes simplicity and readability without sacrificing performance. Certain core functions have SIMD implementations using the Zig @Vector feature, which provides a ~5x speed increase. For more details and comparisons to other implementation, please refer to the performance section.

The stories15.bin file is a model checkpoint for a 15M parameter model that was trained on the tiny stories dataset. The method for generating this file can be found in the llama2.c repo.

Usage

After cloning the repo, run the following command for inference:

zig build -Doptimize=ReleaseFast
zig-out/bin/llama2 stories15M.bin

A prompt can be provided as an argument to the program:

llama2 stories15M.bin -i "Once upon a time"

For all of the options, run:

$ llama2 --help
Usage:   llama2 <checkpoint> [options]
Example: llama2 checkpoint.bin -n 256 -i "Once upon a time"
Options:
 -h, --help                print this help message
 -t, --temperature <float> temperature, default 1.0 (0.0, 1]
 -p, --top-p <float>       p value in top-p (nucleus) sampling. default 1.0, 0 || 1 = off
 -n, --seq-len <int>       number of steps to run for, default 256. 0 = max_seq_len
 -i, --input <string>      input text for the prompt, default ""
 -v, --verbose             print model info and tokens/s

Performance

The benchmarks provided below were executed on an AMD Ryzen 9 5900X 12-Core Processor. All speeds measurements are taken using the stories15M.bin checkpoint file.

If you have an implementation you want to add to the table, please open an issue and I'll be happy to add it. Please only submit implementations that are single language implementations (no OpenBlas, etc.).

Single-threaded

Argmax sampling

  • Temperature 0.0
  • 256 tokens
Implementation Tokens/s
llama2.zig (this repo) 660
llama2.c make runfast -march=native 548
llama2.zig 496
llama2.c make run -march=native 122
llama2.rs 115

Top-P sampling

  • Temperature 1.0
  • Top-P 0.9
  • 256 tokens
Implementation Tokens/s
llama2.zig (this repo) 579
llama2.c make runfast -march=native 463

Multi-threaded

This implementation currently does not support multithreading so is not included in the table below. This is with temperate 0.9 and no top-p.

Implementation Tokens/s
llama2.c make runomp 1564
llama2.rs 441

llama2.zig (this repo)

The single largest speed increase came from writing a SIMD version of matrix multiplication using the Zig @Vector feature. This was an immediate jump from around 115 tokens/s to 430 tokens/s. Notable speed increases also came from:

  • comptime magic to generate fused matrix multiplication
  • Vector aligned memory allocation
  • Using SIMD versions of other core functions
llama2 stories15M.bin -t 0
llama2 stories15M.bin -t 1.0 -p 0.9
llama2 stories15M.bin -t 1.0 -p 0.9 -i "Once upon a time"
zig version -> 0.11.0-dev.4315+f5239677e

llama2.c

 ./run stories15M.bin -t 1.0 -p 0.9
 ./run stories15M.bin -t 0.0
CC = gcc

.PHONY: runfast
runfast: run.c
	$(CC) -Ofast -o run run.c -lm -march=native


.PHONY: run
run: run.c
	$(CC) -O3 -o run run.c -lm -march=native


.PHONY: runomp
runomp: run.c
	$(CC) -Ofast -fopenmp -march=native run.c  -lm  -o run -march=native

llama2.rs

 RUSTFLAGS="-C target-cpu=native" cargo run -r -- stories15M.bin 0.9
 RUSTFLAGS="-C target-cpu=native" cargo run -r -F parallel -- stories15M.bin 0.9
[profile.release]
codegen-units = 1
lto = true
panic = "abort"
strip = "symbols"

Todo

  • Parallelize multi-head attention process
  • Add support for multi-threading (this is not going well)
  • Try top-p sampling by doing multiple linear scans to avoid sorting
  • binary search the token encoder, probably not necessary
  • Add benchmarks for smaller model and tokenizer
  • On very small models (like ~100K params) the speed is slower than the original llama2.c. Figure out why. expf seems to take a lot of time. GCC also doing much better than clang. Think floating point reordering is helping a lot there.

Contributing

Any form of contribution is welcome. Feel free to open an issue or create a pull request. If you are contributing optimizations, please provide benchmarks and/or performance comparisons as well as the code to reproduce them.

Credits

  • Andrej Karpathy for the original llama2.c implementation
  • Foundation42 for opening a PR on the llama2.c repo that was the inspiration for aligned memory allocation and fused matrix multiplication.
  • jrudolph for top-p sampling optimization PR

About

Inference Llama 2 in one file of pure Zig

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Zig 100.0%