diff --git a/README.md b/README.md index 0c4664f..81ed924 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ Lily wanted to play with the ball, but it was too high up in the sky. She tried Lily found a stick and tried to hit the ball. But the stick was too short. She tried again and again, but she couldn't reach it. She felt sad. Suddenly, a kind man came by and saw Lily. He asked her what was wrong. Lily told him about the ball. The man smiled and said, "I have a useful idea!" He took out a long stick and used it to knock the ball down. Lily was so happy! She thanked the man and they played together in the sunshine. -achieved: 719.870 tok/s +achieved: 724.590 tok/s ``` ## Run Llama 2 7B from Hugging Face @@ -47,7 +47,12 @@ Build and run `llama2-generator`: ```sh zig build -Doptimize=ReleaseFast -./zig-out/bin/llama2-generator models/llama2_7b_hf --temperature 0 --sequence_length 28 --prompt "Once Upon a Time" --verbose +./zig-out/bin/llama2-generator models/llama2_7b_hf \ +--prompt "Once Upon a Time" \ +--sequence_length 28 \ +--temperature 0 \ +--thread_count 8 \ +--verbose ``` The output on an Apple M1 Pro with 32 GB of memory: @@ -55,7 +60,7 @@ The output on an Apple M1 Pro with 32 GB of memory: ``` Once Upon a Time in Hollywood is a 2019 American comedy-drama film written and directed by Quentin Tarantino -achieved: 1.800 tok/s +achieved: 3.482 tok/s ``` ## Run Llama 2 7B Chat from Hugging Face @@ -79,7 +84,7 @@ Build and run `llama2-chat`: ```sh zig build -Doptimize=ReleaseFast -./zig-out/bin/llama2-chat models/llama2_7b_chat_hf +./zig-out/bin/llama2-chat models/llama2_7b_chat_hf --temperature 0 --thread_count 8 ``` The output on an Apple M1 Pro with 32 GB of memory: @@ -99,11 +104,12 @@ User: ... Usage: llama2-generator [options] Options: - --temperature = 1.0 - --top_p = 0.9 + --prompt = "" --random_seed = --sequence_length = - --prompt = "" + --temperature = 1.0 + --thread_count = 0 + --top_p = 0.9 --verbose --help ``` @@ -114,11 +120,12 @@ Options: Usage: llama2-chat [options] Options: - --temperature = 1.0 - --top_p = 0.9 --random_seed = --sequence_length = --system_prompt = "" + --temperature = 1.0 + --thread_count = 0 + --top_p = 0.9 --user_prompt = "" --help ``` diff --git a/src/chat.zig b/src/chat.zig index 1953973..a37f2ce 100644 --- a/src/chat.zig +++ b/src/chat.zig @@ -14,7 +14,13 @@ system_prompt: []const u8, user_prompt: []const u8, pub fn createLeaky(allocator: std.mem.Allocator, args: ChatArgs) !Self { - const transformer = try Transformer.createLeaky(allocator, args.model_path, args.sequence_length); + const transformer = try Transformer.createLeaky( + allocator, + args.model_path, + args.sequence_length, + args.thread_count, + ); + const vocab_size = transformer.checkpoint.vocab_size; return .{ diff --git a/src/chat_args.zig b/src/chat_args.zig index b38e9c7..d1123e4 100644 --- a/src/chat_args.zig +++ b/src/chat_args.zig @@ -3,19 +3,21 @@ const Self = @This(); const std = @import("std"); model_path: []const u8, -temperature: f32, -top_p: f32, random_seed: u64, sequence_length: usize, system_prompt: []const u8, +temperature: f32, +thread_count: usize, +top_p: f32, user_prompt: []const u8, const Option = enum { - temperature, - top_p, random_seed, sequence_length, system_prompt, + temperature, + thread_count, + top_p, user_prompt, }; @@ -27,25 +29,28 @@ pub fn createLeaky(allocator: std.mem.Allocator) !Self { const model_path = arg_iterator.next() orelse try help(1); var current_option: ?Option = null; - var temperature: ?f32 = null; - var top_p: ?f32 = null; var random_seed: ?u64 = null; var sequence_length: ?usize = null; var system_prompt: ?[]const u8 = null; + var temperature: ?f32 = null; + var thread_count: ?usize = null; + var top_p: ?f32 = null; var user_prompt: ?[]const u8 = null; while (arg_iterator.next()) |arg| { if (current_option) |option| { - if (option == .temperature and temperature == null) { - temperature = try std.fmt.parseFloat(f32, arg); - } else if (option == .top_p and top_p == null) { - top_p = try std.fmt.parseFloat(f32, arg); - } else if (option == .random_seed and random_seed == null) { + if (option == .random_seed and random_seed == null) { random_seed = try std.fmt.parseInt(u64, arg, 10); } else if (option == .sequence_length and sequence_length == null) { sequence_length = try std.fmt.parseInt(usize, arg, 10); } else if (option == .system_prompt and system_prompt == null) { system_prompt = arg; + } else if (option == .temperature and temperature == null) { + temperature = try std.fmt.parseFloat(f32, arg); + } else if (option == .thread_count and thread_count == null) { + thread_count = try std.fmt.parseInt(usize, arg, 10); + } else if (option == .top_p and top_p == null) { + top_p = try std.fmt.parseFloat(f32, arg); } else if (option == .user_prompt and user_prompt == null) { user_prompt = arg; } else { @@ -53,16 +58,18 @@ pub fn createLeaky(allocator: std.mem.Allocator) !Self { } current_option = null; - } else if (std.mem.eql(u8, arg, "--temperature")) { - current_option = .temperature; - } else if (std.mem.eql(u8, arg, "--top_p")) { - current_option = .top_p; } else if (std.mem.eql(u8, arg, "--random_seed")) { current_option = .random_seed; } else if (std.mem.eql(u8, arg, "--sequence_length")) { current_option = .sequence_length; } else if (std.mem.eql(u8, arg, "--system_prompt")) { current_option = .system_prompt; + } else if (std.mem.eql(u8, arg, "--temperature")) { + current_option = .temperature; + } else if (std.mem.eql(u8, arg, "--thread_count")) { + current_option = .thread_count; + } else if (std.mem.eql(u8, arg, "--top_p")) { + current_option = .top_p; } else if (std.mem.eql(u8, arg, "--user_prompt")) { current_option = .user_prompt; } else { @@ -76,11 +83,12 @@ pub fn createLeaky(allocator: std.mem.Allocator) !Self { return .{ .model_path = model_path, - .temperature = @max(@min(temperature orelse 1, 1), 0), - .top_p = @max(@min(top_p orelse 0.9, 1), 0), .random_seed = random_seed orelse @intCast(std.time.milliTimestamp()), .sequence_length = sequence_length orelse 0, .system_prompt = system_prompt orelse "", + .temperature = @max(@min(temperature orelse 1, 1), 0), + .thread_count = thread_count orelse 0, + .top_p = @max(@min(top_p orelse 0.9, 1), 0), .user_prompt = user_prompt orelse "", }; } @@ -94,11 +102,12 @@ fn help(exit_status: u8) !noreturn { try console.print("Usage: llama2-chat [options]\n\n", .{}); try console.print("Options:\n", .{}); - try console.print(" --temperature = 1.0\n", .{}); - try console.print(" --top_p = 0.9\n", .{}); try console.print(" --random_seed = \n", .{}); try console.print(" --sequence_length = \n", .{}); try console.print(" --system_prompt = \"\"\n", .{}); + try console.print(" --temperature = 1.0\n", .{}); + try console.print(" --thread_count = 0\n", .{}); + try console.print(" --top_p = 0.9\n", .{}); try console.print(" --user_prompt = \"\"\n", .{}); try console.print(" --help\n", .{}); diff --git a/src/checkpoint.zig b/src/checkpoint.zig index b296146..a09ef7d 100644 --- a/src/checkpoint.zig +++ b/src/checkpoint.zig @@ -25,7 +25,7 @@ ffn_up_weights: []const Matrix, output_norm_weight: Vector, output_weight: Matrix, -pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { +pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8, thread_count: usize) !Self { const path = try std.fs.path.join( allocator, &[_][]const u8{ model_path, "checkpoint_v1.bin" }, @@ -85,6 +85,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { n_layers, embedding_size, embedding_size, + thread_count, ); const attention_head_size: usize = embedding_size / n_attention_heads; @@ -95,6 +96,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { n_layers, n_attention_query_groups * attention_head_size, embedding_size, + thread_count, ); const attention_value_weights = try Matrix.readMultipleLeaky( @@ -103,6 +105,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { n_layers, n_attention_query_groups * attention_head_size, embedding_size, + thread_count, ); const attention_output_weights = try Matrix.readMultipleLeaky( @@ -111,6 +114,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { n_layers, embedding_size, embedding_size, + thread_count, ); const ffn_gate_weights = try Matrix.readMultipleLeaky( @@ -119,6 +123,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { n_layers, ffn_hidden_size, embedding_size, + thread_count, ); const ffn_down_weights = try Matrix.readMultipleLeaky( @@ -127,6 +132,7 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { n_layers, embedding_size, ffn_hidden_size, + thread_count, ); const ffn_up_weights = try Matrix.readMultipleLeaky( @@ -135,12 +141,13 @@ pub fn readLeaky(allocator: std.mem.Allocator, model_path: []const u8) !Self { n_layers, ffn_hidden_size, embedding_size, + thread_count, ); const output_weight = if (shared_output_weight) - Matrix{ .rows = embedding_weights } + Matrix{ .rows = embedding_weights, .thread_count = thread_count } else - try Matrix.readLeaky(allocator, file, vocab_size, embedding_size); + try Matrix.readLeaky(allocator, file, vocab_size, embedding_size, thread_count); return .{ .embedding_size = embedding_size, diff --git a/src/generator.zig b/src/generator.zig index 66f30aa..9cfd451 100644 --- a/src/generator.zig +++ b/src/generator.zig @@ -18,6 +18,7 @@ pub fn createLeaky(allocator: std.mem.Allocator, args: GeneratorArgs) !Self { allocator, args.model_path, args.sequence_length, + args.thread_count, ); const vocab_size = transformer.checkpoint.vocab_size; @@ -92,11 +93,12 @@ test "generate tiny story" { const args = GeneratorArgs{ .model_path = "models/tinystories_260k", - .temperature = 1, - .top_p = 0.9, + .prompt = "There was", .random_seed = 42, .sequence_length = 10, - .prompt = "There was", + .temperature = 1, + .thread_count = 0, + .top_p = 0.9, .verbose = false, }; diff --git a/src/generator_args.zig b/src/generator_args.zig index 57c93d0..4d566b5 100644 --- a/src/generator_args.zig +++ b/src/generator_args.zig @@ -3,14 +3,22 @@ const Self = @This(); const std = @import("std"); model_path: []const u8, -temperature: f32, -top_p: f32, +prompt: []const u8, random_seed: u64, sequence_length: usize, -prompt: []const u8, +temperature: f32, +thread_count: usize, +top_p: f32, verbose: bool, -const Option = enum { temperature, top_p, random_seed, sequence_length, prompt }; +const Option = enum { + prompt, + random_seed, + sequence_length, + temperature, + thread_count, + top_p, +}; pub fn createLeaky(allocator: std.mem.Allocator) !Self { var arg_iterator = try std.process.argsWithAllocator(allocator); @@ -20,40 +28,45 @@ pub fn createLeaky(allocator: std.mem.Allocator) !Self { const model_path = arg_iterator.next() orelse try help(1); var current_option: ?Option = null; - var temperature: ?f32 = null; - var top_p: ?f32 = null; + var prompt: ?[]const u8 = null; var random_seed: ?u64 = null; var sequence_length: ?usize = null; - var prompt: ?[]const u8 = null; + var temperature: ?f32 = null; + var thread_count: ?usize = null; + var top_p: ?f32 = null; var verbose: bool = false; while (arg_iterator.next()) |arg| { if (current_option) |option| { - if (option == .temperature and temperature == null) { - temperature = try std.fmt.parseFloat(f32, arg); - } else if (option == .top_p and top_p == null) { - top_p = try std.fmt.parseFloat(f32, arg); + if (option == .prompt and prompt == null) { + prompt = arg; } else if (option == .random_seed and random_seed == null) { random_seed = try std.fmt.parseInt(u64, arg, 10); } else if (option == .sequence_length and sequence_length == null) { sequence_length = try std.fmt.parseInt(usize, arg, 10); - } else if (option == .prompt and prompt == null) { - prompt = arg; + } else if (option == .temperature and temperature == null) { + temperature = try std.fmt.parseFloat(f32, arg); + } else if (option == .thread_count and thread_count == null) { + thread_count = try std.fmt.parseInt(usize, arg, 10); + } else if (option == .top_p and top_p == null) { + top_p = try std.fmt.parseFloat(f32, arg); } else { try help(1); } current_option = null; - } else if (std.mem.eql(u8, arg, "--temperature")) { - current_option = .temperature; - } else if (std.mem.eql(u8, arg, "--top_p")) { - current_option = .top_p; + } else if (std.mem.eql(u8, arg, "--prompt")) { + current_option = .prompt; } else if (std.mem.eql(u8, arg, "--random_seed")) { current_option = .random_seed; } else if (std.mem.eql(u8, arg, "--sequence_length")) { current_option = .sequence_length; - } else if (std.mem.eql(u8, arg, "--prompt")) { - current_option = .prompt; + } else if (std.mem.eql(u8, arg, "--temperature")) { + current_option = .temperature; + } else if (std.mem.eql(u8, arg, "--thread_count")) { + current_option = .thread_count; + } else if (std.mem.eql(u8, arg, "--top_p")) { + current_option = .top_p; } else if (std.mem.eql(u8, arg, "--verbose") and !verbose) { verbose = true; } else { @@ -67,11 +80,12 @@ pub fn createLeaky(allocator: std.mem.Allocator) !Self { return .{ .model_path = model_path, - .temperature = @max(@min(temperature orelse 1, 1), 0), - .top_p = @max(@min(top_p orelse 0.9, 1), 0), + .prompt = prompt orelse "", .random_seed = random_seed orelse @intCast(std.time.milliTimestamp()), .sequence_length = sequence_length orelse 0, - .prompt = prompt orelse "", + .temperature = @max(@min(temperature orelse 1, 1), 0), + .thread_count = thread_count orelse 0, + .top_p = @max(@min(top_p orelse 0.9, 1), 0), .verbose = verbose, }; } @@ -85,11 +99,12 @@ fn help(exit_status: u8) !noreturn { try console.print("Usage: llama2-generator [options]\n\n", .{}); try console.print("Options:\n", .{}); - try console.print(" --temperature = 1.0\n", .{}); - try console.print(" --top_p = 0.9\n", .{}); + try console.print(" --prompt = \"\"\n", .{}); try console.print(" --random_seed = \n", .{}); try console.print(" --sequence_length = \n", .{}); - try console.print(" --prompt = \"\"\n", .{}); + try console.print(" --temperature = 1.0\n", .{}); + try console.print(" --thread_count = 0\n", .{}); + try console.print(" --top_p = 0.9\n", .{}); try console.print(" --verbose\n", .{}); try console.print(" --help\n", .{}); diff --git a/src/matrix.zig b/src/matrix.zig index ad372b5..8d8a9b6 100644 --- a/src/matrix.zig +++ b/src/matrix.zig @@ -4,14 +4,19 @@ const std = @import("std"); const Vector = @import("vector.zig"); rows: []const Vector, +thread_count: usize, pub fn readLeaky( allocator: std.mem.Allocator, file: std.fs.File, m_rows: usize, n_cols: usize, + thread_count: usize, ) !Self { - return .{ .rows = try Vector.readMultipleLeaky(allocator, file, m_rows, n_cols) }; + return .{ + .rows = try Vector.readMultipleLeaky(allocator, file, m_rows, n_cols), + .thread_count = thread_count, + }; } pub fn readMultipleLeaky( @@ -20,20 +25,57 @@ pub fn readMultipleLeaky( n_matrices: usize, m_rows: usize, n_cols: usize, + thread_count: usize, ) ![]Self { const matrices = try allocator.alloc(Self, n_matrices); for (matrices) |*matrix| { - matrix.* = try readLeaky(allocator, file, m_rows, n_cols); + matrix.* = try readLeaky(allocator, file, m_rows, n_cols, thread_count); } return matrices; } +const max_thread_count = 8; + pub fn multiplyVector(self: Self, input: Vector, output: Vector) !void { - std.debug.assert(self.rows.len == output.values.len); + if (self.thread_count == 0) { + try computeMatrixVectorMultiplication(self.rows, input, output.values); + + return; + } + + const n_threads = @min(try std.Thread.getCpuCount(), max_thread_count, self.thread_count); + + if (output.values.len % n_threads != 0) { + return error.UnsupportedThreadCount; + } + + const partial_length = output.values.len / n_threads; + + var threads: [max_thread_count]std.Thread = undefined; + + for (threads[0..n_threads], 0..) |*thread, index| { + thread.* = try std.Thread.spawn(.{}, computeMatrixVectorMultiplication, .{ + self.rows[index * partial_length .. (index + 1) * partial_length], + input, + output.values[index * partial_length .. (index + 1) * partial_length], + }); + } + + for (threads[0..n_threads]) |thread| { + thread.join(); + } +} + +fn computeMatrixVectorMultiplication( + rows: []const Vector, + input: Vector, + output_values: []f32, +) !void { + std.debug.assert(rows.len == output_values.len); - for (output.values, 0..) |*value, index| { - value.* = try self.rows[index].computeScalarProduct(input); + for (output_values, 0..) |*value, index| { + value.* = try rows[index].computeScalarProduct(input); } } diff --git a/src/transformer.zig b/src/transformer.zig index 280bd3b..dd3e167 100644 --- a/src/transformer.zig +++ b/src/transformer.zig @@ -17,8 +17,9 @@ pub fn createLeaky( allocator: std.mem.Allocator, model_path: []const u8, custom_sequence_length: usize, + thread_count: usize, ) !Self { - const checkpoint = try Checkpoint.readLeaky(allocator, model_path); + const checkpoint = try Checkpoint.readLeaky(allocator, model_path, thread_count); const sequence_length = if (custom_sequence_length == 0) checkpoint.max_sequence_length