diff --git a/src/chat.zig b/src/chat.zig
index a8327e2..8b75a81 100644
--- a/src/chat.zig
+++ b/src/chat.zig
@@ -11,8 +11,8 @@ allocator: std.mem.Allocator,
transformer: Transformer,
tokenizer: Tokenizer,
sampler: Sampler,
-user_prompt: []const u8,
system_prompt: []const u8,
+user_prompt: []const u8,
pub fn init(allocator: std.mem.Allocator, args: ChatArgs) !Self {
const transformer = try Transformer.init(allocator, args.model_path, args.n_steps);
@@ -33,8 +33,8 @@ pub fn init(allocator: std.mem.Allocator, args: ChatArgs) !Self {
.transformer = transformer,
.tokenizer = tokenizer,
.sampler = sampler,
- .user_prompt = args.prompt,
.system_prompt = args.system_prompt,
+ .user_prompt = args.user_prompt,
};
}
@@ -44,10 +44,10 @@ pub fn deinit(self: *const Self) void {
self.sampler.deinit();
}
-const user_prompt_template_start = "[INST] ";
-const user_prompt_template_close = " [/INST]";
const system_prompt_template_start = "<>\n";
const system_prompt_template_close = "\n<>\n\n";
+const user_prompt_template_start = "[INST] ";
+const user_prompt_template_close = " [/INST]";
const bos_token = 1; // beginning of sequence
const eos_token = 2; // end of sequence
diff --git a/src/chat_args.zig b/src/chat_args.zig
index 9a7a0d9..0f05cc2 100644
--- a/src/chat_args.zig
+++ b/src/chat_args.zig
@@ -8,10 +8,10 @@ temperature: f32,
top_p: f32,
random_seed: u64,
n_steps: usize,
-prompt: []const u8,
system_prompt: []const u8,
+user_prompt: []const u8,
-const Option = enum { temperature, top_p, random_seed, n_steps, prompt, system_prompt };
+const Option = enum { temperature, top_p, random_seed, n_steps, system_prompt, user_prompt };
pub fn init(allocator: std.mem.Allocator) !Self {
var arg_iterator = try std.process.argsWithAllocator(allocator);
@@ -27,8 +27,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
var top_p: ?f32 = null;
var random_seed: ?u64 = null;
var n_steps: ?usize = null;
- var prompt: ?[]const u8 = null;
var system_prompt: ?[]const u8 = null;
+ var user_prompt: ?[]const u8 = null;
while (arg_iterator.next()) |arg| {
if (current_option) |option| {
@@ -40,10 +40,10 @@ pub fn init(allocator: std.mem.Allocator) !Self {
random_seed = try std.fmt.parseInt(u64, arg, 10);
} else if (option == .n_steps and n_steps == null) {
n_steps = try std.fmt.parseInt(usize, arg, 10);
- } else if (option == .prompt and prompt == null) {
- prompt = arg;
} else if (option == .system_prompt and system_prompt == null) {
system_prompt = arg;
+ } else if (option == .user_prompt and user_prompt == null) {
+ user_prompt = arg;
} else {
try help(1);
}
@@ -57,10 +57,10 @@ pub fn init(allocator: std.mem.Allocator) !Self {
current_option = .random_seed;
} else if (std.mem.eql(u8, arg, "--n_steps")) {
current_option = .n_steps;
- } else if (std.mem.eql(u8, arg, "--prompt")) {
- current_option = .prompt;
} else if (std.mem.eql(u8, arg, "--system_prompt")) {
current_option = .system_prompt;
+ } else if (std.mem.eql(u8, arg, "--user_prompt")) {
+ current_option = .user_prompt;
} else {
try help(if (std.mem.eql(u8, arg, "--help")) 0 else 1);
}
@@ -77,8 +77,8 @@ pub fn init(allocator: std.mem.Allocator) !Self {
.top_p = @max(@min(top_p orelse 0.9, 1), 0),
.random_seed = random_seed orelse @intCast(std.time.milliTimestamp()),
.n_steps = n_steps orelse 0,
- .prompt = prompt orelse "",
.system_prompt = system_prompt orelse "",
+ .user_prompt = user_prompt orelse "",
};
}
@@ -99,8 +99,8 @@ fn help(exit_status: u8) !noreturn {
try console.print(" --top_p = 0.9\n", .{});
try console.print(" --random_seed = \n", .{});
try console.print(" --n_steps = \n", .{});
- try console.print(" --prompt = \"\"\n", .{});
try console.print(" --system_prompt = \"\"\n", .{});
+ try console.print(" --user_prompt = \"\"\n", .{});
try console.print(" --help\n", .{});
std.process.exit(exit_status);