Skip to content

Commit

Permalink
Sync with the English version
Browse files Browse the repository at this point in the history
Signed-off-by: alabulei1 <[email protected]>
  • Loading branch information
alabulei1 authored Oct 30, 2023
1 parent 1461be0 commit 4c5cd77
Showing 1 changed file with 76 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ We also need to get the model. Here we use the llama-2-13b model.
```
curl -LO https://huggingface.co/wasmedge/llama2/blob/main/llama-2-13b-q5_k_m.gguf
```
Next, use WasmEdge to load the Codellama-instruct model and then ask the model to write code by chatting.
Next, use WasmEdge to load the llama-2-13b model and then ask the model to questions by input your .

```
wasmedge --dir .:. \
Expand Down Expand Up @@ -105,7 +105,7 @@ For example, the following command specifies a context length of 4k tokens, whic

```
LLAMA_LOG=1 LLAMA_N_CTX=4096 LLAMA_N_PREDICT=1024 wasmedge --dir .:. \
--nn-preload default:GGML:CPU:llama-2-7b-chat.Q5_K_M.gguf \
--nn-preload default:GGML:CPU:lllama-2-7b-chat-q5_k_m.gguf \
llama-chat.wasm default
llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from llama-2-7b-chat.Q5_K_M.gguf (version GGUF V2 (latest))
Expand Down Expand Up @@ -139,11 +139,19 @@ wasmedge --dir .:. \
The [main.rs](https://github.com/second-state/llama-utils/blob/main/chat/src/main.rs
) is the full Rust code to create an interactive chatbot using a LLM. The Rust program manages the user input, tracks the conversation history, transforms the text into the llama2 and other model’s chat templates, and runs the inference operations using the WASI NN standard API.

First, let's parse command line arguments to customize the chatbot's behavior. It extracts the following parameters: `model_alias` (a list for the loaded model), `ctx_size` (the size of the chat context), and `prompt_template` (a template that guides the conversation).
First, let's parse command line arguments to customize the chatbot's behavior using `Command` struct. It extracts the following parameters: `prompt` (a prompt that guides the conversation), `model_alias` (a list for the loaded model), and `ctx_size` (the size of the chat context).

```
fn main() -> Result<(), String> {
let matches = Command::new("Llama API Server")
.arg(
Arg::new("prompt")
.short('p')
.long("prompt")
.value_name("PROMPT")
.help("Sets the prompt.")
.required(true),
)
.arg(
Arg::new("model_alias")
.short('m')
Expand All @@ -161,155 +169,78 @@ fn main() -> Result<(), String> {
.help("Sets the prompt context size")
.default_value(DEFAULT_CTX_SIZE),
)
.arg(
Arg::new("prompt_template")
.short('p')
.long("prompt-template")
.value_parser([
"llama-2-chat",
"codellama-instruct",
"mistral-instruct-v0.1",
"belle-llama-2-chat",
"vicuna-chat",
"chatml",
])
.value_name("TEMPLATE")
.help("Sets the prompt template.")
.required(true),
)
.get_matches();
// model alias
let model_name = matches
.get_one::<String>("model_alias")
.unwrap()
.to_string();
// prompt context size
let ctx_size = matches.get_one::<u32>("ctx_size").unwrap();
CTX_SIZE
.set(*ctx_size as usize)
.expect("Fail to parse prompt context size");
// prompt
let prompt = matches.get_one::<String>("prompt").unwrap().to_string();
```

After that, the program will initialize the context size based on the value provided as `ctx_size`.
After that, the program will create a new Graph using the `GraphBuilder` and loads the model specified by the `model_name` .

```
let ctx_size = matches.get_one::<u32>("ctx_size").unwrap();
if CTX_SIZE.set(*ctx_size as usize).is_err() {
return Err(String::from("Fail to parse prompt context size"));
}
println!("[INFO] Prompt context size: {size}", size = ctx_size);
// load the model to wasi-nn
let graph =
wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Ggml, wasi_nn::ExecutionTarget::AUTO)
.build_from_cache(&model_name)
.expect("Failed to load the model");
```

