diff --git a/src/checkpoint.zig b/src/checkpoint.zig index af8a534..ca3d38b 100644 --- a/src/checkpoint.zig +++ b/src/checkpoint.zig @@ -13,7 +13,7 @@ n_heads: usize, n_query_groups: usize, vocab_size: usize, max_sequence_length: usize, -shared_final_classifier_matrix: bool, +shared_classifier_matrix: bool, weights: struct { token_embedding_vectors: Tensor(2), @@ -22,12 +22,12 @@ weights: struct { attention_key_matrices: Tensor(3), attention_value_matrices: Tensor(3), attention_output_matrices: Tensor(3), - ffn_pre_norm_vectors: Tensor(2), - ffn_pre_activation_matrices: Tensor(3), - ffn_output_matrices: Tensor(3), - ffn_gate_matrices: Tensor(3), - final_norm_vector: Tensor(1), - final_classifier_matrix: Tensor(2), + feed_forward_pre_norm_vectors: Tensor(2), + feed_forward_pre_activation_matrices: Tensor(3), + feed_forward_output_matrices: Tensor(3), + feed_forward_gate_matrices: Tensor(3), + classifier_pre_norm_vector: Tensor(1), + classifier_matrix: Tensor(2), }, pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { @@ -43,7 +43,7 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { // https://github.com/karpathy/llama2.c/blob/35deb5e0fa55f0a257040bcf1624ed8386e63dc7/run.c#L153 const signed_vocab_size = try file.reader().readIntLittle(i32); - const shared_final_classifier_matrix = signed_vocab_size > 0; + const shared_classifier_matrix = signed_vocab_size > 0; const vocab_size: usize = std.math.absCast(signed_vocab_size); const max_sequence_length: usize = @intCast(try file.reader().readIntLittle(i32)); @@ -98,56 +98,56 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { errdefer attention_output_matrices.deinit(); try attention_output_matrices.read(file); - const ffn_pre_norm_vectors = try Tensor(2).init( + const feed_forward_pre_norm_vectors = try Tensor(2).init( allocator, [_]usize{ n_layers, embedding_size }, ); - errdefer ffn_pre_norm_vectors.deinit(); - try ffn_pre_norm_vectors.read(file); + errdefer feed_forward_pre_norm_vectors.deinit(); + try feed_forward_pre_norm_vectors.read(file); - const ffn_pre_activation_matrices = try Tensor(3).init( + const feed_forward_pre_activation_matrices = try Tensor(3).init( allocator, [_]usize{ n_layers, hidden_size, embedding_size }, ); - errdefer ffn_pre_activation_matrices.deinit(); - try ffn_pre_activation_matrices.read(file); + errdefer feed_forward_pre_activation_matrices.deinit(); + try feed_forward_pre_activation_matrices.read(file); - const ffn_output_matrices = try Tensor(3).init( + const feed_forward_output_matrices = try Tensor(3).init( allocator, [_]usize{ n_layers, embedding_size, hidden_size }, ); - errdefer ffn_output_matrices.deinit(); - try ffn_output_matrices.read(file); + errdefer feed_forward_output_matrices.deinit(); + try feed_forward_output_matrices.read(file); - const ffn_gate_matrices = try Tensor(3).init( + const feed_forward_gate_matrices = try Tensor(3).init( allocator, [_]usize{ n_layers, hidden_size, embedding_size }, ); - errdefer ffn_gate_matrices.deinit(); - try ffn_gate_matrices.read(file); + errdefer feed_forward_gate_matrices.deinit(); + try feed_forward_gate_matrices.read(file); - const final_norm_vector = try Tensor(1).init(allocator, [_]usize{embedding_size}); + const classifier_pre_norm_vector = try Tensor(1).init(allocator, [_]usize{embedding_size}); - errdefer final_norm_vector.deinit(); - try final_norm_vector.read(file); + errdefer classifier_pre_norm_vector.deinit(); + try classifier_pre_norm_vector.read(file); try file.seekBy(@intCast(max_sequence_length * head_size * @sizeOf(f32))); - const final_classifier_matrix = if (shared_final_classifier_matrix) + const classifier_matrix = if (shared_classifier_matrix) token_embedding_vectors else try Tensor(2).init(allocator, [_]usize{ vocab_size, embedding_size }); - errdefer if (!shared_final_classifier_matrix) { - final_classifier_matrix.deinit(); + errdefer if (!shared_classifier_matrix) { + classifier_matrix.deinit(); }; - if (!shared_final_classifier_matrix) { - try final_classifier_matrix.read(file); + if (!shared_classifier_matrix) { + try classifier_matrix.read(file); } return Self{ @@ -159,7 +159,7 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { .n_query_groups = n_query_groups, .vocab_size = vocab_size, .max_sequence_length = max_sequence_length, - .shared_final_classifier_matrix = shared_final_classifier_matrix, + .shared_classifier_matrix = shared_classifier_matrix, .weights = .{ .token_embedding_vectors = token_embedding_vectors, @@ -168,12 +168,12 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { .attention_key_matrices = attention_key_matrices, .attention_value_matrices = attention_value_matrices, .attention_output_matrices = attention_output_matrices, - .ffn_pre_norm_vectors = ffn_pre_norm_vectors, - .ffn_pre_activation_matrices = ffn_pre_activation_matrices, - .ffn_output_matrices = ffn_output_matrices, - .ffn_gate_matrices = ffn_gate_matrices, - .final_norm_vector = final_norm_vector, - .final_classifier_matrix = final_classifier_matrix, + .feed_forward_pre_norm_vectors = feed_forward_pre_norm_vectors, + .feed_forward_pre_activation_matrices = feed_forward_pre_activation_matrices, + .feed_forward_output_matrices = feed_forward_output_matrices, + .feed_forward_gate_matrices = feed_forward_gate_matrices, + .classifier_pre_norm_vector = classifier_pre_norm_vector, + .classifier_matrix = classifier_matrix, }, }; } @@ -185,13 +185,13 @@ pub fn deinit(self: *const Self) void { self.weights.attention_key_matrices.deinit(); self.weights.attention_value_matrices.deinit(); self.weights.attention_output_matrices.deinit(); - self.weights.ffn_pre_norm_vectors.deinit(); - self.weights.ffn_pre_activation_matrices.deinit(); - self.weights.ffn_output_matrices.deinit(); - self.weights.ffn_gate_matrices.deinit(); - self.weights.final_norm_vector.deinit(); - - if (!self.shared_final_classifier_matrix) { - self.weights.final_classifier_matrix.deinit(); + self.weights.feed_forward_pre_norm_vectors.deinit(); + self.weights.feed_forward_pre_activation_matrices.deinit(); + self.weights.feed_forward_output_matrices.deinit(); + self.weights.feed_forward_gate_matrices.deinit(); + self.weights.classifier_pre_norm_vector.deinit(); + + if (!self.shared_classifier_matrix) { + self.weights.classifier_matrix.deinit(); } } diff --git a/src/ffn.zig b/src/feed_forward.zig similarity index 89% rename from src/ffn.zig rename to src/feed_forward.zig index 22cb169..40989b8 100644 --- a/src/ffn.zig +++ b/src/feed_forward.zig @@ -50,9 +50,9 @@ pub fn forward(self: *const Self, layer: usize) !void { @setFloatMode(.Optimized); const weights = self.checkpoint.weights; - const pre_activation_matrix = weights.ffn_pre_activation_matrices.slice(layer); - const gate_matrix = weights.ffn_gate_matrices.slice(layer); - const output_matrix = weights.ffn_output_matrices.slice(layer); + const pre_activation_matrix = weights.feed_forward_pre_activation_matrices.slice(layer); + const gate_matrix = weights.feed_forward_gate_matrices.slice(layer); + const output_matrix = weights.feed_forward_output_matrices.slice(layer); pre_activation_matrix.multiplyVector(self.input_buffer, self.hidden_buffer); gate_matrix.multiplyVector(self.input_buffer, self.gate_buffer); diff --git a/src/transformer.zig b/src/transformer.zig index 1a94edd..4511efa 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -4,7 +4,7 @@ const std = @import("std"); const Attention = @import("attention.zig"); const Checkpoint = @import("checkpoint.zig"); const Cli = @import("cli.zig"); -const Ffn = @import("ffn.zig"); +const FeedForward = @import("feed_forward.zig"); const Tensor = @import("./tensor.zig").Tensor; const vector = @import("vector.zig"); @@ -12,7 +12,7 @@ allocator: std.mem.Allocator, checkpoint: Checkpoint, sequence_length: usize, attention: Attention, -ffn: Ffn, +feed_forward: FeedForward, hidden_buffer: Tensor(1), logits_buffer: Tensor(1), @@ -26,9 +26,9 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { errdefer attention.deinit(); - const ffn = try Ffn.init(allocator, checkpoint); + const feed_forward = try FeedForward.init(allocator, checkpoint); - errdefer ffn.deinit(); + errdefer feed_forward.deinit(); const hidden_buffer = try Tensor(1).init(allocator, [_]usize{checkpoint.embedding_size}); @@ -43,7 +43,7 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { .checkpoint = checkpoint, .sequence_length = sequence_length, .attention = attention, - .ffn = ffn, + .feed_forward = feed_forward, .hidden_buffer = hidden_buffer, .logits_buffer = logits_buffer, }; @@ -52,7 +52,7 @@ pub fn init(allocator: std.mem.Allocator, cli: *const Cli) !Self { pub fn deinit(self: *const Self) void { self.checkpoint.deinit(); self.attention.deinit(); - self.ffn.deinit(); + self.feed_forward.deinit(); self.hidden_buffer.deinit(); self.logits_buffer.deinit(); } @@ -64,7 +64,7 @@ pub fn forward(self: *const Self, token: usize, position: usize) !void { for (0..self.checkpoint.n_layers) |layer| { const attention_pre_norm_vector = weights.attention_pre_norm_vectors.slice(layer); - const ffn_pre_norm_vector = weights.ffn_pre_norm_vectors.slice(layer); + const feed_forward_pre_norm_vector = weights.feed_forward_pre_norm_vectors.slice(layer); vector.rmsnorm( self.hidden_buffer.data, @@ -76,18 +76,22 @@ pub fn forward(self: *const Self, token: usize, position: usize) !void { vector.add(self.hidden_buffer.data, self.attention.output_buffer.data); - vector.rmsnorm(self.hidden_buffer.data, ffn_pre_norm_vector.data, self.ffn.input_buffer.data); + vector.rmsnorm( + self.hidden_buffer.data, + feed_forward_pre_norm_vector.data, + self.feed_forward.input_buffer.data, + ); - try self.ffn.forward(layer); + try self.feed_forward.forward(layer); - vector.add(self.hidden_buffer.data, self.ffn.output_buffer.data); + vector.add(self.hidden_buffer.data, self.feed_forward.output_buffer.data); } vector.rmsnorm( self.hidden_buffer.data, - weights.final_norm_vector.data, + weights.classifier_pre_norm_vector.data, self.hidden_buffer.data, ); - weights.final_classifier_matrix.multiplyVector(self.hidden_buffer, self.logits_buffer); + weights.classifier_matrix.multiplyVector(self.hidden_buffer, self.logits_buffer); }