Skip to content

Commit

Permalink
feat: adapt --chat-template parameter of llama-server (#2362)
Browse files Browse the repository at this point in the history
* feat: pass --chat-tempate, create ChatCompletionServer

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* improve impl

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
darknight and autofix-ci[bot] authored Jun 9, 2024
1 parent d26ab73 commit 94d35f0
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 143 deletions.
24 changes: 0 additions & 24 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

67 changes: 64 additions & 3 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ use async_trait::async_trait;
use futures::stream::BoxStream;
use supervisor::LlamaCppSupervisor;
use tabby_common::{
api::chat::Message,
config::{HttpModelConfigBuilder, ModelConfig},
registry::{parse_model_id, ModelRegistry, GGML_MODEL_RELATIVE_PATH},
};
use tabby_inference::{CompletionOptions, CompletionStream, Embedding};
use tabby_inference::{
ChatCompletionOptions, ChatCompletionStream, CompletionOptions, CompletionStream, Embedding,
};

fn api_endpoint(port: u16) -> String {
format!("http://127.0.0.1:{port}")
Expand All @@ -24,7 +27,7 @@ struct EmbeddingServer {

impl EmbeddingServer {
async fn new(num_gpu_layers: u16, model_path: &str, parallelism: u8) -> EmbeddingServer {
let server = LlamaCppSupervisor::new(num_gpu_layers, true, model_path, parallelism);
let server = LlamaCppSupervisor::new(num_gpu_layers, true, model_path, parallelism, None);
server.start().await;

let config = HttpModelConfigBuilder::default()
Expand Down Expand Up @@ -55,7 +58,7 @@ struct CompletionServer {

impl CompletionServer {
async fn new(num_gpu_layers: u16, model_path: &str, parallelism: u8) -> Self {
let server = LlamaCppSupervisor::new(num_gpu_layers, false, model_path, parallelism);
let server = LlamaCppSupervisor::new(num_gpu_layers, false, model_path, parallelism, None);
server.start().await;
let config = HttpModelConfigBuilder::default()
.api_endpoint(api_endpoint(server.port()))
Expand All @@ -74,6 +77,64 @@ impl CompletionStream for CompletionServer {
}
}

struct ChatCompletionServer {
#[allow(unused)]
server: LlamaCppSupervisor,
chat_completion: Arc<dyn ChatCompletionStream>,
}

impl ChatCompletionServer {
async fn new(
num_gpu_layers: u16,
model_path: &str,
parallelism: u8,
chat_template: String,
) -> Self {
let server = LlamaCppSupervisor::new(
num_gpu_layers,
false,
model_path,
parallelism,
Some(chat_template),
);
server.start().await;
let config = HttpModelConfigBuilder::default()
.api_endpoint(api_endpoint(server.port()))
.kind("openai/chat".to_string())
.build()
.expect("Failed to create HttpModelConfig");
let chat_completion = http_api_bindings::create_chat(&config).await;
Self {
server,
chat_completion,
}
}
}

#[async_trait]
impl ChatCompletionStream for ChatCompletionServer {
async fn chat_completion(
&self,
messages: &[Message],
options: ChatCompletionOptions,
) -> Result<BoxStream<String>> {
self.chat_completion
.chat_completion(messages, options)
.await
}
}

pub async fn create_chat_completion(
num_gpu_layers: u16,
model_path: &str,
parallelism: u8,
chat_template: String,
) -> Arc<dyn ChatCompletionStream> {
Arc::new(
ChatCompletionServer::new(num_gpu_layers, model_path, parallelism, chat_template).await,
)
}

pub async fn create_completion(
num_gpu_layers: u16,
model_path: &str,
Expand Down
5 changes: 5 additions & 0 deletions crates/llama-cpp-server/src/supervisor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl LlamaCppSupervisor {
embedding: bool,
model_path: &str,
parallelism: u8,
chat_template: Option<String>,
) -> LlamaCppSupervisor {
let Some(binary_name) = find_binary_name() else {
panic!("Failed to locate llama-server binary, please make sure you have llama-server binary locates in the same directory as the current executable.");
Expand Down Expand Up @@ -69,6 +70,10 @@ impl LlamaCppSupervisor {
.arg(var("LLAMA_CPP_EMBEDDING_N_UBATCH_SIZE").unwrap_or("4096".into()));
}

if let Some(chat_template) = chat_template.as_ref() {
command.arg("--chat-template").arg(chat_template);
}

let mut process = command.spawn().unwrap_or_else(|e| {
panic!(
"Failed to start llama-server with command {:?}: {}",
Expand Down
1 change: 0 additions & 1 deletion crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ sysinfo = "0.29.8"
nvml-wrapper = "0.9.0"
http-api-bindings = { path = "../http-api-bindings" }
async-stream = { workspace = true }
minijinja = { version = "1.0.8", features = ["loader"] }
llama-cpp-server = { path = "../llama-cpp-server" }
futures.workspace = true
async-trait.workspace = true
Expand Down
105 changes: 0 additions & 105 deletions crates/tabby/src/services/model/chat.rs

This file was deleted.

39 changes: 29 additions & 10 deletions crates/tabby/src/services/model/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
mod chat;

use std::{fs, path::PathBuf, sync::Arc};

use serde::Deserialize;
Expand All @@ -17,14 +15,35 @@ pub async fn load_chat_completion(chat: &ModelConfig) -> Arc<dyn ChatCompletionS
match chat {
ModelConfig::Http(http) => http_api_bindings::create_chat(http).await,

ModelConfig::Local(_) => {
let (engine, PromptInfo { chat_template, .. }) = load_completion(chat).await;

let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template");
};

Arc::new(chat::make_chat_completion(engine, chat_template))
ModelConfig::Local(llama) => {
if fs::metadata(&llama.model_id).is_ok() {
let path = PathBuf::from(&llama.model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH).display().to_string();
let engine_info = PromptInfo::read(path.join("tabby.json"));
llama_cpp_server::create_chat_completion(
llama.num_gpu_layers,
&model_path,
llama.parallelism,
engine_info.chat_template.unwrap_or_else(|| {
fatal!("Chat model requires specifying prompt template")
}),
)
.await
} else {
let (registry, name) = parse_model_id(&llama.model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
llama_cpp_server::create_chat_completion(
llama.num_gpu_layers,
&model_path,
llama.parallelism,
model_info.chat_template.clone().unwrap_or_else(|| {
fatal!("Chat model requires specifying prompt template")
}),
)
.await
}
}
}
}
Expand Down

0 comments on commit 94d35f0

Please sign in to comment.