Then, the program will parse and create the prompt template. The program parses the `prompt_template` parameter and converts it into an enum called `PromptTemplateType`. The code then uses the parsed PromptTemplateType to create an appropriate chat prompt template. This template is essential for generating prompts during the conversation. The prompt template is defined [here](https://github.com/second-state/llama-utils/blob/main/chat/src/main.rs#L193-L214).
Next, We create an execution context from the loaded Graph. The context is mutable because we will be changing it when we set the input tensor and execute the inference.

```
let prompt_template = matches
.get_one::<String>("prompt_template")
.unwrap()
.to_string();
let template_ty = match PromptTemplateType::from_str(&prompt_template) {
Ok(template) => template,
Err(e) => {
return Err(format!(
"Fail to parse prompt template type: {msg}",
msg = e.to_string()
))
}
};
println!("[INFO] Prompt template: {ty:?}", ty = &template_ty);
let template = create_prompt_template(template_ty);
let mut chat_request = ChatCompletionRequest::default();
```

Next step is to load the model. The program will load the model to `wasi-nn`. The model is identified by the `model_alias` provided via the command line.

```
let graph = match wasi_nn::GraphBuilder::new(
wasi_nn::GraphEncoding::Ggml,
wasi_nn::ExecutionTarget::CPU,
)
.build_from_cache(model_name.as_ref())
{
Ok(graph) => graph,
Err(e) => {
return Err(format!(
"Fail to load model into wasi-nn: {msg}",
msg = e.to_string()
))
}
};
```
Now we have finished the preparation work: loaded the model and prompt. Let's initiate a chat to chat with the model. The `read_input` function reads lines of text from the standard input until it receives a non-empty and non-whitespace line. The `chat_request` variable is an instance of a data structure that manages and stores information related to the ongoing conversation between the user and the AI assistant.

```
loop {
println!("[USER]:");
let user_message = read_input();
chat_request
.messages
.push(ChatCompletionRequestMessage::new(
ChatCompletionRole::User,
user_message,
));
// build prompt
let prompt = match template.build(&mut chat_request.messages) {
Ok(prompt) => prompt,
Err(e) => {
return Err(format!(
"Fail to build chat prompts: {msg}",
msg = e.to_string()
))
}
};
// read input tensor
let tensor_data = prompt.trim().as_bytes().to_vec();
if context
.set_input(0, wasi_nn::TensorType::U8, &[1], &tensor_data)
.is_err()
{
return Err(String::from("Fail to set input tensor"));
};
// execute the inference
if context.compute().is_err() {
return Err(String::from("Fail to execute model inference"));
}
// retrieve the output
let mut output_buffer = vec![0u8; *CTX_SIZE.get().unwrap()];
let mut output_size = match context.get_output(0, &mut output_buffer) {
Ok(size) => size,
Err(e) => {
return Err(format!(
"Fail to get output tensor: {msg}",
msg = e.to_string()
))
}
};
output_size = std::cmp::min(*CTX_SIZE.get().unwrap(), output_size);
let output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
println!("[ASSISTANT]:\n{}", output.trim());
// put the answer into the `messages` of chat_request
chat_request
.messages
.push(ChatCompletionRequestMessage::new(
ChatCompletionRole::Assistant,
output,
));
}
Ok(())
}
fn read_input() -> String {
loop {
let mut answer = String::new();
std::io::stdin()
.read_line(&mut answer)
.ok()
.expect("Failed to read line");
if !answer.is_empty() && answer != "\n" && answer != "\r\n" {
return answer;
}
}
}
```

For the reason why we need to run LLama2 model with WasmEdge, please check out [this article](https://medium.com/stackademic/fast-and-portable-llama2-inference-on-the-heterogeneous-edge-a62508e82359).
// initialize the execution context
let mut context = graph
.init_execution_context()
.expect("Failed to init context");
```
Next, The prompt is converted into bytes and set as the input tensor for the model inference.

```
// set input tensor
let tensor_data = prompt.as_str().as_bytes().to_vec();
context
.set_input(0, wasi_nn::TensorType::U8, &[1], &tensor_data)
.expect("Failed to set prompt as the input tensor");
```

Next, excute the model inference.

```
// execute the inference
context.compute().expect("Failed to complete inference");
```

After the inference is fiished, extract the result from the computation context and losing invalid UTF8 sequences handled by converting the output to a string using `String::from_utf8_lossy`.

```
let mut output_buffer = vec![0u8; *CTX_SIZE.get().unwrap()];
let mut output_size = context
.get_output(0, &mut output_buffer)
.expect("Failed to get output tensor");
output_size = std::cmp::min(*CTX_SIZE.get().unwrap(), output_size);
let output = String::from_utf8_lossy(&output_buffer[..output_size]).to_string();
```

Finally, print the prompt and the inference output to the console.

```
println!("\nprompt: {}", &prompt);
println!("\noutput: {}", output);
```

The code explanation above is simple one time chat with llama 2 model. But we have more!
* If you're looking for continuous conversations with llama 2 models, please check out the source code [here](https://github.com/second-state/llama-utils/tree/main/chat).
* If you want to construct OpenAI-compatible APIs specifically for your llama2 model, or the Llama2 model itself, please check out the surce code [here](https://github.com/second-state/llama-utils/tree/main/api-server).
* For the reason why we need to run LLama2 model with WasmEdge, please check out [this article](https://medium.com/stackademic/fast-and-portable-llama2-inference-on-the-heterogeneous-edge-a62508e82359).

0 comments on commit 4c5cd77

Please sign in to comment.