diff --git a/src/attention.zig b/src/attention.zig index 654ce8c..28871cf 100644 --- a/src/attention.zig +++ b/src/attention.zig @@ -89,17 +89,17 @@ pub fn forward(self: *const Self, layer: usize, position: usize) void { key_matrix.computeMatrixVectorMultiplication(self.input_buffer, key_buffer); value_matrix.computeMatrixVectorMultiplication(self.input_buffer, value_buffer); - self.rope(position, key_buffer); + self.computeRoPE(position, key_buffer); for (0..self.checkpoint.n_attention_heads) |head| { - self.gqa(layer, position, head); + self.computeGQA(layer, position, head); } output_matrix.computeMatrixVectorMultiplication(self.input_buffer, self.output_buffer); } // Rotary positional embeddings: https://arxiv.org/abs/2104.09864 -fn rope(self: *const Self, position: usize, key_buffer: Tensor(2)) void { +fn computeRoPE(self: *const Self, position: usize, key_buffer: Tensor(2)) void { @setFloatMode(.Optimized); std.debug.assert(self.query_buffer.values.len % key_buffer.values.len == 0); @@ -133,7 +133,7 @@ fn rope(self: *const Self, position: usize, key_buffer: Tensor(2)) void { } // Grouped-query attention: https://arxiv.org/abs/2305.13245v1 -fn gqa(self: *const Self, layer: usize, current_position: usize, head: usize) void { +fn computeGQA(self: *const Self, layer: usize, current_position: usize, head: usize) void { @setFloatMode(.Optimized); const query_vector = self.query_buffer.slice(head);