Skip to content

Commit

Permalink
Minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
clebert committed Oct 19, 2023
1 parent f560ee8 commit 96b1b5c
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 41 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Lily wanted to play with the ball, but it was too high up in the sky. She tried
Lily found a stick and tried to hit the ball. But the stick was too short. She tried again and again, but she couldn't reach it. She felt sad.
Suddenly, a kind man came by and saw Lily. He asked her what was wrong. Lily told him about the ball. The man smiled and said, "I have a useful idea!" He took out a long stick and used it to knock the ball down. Lily was so happy! She thanked the man and they played together in the sunshine.
achieved: 701.587 tok/s
achieved: 712.903 tok/s
```

## Run Llama 2 from Hugging Face
Expand Down
8 changes: 4 additions & 4 deletions src/attention.zig
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint, sequence_lengt
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
self.input_buffer.deinit();
self.output_buffer.deinit();
self.query_buffer.deinit();
Expand All @@ -76,7 +76,7 @@ pub fn deinit(self: *const Self) void {
self.allocator.free(self.scores);
}

pub fn forward(self: *const Self, layer: usize, position: usize) void {
pub fn forward(self: Self, layer: usize, position: usize) void {
const weights = self.checkpoint.weights;
const query_matrix = weights.attention_query_matrices.slice(layer);
const key_matrix = weights.attention_key_matrices.slice(layer);
Expand All @@ -99,7 +99,7 @@ pub fn forward(self: *const Self, layer: usize, position: usize) void {
}

// Rotary positional embeddings: https://arxiv.org/abs/2104.09864
fn computeRoPE(self: *const Self, position: usize, key_buffer: Tensor(2)) void {
fn computeRoPE(self: Self, position: usize, key_buffer: Tensor(2)) void {
@setFloatMode(.Optimized);

std.debug.assert(self.query_buffer.values.len % key_buffer.values.len == 0);
Expand Down Expand Up @@ -133,7 +133,7 @@ fn computeRoPE(self: *const Self, position: usize, key_buffer: Tensor(2)) void {
}

// Grouped-query attention: https://arxiv.org/abs/2305.13245v1
fn computeGQA(self: *const Self, layer: usize, current_position: usize, head: usize) void {
fn computeGQA(self: Self, layer: usize, current_position: usize, head: usize) void {
@setFloatMode(.Optimized);

const query_vector = self.query_buffer.slice(head);
Expand Down
2 changes: 1 addition & 1 deletion src/chat.zig
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pub fn init(allocator: std.mem.Allocator, args: ChatArgs) !Self {
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
self.transformer.deinit();
self.tokenizer.deinit();
self.sampler.deinit();
Expand Down
4 changes: 2 additions & 2 deletions src/checkpoint.zig
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub fn init(allocator: std.mem.Allocator, model_path: []const u8) !Self {
}

// https://github.com/karpathy/llama2.c/blob/d9862069e7ef665fe6309e3c17398ded2f121bf5/export.py#L132
pub fn writeV1(self: *const Self, allocator: std.mem.Allocator, model_path: []const u8) !void {
pub fn writeV1(self: Self, allocator: std.mem.Allocator, model_path: []const u8) !void {
const path = try std.fs.path.join(
allocator,
&[_][]const u8{ model_path, "checkpoint_v1.bin" },
Expand Down Expand Up @@ -403,7 +403,7 @@ fn readLegacy(allocator: std.mem.Allocator, file: std.fs.File) !Self {
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
self.weights.token_embedding_vectors.deinit();
self.weights.attention_norm_vectors.deinit();
self.weights.attention_query_matrices.deinit();
Expand Down
4 changes: 2 additions & 2 deletions src/ffn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ pub fn init(allocator: std.mem.Allocator, checkpoint: Checkpoint) !Self {
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
self.input_buffer.deinit();
self.gate_buffer.deinit();
self.hidden_buffer.deinit();
self.output_buffer.deinit();
}

// SwiGLU activation function: https://arxiv.org/abs/2002.05202
pub fn forward(self: *const Self, layer: usize) void {
pub fn forward(self: Self, layer: usize) void {
@setFloatMode(.Optimized);

const weights = self.checkpoint.weights;
Expand Down
2 changes: 1 addition & 1 deletion src/generator.zig
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub fn init(allocator: std.mem.Allocator, args: GeneratorArgs) !Self {
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
self.transformer.deinit();
self.tokenizer.deinit();
self.sampler.deinit();
Expand Down
8 changes: 4 additions & 4 deletions src/quantized_tensor.zig
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ pub fn QuantizedTensor(comptime n_dims: comptime_int) type {
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
if (self.allocator) |allocator| {
allocator.free(self.values);
allocator.free(self.scaling_factors);
}
}

pub fn slice(self: *const Self, index: usize) !QuantizedTensor(n_dims - 1) {
pub fn slice(self: Self, index: usize) !QuantizedTensor(n_dims - 1) {
comptime if (n_dims < 2) @compileError("n_dims < 2");

const n_sub_values = @reduce(.Mul, @as(@Vector(n_dims - 1, usize), self.sub_dims));
Expand All @@ -58,7 +58,7 @@ pub fn QuantizedTensor(comptime n_dims: comptime_int) type {
}

pub fn computeMatrixVectorMultiplication(
self: *const Self,
self: Self,
input: anytype,
output: anytype,
) !void {
Expand All @@ -67,7 +67,7 @@ pub fn QuantizedTensor(comptime n_dims: comptime_int) type {
}
}

fn computeScalarProduct(self: *const Self, other: anytype) !f32 {
fn computeScalarProduct(self: Self, other: anytype) !f32 {
// https://github.com/karpathy/llama2.c/pull/312#issuecomment-1684140683
if (self.group_size == 32) {
return _computeScalarProduct(32, self, other);
Expand Down
2 changes: 1 addition & 1 deletion src/sampler.zig
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub fn init(allocator: std.mem.Allocator, args: anytype, vocab_size: usize) !Sel
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
self.allocator.free(self.probability_index_pairs_buffer);
}

Expand Down
22 changes: 9 additions & 13 deletions src/tensor.zig
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,25 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
if (self.allocator) |allocator| {
allocator.free(self.values);
}
}

pub fn read(self: *const Self, file: std.fs.File) !void {
pub fn read(self: Self, file: std.fs.File) !void {
const values: [*]u8 = @ptrCast(self.values);

try file.reader().readNoEof(values[0 .. self.values.len * @sizeOf(f32)]);
}

pub fn write(self: *const Self, file: std.fs.File) !void {
pub fn write(self: Self, file: std.fs.File) !void {
const values: [*]u8 = @ptrCast(self.values);

try file.writer().writeAll(values[0 .. self.values.len * @sizeOf(f32)]);
}

pub fn slice(self: *const Self, index: usize) Tensor(n_dims - 1) {
pub fn slice(self: Self, index: usize) Tensor(n_dims - 1) {
comptime if (n_dims < 2) @compileError("n_dims < 2");

const n_sub_values = @reduce(.Mul, @as(@Vector(n_dims - 1, usize), self.sub_dims));
Expand All @@ -50,7 +50,7 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
};
}

pub fn add(self: *const Self, other: anytype) void {
pub fn add(self: Self, other: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(self.values.len == other.values.len);
Expand All @@ -60,17 +60,13 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
}
}

pub fn computeMatrixVectorMultiplication(
self: *const Self,
input: anytype,
output: anytype,
) void {
pub fn computeMatrixVectorMultiplication(self: Self, input: anytype, output: anytype) void {
for (output.values, 0..) |*value, index| {
value.* = self.slice(index).computeScalarProduct(&input);
value.* = self.slice(index).computeScalarProduct(input);
}
}

pub fn computeScalarProduct(self: *const Self, other: anytype) f32 {
pub fn computeScalarProduct(self: Self, other: anytype) f32 {
if (self.values.len % 32 == 0) {
return _computeScalarProduct(32, self, other);
}
Expand All @@ -87,7 +83,7 @@ pub fn Tensor(comptime n_dims: comptime_int) type {
}

// Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467
pub fn computeRMSNorm(self: *const Self, weight: anytype, output: anytype) void {
pub fn computeRMSNorm(self: Self, weight: anytype, output: anytype) void {
@setFloatMode(.Optimized);

std.debug.assert(output.values.len == self.values.len);
Expand Down
16 changes: 6 additions & 10 deletions src/tokenizer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub fn init(allocator: std.mem.Allocator, model_path: []const u8, vocab_size: us
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
for (self.vocab) |word| {
self.allocator.free(word);
}
Expand All @@ -64,11 +64,7 @@ pub fn deinit(self: *const Self) void {
self.allocator.free(self.sorted_vocab);
}

pub fn encode(
self: *const Self,
allocator: std.mem.Allocator,
text: []const u8,
) ![]usize {
pub fn encode(self: Self, allocator: std.mem.Allocator, text: []const u8) ![]usize {
var double_word_buffer = try allocator.alloc(u8, self.max_word_length * 2);

defer allocator.free(double_word_buffer);
Expand All @@ -90,14 +86,14 @@ pub fn encode(
return merged_tokens_copy;
}

pub fn decode(self: *const Self, token: usize, bos: bool) []const u8 {
pub fn decode(self: Self, token: usize, bos: bool) []const u8 {
const word = self.vocab[token];

// https://github.com/karpathy/llama2.c/blob/7ac65cb2c2b169050747be92011b7bebdd1b4544/run.c#L425
return if (bos and std.ascii.isWhitespace(word[0])) word[1..] else word;
}

fn encodeCodepoints(self: *const Self, allocator: std.mem.Allocator, text: []const u8) ![]usize {
fn encodeCodepoints(self: Self, allocator: std.mem.Allocator, text: []const u8) ![]usize {
var tokens = std.ArrayList(usize).init(allocator);

errdefer tokens.deinit();
Expand Down Expand Up @@ -125,7 +121,7 @@ fn encodeCodepoints(self: *const Self, allocator: std.mem.Allocator, text: []con
return tokens.toOwnedSlice();
}

fn mergeBestWordPair(self: *const Self, tokens: []usize, double_word_buffer: []u8) bool {
fn mergeBestWordPair(self: Self, tokens: []usize, double_word_buffer: []u8) bool {
if (tokens.len < 1) {
return false;
}
Expand Down Expand Up @@ -168,7 +164,7 @@ fn mergeBestWordPair(self: *const Self, tokens: []usize, double_word_buffer: []u
return false;
}

fn lookupToken(self: *const Self, word: []const u8) ?usize {
fn lookupToken(self: Self, word: []const u8) ?usize {
var left: usize = 0;
var right = self.sorted_vocab.len;

Expand Down
4 changes: 2 additions & 2 deletions src/transformer.zig
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ pub fn init(
};
}

pub fn deinit(self: *const Self) void {
pub fn deinit(self: Self) void {
self.checkpoint.deinit();
self.attention.deinit();
self.ffn.deinit();
self.hidden_buffer.deinit();
self.output_buffer.deinit();
}

pub fn forward(self: *const Self, token: usize, position: usize) void {
pub fn forward(self: Self, token: usize, position: usize) void {
const weights = self.checkpoint.weights;

@memcpy(self.hidden_buffer.values, weights.token_embedding_vectors.slice(token).values);
Expand Down

0 comments on commit 96b1b5c

Please sign in to comment.