diff --git a/Cargo.lock b/Cargo.lock index c81a9d81317d..fa4d90caad04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2773,6 +2773,20 @@ dependencies = [ "tracing", ] +[[package]] +name = "llama-cpp-server" +version = "0.12.0-dev.0" +dependencies = [ + "futures", + "http-api-bindings", + "reqwest 0.12.4", + "serde_json", + "tabby-common", + "tabby-inference", + "tokio", + "tracing", +] + [[package]] name = "lock_api" version = "0.4.10" diff --git a/Cargo.toml b/Cargo.toml index c63ab8bc12d1..c07031371262 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ "ee/tabby-db", "ee/tabby-db-macros", "ee/tabby-search", - "ee/tabby-schema", + "ee/tabby-schema", "crates/llama-cpp-server", ] [workspace.package] diff --git a/crates/llama-cpp-bindings/llama.cpp b/crates/llama-cpp-bindings/llama.cpp index b4e4b8a9351d..9aa672490c84 160000 --- a/crates/llama-cpp-bindings/llama.cpp +++ b/crates/llama-cpp-bindings/llama.cpp @@ -1 +1 @@ -Subproject commit b4e4b8a9351d918a56831c73cf9f25c1837b80d1 +Subproject commit 9aa672490c848e45eaa704a554e0f1f6df995fc8 diff --git a/crates/llama-cpp-server/Cargo.toml b/crates/llama-cpp-server/Cargo.toml new file mode 100644 index 000000000000..f3d7f45e0895 --- /dev/null +++ b/crates/llama-cpp-server/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "llama-cpp-server" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +futures.workspace = true +http-api-bindings = { path = "../http-api-bindings" } +reqwest.workspace = true +serde_json.workspace = true +tabby-inference = { path = "../tabby-inference" } +tracing.workspace = true +tokio = { workspace = true, features = ["process"] } + +[dev-dependencies] +tabby-common = { path = "../tabby-common" } \ No newline at end of file diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs new file mode 100644 index 000000000000..385f2dbc4df5 --- /dev/null +++ b/crates/llama-cpp-server/src/lib.rs @@ -0,0 +1,150 @@ +use std::{process::Stdio, sync::Arc}; + +use serde_json::json; +use tabby_inference::{ChatCompletionStream, CompletionStream, Embedding}; +use tokio::task::JoinHandle; +use tracing::warn; + +struct LlamaCppServer { + handle: JoinHandle<()>, +} + +const SERVER_PORT: u16 = 30888; + +impl LlamaCppServer { + pub fn new(model_path: &str, use_gpu: bool, parallelism: u8) -> Self { + let mut num_gpu_layers = std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into()); + if !use_gpu { + num_gpu_layers = "0".to_string(); + } + + let model_path = model_path.to_owned(); + let handle = tokio::spawn(async move { + loop { + let mut process = tokio::process::Command::new("llama-server") + .arg("-m") + .arg(&model_path) + .arg("--port") + .arg(SERVER_PORT.to_string()) + .arg("-ngl") + .arg(&num_gpu_layers) + .arg("-np") + .arg(parallelism.to_string()) + .kill_on_drop(true) + .stderr(Stdio::inherit()) + .stdout(Stdio::inherit()) + .spawn() + .expect("Failed to spawn llama-cpp-server"); + + let status_code = process + .wait() + .await + .ok() + .and_then(|s| s.code()) + .unwrap_or(-1); + + if status_code != 0 { + warn!( + "llama-server exited with status code {}, restarting...", + status_code + ); + } + } + }); + + Self { handle } + } + + async fn wait_for_health(&self) { + let client = reqwest::Client::new(); + loop { + let Ok(resp) = client.get(api_endpoint() + "/health").send().await else { + continue; + }; + + if resp.status().is_success() { + return; + } + } + } + + pub fn completion(&self, prompt_template: String) -> Arc { + let model_spec: String = serde_json::to_string(&json!({ + "kind": "llama", + "api_endpoint": api_endpoint(), + "prompt_template": prompt_template, + })) + .expect("Failed to serialize model spec"); + let (engine, _, _) = http_api_bindings::create(&model_spec); + engine + } + + pub fn chat(&self) -> Arc { + let model_spec: String = serde_json::to_string(&json!({ + "kind": "openai-chat", + "api_endpoint": format!("http://localhost:{SERVER_PORT}/v1"), + })) + .expect("Failed to serialize model spec"); + http_api_bindings::create_chat(&model_spec) + } + + pub fn embedding(self) -> Arc { + let model_spec: String = serde_json::to_string(&json!({ + "kind": "llama", + "api_endpoint": format!("http://localhost:{SERVER_PORT}"), + })) + .expect("Failed to serialize model spec"); + http_api_bindings::create_embedding(&model_spec) + } +} + +impl Drop for LlamaCppServer { + fn drop(&mut self) { + self.handle.abort(); + } +} + +fn api_endpoint() -> String { + format!("http://localhost:{SERVER_PORT}") +} + +#[cfg(test)] +mod tests { + use futures::StreamExt; + use tabby_common::registry::{parse_model_id, ModelRegistry}; + use tabby_inference::CompletionOptionsBuilder; + + use super::*; + + #[tokio::test] + #[ignore = "Should only be run in local manual testing"] + async fn test_create_completion() { + let model_id = "StarCoder-1B"; + let (registry, name) = parse_model_id(model_id); + let registry = ModelRegistry::new(registry).await; + let model_path = registry.get_model_path(name).display().to_string(); + let model_info = registry.get_model_info(name); + + let server = LlamaCppServer::new(&model_path, false, 1); + server.wait_for_health().await; + + let completion = server.completion(model_info.prompt_template.clone().unwrap()); + let s = completion + .generate( + "def fib(n):", + CompletionOptionsBuilder::default() + .max_decoding_tokens(7) + .max_input_length(1024) + .sampling_temperature(0.0) + .seed(12345) + .build() + .unwrap(), + ) + .await; + + let content: Vec = s.collect().await; + + let content = content.join(""); + assert_eq!(content, "\n if n <= 1:") + } +} diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 938091abc781..bdc5603628df 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -37,7 +37,7 @@ tabby-schema = { path = "../../ee/tabby-schema" } tabby-db = { path = "../../ee/tabby-db" } tarpc = { version = "0.33.0", features = ["serde-transport"] } thiserror.workspace = true -tokio = { workspace = true, features = ["fs", "process"] } +tokio = { workspace = true, features = ["fs"] } tokio-tungstenite = "0.21" tower = { version = "0.4", features = ["util", "limit"] } tower-http = { workspace = true, features = ["fs", "trace"] }