diff --git a/src/vector.zig b/src/vector.zig index 570e016..9c2bda3 100644 --- a/src/vector.zig +++ b/src/vector.zig @@ -56,11 +56,11 @@ pub fn dot(input_a: []const f32, input_b: []const f32) f32 { } // Pre-normalization using RMSNorm: https://arxiv.org/abs/1910.07467 -pub fn rmsnorm(input: []const f32, weight: []const f32, output: []f32) void { +pub fn rmsnorm(input: []const f32, weights: []const f32, output: []f32) void { @setFloatMode(.Optimized); std.debug.assert(output.len == input.len); - std.debug.assert(output.len == weight.len); + std.debug.assert(output.len == weights.len); var rms_scaling_factor: f32 = 0; @@ -73,7 +73,7 @@ pub fn rmsnorm(input: []const f32, weight: []const f32, output: []f32) void { rms_scaling_factor = 1 / std.math.sqrt(rms_scaling_factor); for (output, 0..) |*element, index| { - element.* = weight[index] * rms_scaling_factor * input[index]; + element.* = weights[index] * rms_scaling_factor * input[index]; } }