From 94d35f0c42921cd368bdcc209b7f3d428196f496 Mon Sep 17 00:00:00 2001 From: Eric Date: Sun, 9 Jun 2024 11:20:19 +0800 Subject: [PATCH] feat: adapt --chat-template parameter of llama-server (#2362) * feat: pass --chat-tempate, create ChatCompletionServer * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * improve impl * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- Cargo.lock | 24 ----- crates/llama-cpp-server/src/lib.rs | 67 +++++++++++++- crates/llama-cpp-server/src/supervisor.rs | 5 ++ crates/tabby/Cargo.toml | 1 - crates/tabby/src/services/model/chat.rs | 105 ---------------------- crates/tabby/src/services/model/mod.rs | 39 +++++--- 6 files changed, 98 insertions(+), 143 deletions(-) delete mode 100644 crates/tabby/src/services/model/chat.rs diff --git a/Cargo.lock b/Cargo.lock index fca9cb85f1a1..f9f9e9e6f300 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2762,12 +2762,6 @@ dependencies = [ "libc", ] -[[package]] -name = "memo-map" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374c335b2df19e62d4cb323103473cbc6510980253119180de862d89184f6a83" - [[package]] name = "metrics" version = "0.22.3" @@ -2826,17 +2820,6 @@ dependencies = [ "unicase", ] -[[package]] -name = "minijinja" -version = "1.0.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55e877d961d4f96ce13615862322df7c0b6d169d40cab71a7ef3f9b9e594451e" -dependencies = [ - "memo-map", - "self_cell", - "serde", -] - [[package]] name = "minimal-lexical" version = "0.2.1" @@ -4375,12 +4358,6 @@ dependencies = [ "thin-slice", ] -[[package]] -name = "self_cell" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d369a96f978623eb3dc28807c4852d6cc617fed53da5d3c400feff1ef34a714a" - [[package]] name = "semver" version = "1.0.23" @@ -5163,7 +5140,6 @@ dependencies = [ "insta", "lazy_static", "llama-cpp-server", - "minijinja", "nvml-wrapper", "openssl", "parse-git-url", diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index fe0febc849a1..321fe2cfe431 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -7,10 +7,13 @@ use async_trait::async_trait; use futures::stream::BoxStream; use supervisor::LlamaCppSupervisor; use tabby_common::{ + api::chat::Message, config::{HttpModelConfigBuilder, ModelConfig}, registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH}, }; -use tabby_inference::{CompletionOptions, CompletionStream, Embedding}; +use tabby_inference::{ + ChatCompletionOptions, ChatCompletionStream, CompletionOptions, CompletionStream, Embedding, +}; fn api_endpoint(port: u16) -> String { format!("http://127.0.0.1:{port}") @@ -24,7 +27,7 @@ struct EmbeddingServer { impl EmbeddingServer { async fn new(num_gpu_layers: u16, model_path: &str, parallelism: u8) -> EmbeddingServer { - let server = LlamaCppSupervisor::new(num_gpu_layers, true, model_path, parallelism); + let server = LlamaCppSupervisor::new(num_gpu_layers, true, model_path, parallelism, None); server.start().await; let config = HttpModelConfigBuilder::default() @@ -55,7 +58,7 @@ struct CompletionServer { impl CompletionServer { async fn new(num_gpu_layers: u16, model_path: &str, parallelism: u8) -> Self { - let server = LlamaCppSupervisor::new(num_gpu_layers, false, model_path, parallelism); + let server = LlamaCppSupervisor::new(num_gpu_layers, false, model_path, parallelism, None); server.start().await; let config = HttpModelConfigBuilder::default() .api_endpoint(api_endpoint(server.port())) @@ -74,6 +77,64 @@ impl CompletionStream for CompletionServer { } } +struct ChatCompletionServer { + #[allow(unused)] + server: LlamaCppSupervisor, + chat_completion: Arc, +} + +impl ChatCompletionServer { + async fn new( + num_gpu_layers: u16, + model_path: &str, + parallelism: u8, + chat_template: String, + ) -> Self { + let server = LlamaCppSupervisor::new( + num_gpu_layers, + false, + model_path, + parallelism, + Some(chat_template), + ); + server.start().await; + let config = HttpModelConfigBuilder::default() + .api_endpoint(api_endpoint(server.port())) + .kind("openai/chat".to_string()) + .build() + .expect("Failed to create HttpModelConfig"); + let chat_completion = http_api_bindings::create_chat(&config).await; + Self { + server, + chat_completion, + } + } +} + +#[async_trait] +impl ChatCompletionStream for ChatCompletionServer { + async fn chat_completion( + &self, + messages: &[Message], + options: ChatCompletionOptions, + ) -> Result> { + self.chat_completion + .chat_completion(messages, options) + .await + } +} + +pub async fn create_chat_completion( + num_gpu_layers: u16, + model_path: &str, + parallelism: u8, + chat_template: String, +) -> Arc { + Arc::new( + ChatCompletionServer::new(num_gpu_layers, model_path, parallelism, chat_template).await, + ) +} + pub async fn create_completion( num_gpu_layers: u16, model_path: &str, diff --git a/crates/llama-cpp-server/src/supervisor.rs b/crates/llama-cpp-server/src/supervisor.rs index b8495641bacf..9c745ec0e3b1 100644 --- a/crates/llama-cpp-server/src/supervisor.rs +++ b/crates/llama-cpp-server/src/supervisor.rs @@ -21,6 +21,7 @@ impl LlamaCppSupervisor { embedding: bool, model_path: &str, parallelism: u8, + chat_template: Option, ) -> LlamaCppSupervisor { let Some(binary_name) = find_binary_name() else { panic!("Failed to locate llama-server binary, please make sure you have llama-server binary locates in the same directory as the current executable."); @@ -69,6 +70,10 @@ impl LlamaCppSupervisor { .arg(var("LLAMA_CPP_EMBEDDING_N_UBATCH_SIZE").unwrap_or("4096".into())); } + if let Some(chat_template) = chat_template.as_ref() { + command.arg("--chat-template").arg(chat_template); + } + let mut process = command.spawn().unwrap_or_else(|e| { panic!( "Failed to start llama-server with command {:?}: {}", diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 5d034797cf43..0cb3cc21725f 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -44,7 +44,6 @@ sysinfo = "0.29.8" nvml-wrapper = "0.9.0" http-api-bindings = { path = "../http-api-bindings" } async-stream = { workspace = true } -minijinja = { version = "1.0.8", features = ["loader"] } llama-cpp-server = { path = "../llama-cpp-server" } futures.workspace = true async-trait.workspace = true diff --git a/crates/tabby/src/services/model/chat.rs b/crates/tabby/src/services/model/chat.rs deleted file mode 100644 index 4508934fe5a7..000000000000 --- a/crates/tabby/src/services/model/chat.rs +++ /dev/null @@ -1,105 +0,0 @@ -use std::sync::Arc; - -use anyhow::Result; -use async_stream::stream; -use async_trait::async_trait; -use futures::stream::BoxStream; -use minijinja::{context, Environment}; -use tabby_common::api::chat::Message; -use tabby_inference::{ - ChatCompletionOptions, ChatCompletionStream, CompletionOptionsBuilder, CompletionStream, -}; - -struct ChatPromptBuilder { - env: Environment<'static>, -} - -impl ChatPromptBuilder { - pub fn new(prompt_template: String) -> Self { - let mut env = Environment::new(); - env.add_function("raise_exception", |e: String| panic!("{}", e)); - env.add_template_owned("prompt", prompt_template) - .expect("Failed to compile template"); - - Self { env } - } - - pub fn build(&self, messages: &[Message]) -> Result { - // System prompt is not supported for TextGenerationStream backed chat. - let messages = messages - .iter() - .filter(|x| x.role != "system") - .collect::>(); - Ok(self.env.get_template("prompt")?.render(context!( - messages => messages - ))?) - } -} - -struct ChatCompletionImpl { - engine: Arc, - prompt_builder: ChatPromptBuilder, -} - -#[async_trait] -impl ChatCompletionStream for ChatCompletionImpl { - async fn chat_completion( - &self, - messages: &[Message], - options: ChatCompletionOptions, - ) -> Result> { - let options = CompletionOptionsBuilder::default() - .max_input_length(2048) - .seed(options.seed) - .max_decoding_tokens(options.max_decoding_tokens) - .sampling_temperature(options.sampling_temperature) - .presence_penalty(options.presence_penalty) - .build()?; - - let prompt = self.prompt_builder.build(messages)?; - - let s = stream! { - for await content in self.engine.generate(&prompt, options).await { - yield content; - } - }; - - Ok(Box::pin(s)) - } -} - -pub fn make_chat_completion( - engine: Arc, - prompt_template: String, -) -> impl ChatCompletionStream { - ChatCompletionImpl { - engine, - prompt_builder: ChatPromptBuilder::new(prompt_template), - } -} - -#[cfg(test)] -mod tests { - use super::*; - static PROMPT_TEMPLATE : &str = "{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"; - - #[test] - fn test_it_works() { - let builder = ChatPromptBuilder::new(PROMPT_TEMPLATE.to_owned()); - let messages = vec![ - Message { - role: "user".to_owned(), - content: "What is tail recursion?".to_owned(), - }, - Message { - role: "assistant".to_owned(), - content: "It's a kind of optimization in compiler?".to_owned(), - }, - Message { - role: "user".to_owned(), - content: "Could you share more details?".to_owned(), - }, - ]; - assert_eq!(builder.build(&messages).unwrap(), "[INST] What is tail recursion? [/INST]It's a kind of optimization in compiler? [INST] Could you share more details? [/INST]") - } -} diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index ba36d9152dd6..5814573115bc 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -1,5 +1,3 @@ -mod chat; - use std::{fs, path::PathBuf, sync::Arc}; use serde::Deserialize; @@ -17,14 +15,35 @@ pub async fn load_chat_completion(chat: &ModelConfig) -> Arc http_api_bindings::create_chat(http).await, - ModelConfig::Local(_) => { - let (engine, PromptInfo { chat_template, .. }) = load_completion(chat).await; - - let Some(chat_template) = chat_template else { - fatal!("Chat model requires specifying prompt template"); - }; - - Arc::new(chat::make_chat_completion(engine, chat_template)) + ModelConfig::Local(llama) => { + if fs::metadata(&llama.model_id).is_ok() { + let path = PathBuf::from(&llama.model_id); + let model_path = path.join(GGML_MODEL_RELATIVE_PATH).display().to_string(); + let engine_info = PromptInfo::read(path.join("tabby.json")); + llama_cpp_server::create_chat_completion( + llama.num_gpu_layers, + &model_path, + llama.parallelism, + engine_info.chat_template.unwrap_or_else(|| { + fatal!("Chat model requires specifying prompt template") + }), + ) + .await + } else { + let (registry, name) = parse_model_id(&llama.model_id); + let registry = ModelRegistry::new(registry).await; + let model_path = registry.get_model_path(name).display().to_string(); + let model_info = registry.get_model_info(name); + llama_cpp_server::create_chat_completion( + llama.num_gpu_layers, + &model_path, + llama.parallelism, + model_info.chat_template.clone().unwrap_or_else(|| { + fatal!("Chat model requires specifying prompt template") + }), + ) + .await + } } } }