diff --git a/src/attention.zig b/src/attention.zig index 2065bbe..e9d7ea2 100644 --- a/src/attention.zig +++ b/src/attention.zig @@ -67,7 +67,7 @@ pub fn forward(self: Self, layer: usize, position: usize) !void { try key_weight.multiplyVector(self.input, multi_key); try value_weight.multiplyVector(self.input, multi_value); - self.computeRoPE(position, multi_key.values); + self.computeRoPE(position, multi_key.data); for (0..self.checkpoint.n_attention_heads) |head| { try self.computeGQA(layer, position, head); @@ -77,37 +77,37 @@ pub fn forward(self: Self, layer: usize, position: usize) !void { } // Rotary positional embeddings: https://arxiv.org/abs/2104.09864 -fn computeRoPE(self: Self, position: usize, multi_key_values: []f32) void { +fn computeRoPE(self: Self, position: usize, multi_key_data: []f32) void { @setFloatMode(.Optimized); - const multi_query_values = self.multi_query.values; + const multi_query_data = self.multi_query.data; - std.debug.assert(multi_query_values.len % multi_key_values.len == 0); + std.debug.assert(multi_query_data.len % multi_key_data.len == 0); var index: usize = 0; - while (index < multi_query_values.len) : (index += 2) { + while (index < multi_query_data.len) : (index += 2) { const head: f32 = @floatFromInt(index % self.head_size); const frequency = 1 / std.math.pow(f32, 10000, head / @as(f32, @floatFromInt(self.head_size))); const rotation_scaling_factor: f32 = @as(f32, @floatFromInt(position)) * frequency; - const real_rotation_value: f32 = std.math.cos(rotation_scaling_factor); - const imag_rotation_value: f32 = std.math.sin(rotation_scaling_factor); + const real_rotation: f32 = std.math.cos(rotation_scaling_factor); + const imag_rotation: f32 = std.math.sin(rotation_scaling_factor); - const q_0 = multi_query_values[index]; - const q_1 = multi_query_values[index + 1]; + const q_0 = multi_query_data[index]; + const q_1 = multi_query_data[index + 1]; - multi_query_values[index] = q_0 * real_rotation_value - q_1 * imag_rotation_value; - multi_query_values[index + 1] = q_0 * imag_rotation_value + q_1 * real_rotation_value; + multi_query_data[index] = q_0 * real_rotation - q_1 * imag_rotation; + multi_query_data[index + 1] = q_0 * imag_rotation + q_1 * real_rotation; - if (index < multi_key_values.len) { - const k_0 = multi_key_values[index]; - const k_1 = multi_key_values[index + 1]; + if (index < multi_key_data.len) { + const k_0 = multi_key_data[index]; + const k_1 = multi_key_data[index + 1]; - multi_key_values[index] = k_0 * real_rotation_value - k_1 * imag_rotation_value; - multi_key_values[index + 1] = k_0 * imag_rotation_value + k_1 * real_rotation_value; + multi_key_data[index] = k_0 * real_rotation - k_1 * imag_rotation; + multi_key_data[index + 1] = k_0 * imag_rotation + k_1 * real_rotation; } } } @@ -116,7 +116,7 @@ fn computeRoPE(self: Self, position: usize, multi_key_values: []f32) void { fn computeGQA(self: Self, layer: usize, current_position: usize, head: usize) !void { @setFloatMode(.Optimized); - const query_values = self.multi_query.values[head * self.head_size ..][0..self.head_size]; + const query_data = self.multi_query.data[head * self.head_size ..][0..self.head_size]; const query_group = head / (self.checkpoint.n_attention_heads / self.checkpoint.n_attention_query_groups); @@ -125,25 +125,25 @@ fn computeGQA(self: Self, layer: usize, current_position: usize, head: usize) !v for (0..next_position) |position| { const multi_key = self.key_cache[layer][position]; - const key_values = multi_key.values[query_group * self.head_size ..][0..self.head_size]; + const key_data = multi_key.data[query_group * self.head_size ..][0..self.head_size]; self.scores[position] = - try simd.computeScalarProduct(query_values, key_values) / self.head_size_sqrt; + try simd.computeScalarProduct(query_data, key_data) / self.head_size_sqrt; } math.softmax(self.scores[0..next_position]); - const attention_values = self.input.values[head * self.head_size ..][0..self.head_size]; + const attention_data = self.input.data[head * self.head_size ..][0..self.head_size]; - @memset(attention_values, 0); + @memset(attention_data, 0); for (0..next_position) |position| { const multi_value = self.value_cache[layer][position]; - const value_values = multi_value.values[query_group * self.head_size ..][0..self.head_size]; + const value_data = multi_value.data[query_group * self.head_size ..][0..self.head_size]; const weight = self.scores[position]; for (0..self.head_size) |index| { - attention_values[index] += value_values[index] * weight; + attention_data[index] += value_data[index] * weight; } } } diff --git a/src/chat.zig b/src/chat.zig index cec0fed..b16b9be 100644 --- a/src/chat.zig +++ b/src/chat.zig @@ -111,7 +111,7 @@ pub fn start(self: *Self, allocator: std.mem.Allocator) !void { user_prompt_tokens_index += 1; if (next_token == 0) { - next_token = self.sampler.sample(self.transformer.output.values); + next_token = self.sampler.sample(self.transformer.output.data); } if (next_token == eos_token) { diff --git a/src/ffn.zig b/src/ffn.zig index ba943a8..0e8a996 100644 --- a/src/ffn.zig +++ b/src/ffn.zig @@ -32,7 +32,7 @@ pub fn forward(self: Self, layer: usize) !void { try up_weight.multiplyVector(self.input, self.hidden); for (0..self.checkpoint.ffn_hidden_size) |index| { - self.hidden.values[index] *= swish(self.gate.values[index]); + self.hidden.data[index] *= swish(self.gate.data[index]); } try down_weight.multiplyVector(self.hidden, self.output); diff --git a/src/generator.zig b/src/generator.zig index f95aef9..cac9a4a 100644 --- a/src/generator.zig +++ b/src/generator.zig @@ -54,7 +54,7 @@ pub fn generate(self: *Self, writer: anytype) !void { next_token = self.prompt_tokens[prompt_tokens_index]; prompt_tokens_index += 1; } else { - next_token = self.sampler.sample(self.transformer.output.values); + next_token = self.sampler.sample(self.transformer.output.data); } if (next_token == bos_token or next_token == eos_token) { diff --git a/src/math.zig b/src/math.zig index 72d3e25..257ba18 100644 --- a/src/math.zig +++ b/src/math.zig @@ -1,33 +1,33 @@ const std = @import("std"); -pub fn argmax(values: []f32) usize { +pub fn argmax(data: []f32) usize { var max_index: usize = 0; - var max_value: f32 = values[max_index]; + var max_element: f32 = data[max_index]; - for (1..values.len) |index| { - const value = values[index]; + for (1..data.len) |index| { + const element = data[index]; - if (value > max_value) { + if (element > max_element) { max_index = index; - max_value = value; + max_element = element; } } return max_index; } -pub fn softmax(values: []f32) void { +pub fn softmax(data: []f32) void { @setFloatMode(.Optimized); - var max_value: f32 = std.mem.max(f32, values); + var max_element: f32 = std.mem.max(f32, data); var sum: f32 = 0; - for (values) |*value| { - value.* = std.math.exp(value.* - max_value); - sum += value.*; + for (data) |*element| { + element.* = std.math.exp(element.* - max_element); + sum += element.*; } - for (values) |*value| { - value.* /= sum; + for (data) |*element| { + element.* /= sum; } } diff --git a/src/matrix.zig b/src/matrix.zig index 86e368f..00f3cf2 100644 --- a/src/matrix.zig +++ b/src/matrix.zig @@ -40,21 +40,21 @@ const max_thread_count = 24; pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void { if (self.thread_count == 0) { - try computeMatrixVectorMultiplication(self.rows, input, output.values); + try computeMatrixVectorMultiplication(self.rows, input, output.data); return; } const n_threads = @min(max_thread_count, self.thread_count); - const thread_chunk_size = output.values.len / n_threads; + const chunk_size = output.data.len / n_threads; var threads: [max_thread_count]std.Thread = undefined; for (threads[0..n_threads], 0..) |*thread, index| { thread.* = try std.Thread.spawn(.{}, computeMatrixVectorMultiplication, .{ - self.rows[index * thread_chunk_size ..][0..thread_chunk_size], + self.rows[index * chunk_size ..][0..chunk_size], input, - output.values[index * thread_chunk_size ..][0..thread_chunk_size], + output.data[index * chunk_size ..][0..chunk_size], }); } @@ -62,11 +62,11 @@ pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void { thread.join(); } - if (output.values.len % n_threads > 0) { + if (output.data.len % n_threads > 0) { try computeMatrixVectorMultiplication( - self.rows[n_threads * thread_chunk_size ..], + self.rows[n_threads * chunk_size ..], input, - output.values[n_threads * thread_chunk_size ..], + output.data[n_threads * chunk_size ..], ); } } @@ -74,11 +74,11 @@ pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void { fn computeMatrixVectorMultiplication( rows: []const Vector, input: Vector, - output_values: []f32, + output_data: []f32, ) !void { - std.debug.assert(rows.len == output_values.len); + std.debug.assert(rows.len == output_data.len); - for (output_values, 0..) |*value, index| { - value.* = try rows[index].computeScalarProduct(input); + for (output_data, 0..) |*element, index| { + element.* = try rows[index].computeScalarProduct(input); } } diff --git a/src/simd.zig b/src/simd.zig index 539e4ff..2f14bf0 100644 --- a/src/simd.zig +++ b/src/simd.zig @@ -1,41 +1,37 @@ const std = @import("std"); // Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467 -pub fn computeRMSNorm( - input_values: []const f32, - weight_values: []const f32, - output_values: []f32, -) !void { +pub fn computeRMSNorm(input_data: []const f32, weight_data: []const f32, output_data: []f32) !void { @setFloatMode(.Optimized); - var scaling_factor = try computeScalarProduct(input_values, input_values); + var scaling_factor = try computeScalarProduct(input_data, input_data); - scaling_factor /= @floatFromInt(input_values.len); + scaling_factor /= @floatFromInt(input_data.len); scaling_factor += 1e-5; scaling_factor = 1 / std.math.sqrt(scaling_factor); - try computeVectorMultiplication(scaling_factor, input_values, weight_values, output_values); + try computeVectorMultiplication(scaling_factor, input_data, weight_data, output_data); } -pub fn computeScalarProduct(input_values_1: []const f32, input_values_2: []const f32) !f32 { +pub fn computeScalarProduct(input_data_1: []const f32, input_data_2: []const f32) !f32 { @setFloatMode(.Optimized); - std.debug.assert(input_values_1.len == input_values_2.len); + std.debug.assert(input_data_1.len == input_data_2.len); comptime var vector_len = std.atomic.cache_line / @sizeOf(f32); inline while (vector_len >= 4) : (vector_len /= 2) { - if (input_values_1.len % vector_len == 0) { - var output_values: @Vector(vector_len, f32) = @splat(0); + if (input_data_1.len % vector_len == 0) { + var output_data: @Vector(vector_len, f32) = @splat(0); var index: usize = 0; - while (index < input_values_1.len) : (index += vector_len) { - output_values += - @as(@Vector(vector_len, f32), input_values_1[index..][0..vector_len].*) * - @as(@Vector(vector_len, f32), input_values_2[index..][0..vector_len].*); + while (index < input_data_1.len) : (index += vector_len) { + output_data += + @as(@Vector(vector_len, f32), input_data_1[index..][0..vector_len].*) * + @as(@Vector(vector_len, f32), input_data_2[index..][0..vector_len].*); } - return @reduce(.Add, output_values); + return @reduce(.Add, output_data); } } @@ -43,25 +39,25 @@ pub fn computeScalarProduct(input_values_1: []const f32, input_values_2: []const } pub fn computeVectorAddition( - input_values_1: []const f32, - input_values_2: []const f32, - output_values: []f32, + input_data_1: []const f32, + input_data_2: []const f32, + output_data: []f32, ) !void { @setFloatMode(.Optimized); - std.debug.assert(input_values_1.len == input_values_2.len); - std.debug.assert(input_values_1.len == output_values.len); + std.debug.assert(input_data_1.len == input_data_2.len); + std.debug.assert(input_data_1.len == output_data.len); comptime var vector_len = std.atomic.cache_line / @sizeOf(f32); inline while (vector_len >= 4) : (vector_len /= 2) { - if (input_values_1.len % vector_len == 0) { + if (input_data_1.len % vector_len == 0) { var index: usize = 0; - while (index < input_values_1.len) : (index += vector_len) { - output_values[index..][0..vector_len].* = - @as(@Vector(vector_len, f32), input_values_1[index..][0..vector_len].*) + - @as(@Vector(vector_len, f32), input_values_2[index..][0..vector_len].*); + while (index < input_data_1.len) : (index += vector_len) { + output_data[index..][0..vector_len].* = + @as(@Vector(vector_len, f32), input_data_1[index..][0..vector_len].*) + + @as(@Vector(vector_len, f32), input_data_2[index..][0..vector_len].*); } return; @@ -73,28 +69,28 @@ pub fn computeVectorAddition( pub fn computeVectorMultiplication( scaling_factor: f32, - input_values_1: []const f32, - input_values_2: []const f32, - output_values: []f32, + input_data_1: []const f32, + input_data_2: []const f32, + output_data: []f32, ) !void { @setFloatMode(.Optimized); - std.debug.assert(input_values_1.len == input_values_2.len); - std.debug.assert(input_values_1.len == output_values.len); + std.debug.assert(input_data_1.len == input_data_2.len); + std.debug.assert(input_data_1.len == output_data.len); comptime var vector_len = std.atomic.cache_line / @sizeOf(f32); inline while (vector_len >= 4) : (vector_len /= 2) { - if (input_values_1.len % vector_len == 0) { + if (input_data_1.len % vector_len == 0) { const scaling_factors: @Vector(vector_len, f32) = @splat(scaling_factor); var index: usize = 0; - while (index < input_values_1.len) : (index += vector_len) { - output_values[index..][0..vector_len].* = + while (index < input_data_1.len) : (index += vector_len) { + output_data[index..][0..vector_len].* = scaling_factors * - @as(@Vector(vector_len, f32), input_values_1[index..][0..vector_len].*) * - @as(@Vector(vector_len, f32), input_values_2[index..][0..vector_len].*); + @as(@Vector(vector_len, f32), input_data_1[index..][0..vector_len].*) * + @as(@Vector(vector_len, f32), input_data_2[index..][0..vector_len].*); } return; diff --git a/src/transformer.zig b/src/transformer.zig index 3bcdadb..6bdae7f 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -34,7 +34,7 @@ pub fn createLeaky(allocator: std.mem.Allocator, args: anytype) !Self { pub fn forward(self: Self, token: usize, position: usize) !void { const embedding_weight = self.checkpoint.embedding_weights[token]; - @memcpy(self.hidden.values, embedding_weight.values); + @memcpy(self.hidden.data, embedding_weight.data); for (0..self.checkpoint.n_layers) |layer| { const attention_norm_weight = self.checkpoint.attention_norm_weights[layer]; diff --git a/src/vector.zig b/src/vector.zig index ccde174..1bbc663 100644 --- a/src/vector.zig +++ b/src/vector.zig @@ -3,31 +3,31 @@ const Self = @This(); const std = @import("std"); const simd = @import("simd.zig"); -values: []align(std.atomic.cache_line) f32, +data: []align(std.atomic.cache_line) f32, -pub fn createLeaky(allocator: std.mem.Allocator, n_values: usize) !Self { - return .{ .values = try allocator.alignedAlloc(f32, std.atomic.cache_line, n_values) }; +pub fn createLeaky(allocator: std.mem.Allocator, data_size: usize) !Self { + return .{ .data = try allocator.alignedAlloc(f32, std.atomic.cache_line, data_size) }; } pub fn createMultipleLeaky( allocator: std.mem.Allocator, n_vectors: usize, - n_values: usize, + data_size: usize, ) ![]Self { const vectors = try allocator.alloc(Self, n_vectors); for (vectors) |*vector| { - vector.* = try createLeaky(allocator, n_values); + vector.* = try createLeaky(allocator, data_size); } return vectors; } -pub fn readLeaky(allocator: std.mem.Allocator, file: std.fs.File, n_values: usize) !Self { - const vector = try createLeaky(allocator, n_values); - const bytes: [*]u8 = @ptrCast(vector.values); +pub fn readLeaky(allocator: std.mem.Allocator, file: std.fs.File, data_size: usize) !Self { + const vector = try createLeaky(allocator, data_size); + const data: [*]u8 = @ptrCast(vector.data); - try file.reader().readNoEof(bytes[0 .. vector.values.len * @sizeOf(f32)]); + try file.reader().readNoEof(data[0 .. vector.data.len * @sizeOf(f32)]); return vector; } @@ -36,25 +36,25 @@ pub fn readMultipleLeaky( allocator: std.mem.Allocator, file: std.fs.File, n_vectors: usize, - n_values: usize, + data_size: usize, ) ![]Self { const vectors = try allocator.alloc(Self, n_vectors); for (vectors) |*vector| { - vector.* = try readLeaky(allocator, file, n_values); + vector.* = try readLeaky(allocator, file, data_size); } return vectors; } pub fn addVector(self: Self, other: Self) !void { - try simd.computeVectorAddition(self.values, other.values, self.values); + try simd.computeVectorAddition(self.data, other.data, self.data); } pub fn computeRMSNorm(self: Self, weight: Self, output: Self) !void { - try simd.computeRMSNorm(self.values, weight.values, output.values); + try simd.computeRMSNorm(self.data, weight.data, output.data); } pub fn computeScalarProduct(self: Self, other: Self) !f32 { - return simd.computeScalarProduct(self.values, other.values); + return simd.computeScalarProduct(self.data, other.data); }