From b674eb7926c5afedc7626964d4aca20dcca3f0e6 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 13 May 2024 17:04:48 -0700 Subject: [PATCH] refactor(core): embed llama.cpp's server binary directly for LLM inference (#2113) * chore: add llama-cpp-server sub crate * chore: add llama-cpp-server to embed llama-server directly * update * cleanup * update * update * update * update * update * update * update * update * update * update * update --- Cargo.lock | 47 ++++++- Cargo.toml | 2 +- crates/llama-cpp-bindings/llama.cpp | 2 +- crates/llama-cpp-server/Cargo.toml | 26 ++++ crates/llama-cpp-server/build.rs | 79 +++++++++++ crates/llama-cpp-server/src/lib.rs | 180 +++++++++++++++++++++++++ crates/tabby/Cargo.toml | 12 +- crates/tabby/src/main.rs | 36 +---- crates/tabby/src/services/model/mod.rs | 26 ++-- ee/tabby-webserver/Cargo.toml | 2 +- 10 files changed, 356 insertions(+), 56 deletions(-) create mode 100644 crates/llama-cpp-server/Cargo.toml create mode 100644 crates/llama-cpp-server/build.rs create mode 100644 crates/llama-cpp-server/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index c81a9d81317d..6d3d0850e53c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -141,9 +141,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.71" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" +checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3" [[package]] name = "apalis" @@ -585,6 +585,12 @@ dependencies = [ "serde", ] +[[package]] +name = "build-target" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "832133bbabbbaa9fbdba793456a2827627a7d2b8fb96032fa1e7666d7895832b" + [[package]] name = "bumpalo" version = "3.13.0" @@ -2773,6 +2779,23 @@ dependencies = [ "tracing", ] +[[package]] +name = "llama-cpp-server" +version = "0.12.0-dev.0" +dependencies = [ + "anyhow", + "async-trait", + "cmake", + "futures", + "http-api-bindings", + "omnicopy_to_output", + "reqwest 0.12.4", + "serde_json", + "tabby-inference", + "tokio", + "tracing", +] + [[package]] name = "lock_api" version = "0.4.10" @@ -3337,6 +3360,18 @@ dependencies = [ "url", ] +[[package]] +name = "omnicopy_to_output" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10aff4d07c3656c416a997301d51ed83be62cbb256b421f86b014931217f2393" +dependencies = [ + "anyhow", + "build-target", + "fs_extra", + "project-root", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -3835,6 +3870,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "project-root" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bccbff07d5ed689c4087d20d7307a52ab6141edeedf487c3876a55b86cf63df" + [[package]] name = "psm" version = "0.1.21" @@ -5327,7 +5368,7 @@ dependencies = [ "hyper 1.3.1", "insta", "lazy_static", - "llama-cpp-bindings", + "llama-cpp-server", "minijinja", "nvml-wrapper", "openssl", diff --git a/Cargo.toml b/Cargo.toml index c63ab8bc12d1..c07031371262 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ members = [ "ee/tabby-db", "ee/tabby-db-macros", "ee/tabby-search", - "ee/tabby-schema", + "ee/tabby-schema", "crates/llama-cpp-server", ] [workspace.package] diff --git a/crates/llama-cpp-bindings/llama.cpp b/crates/llama-cpp-bindings/llama.cpp index b4e4b8a9351d..9aa672490c84 160000 --- a/crates/llama-cpp-bindings/llama.cpp +++ b/crates/llama-cpp-bindings/llama.cpp @@ -1 +1 @@ -Subproject commit b4e4b8a9351d918a56831c73cf9f25c1837b80d1 +Subproject commit 9aa672490c848e45eaa704a554e0f1f6df995fc8 diff --git a/crates/llama-cpp-server/Cargo.toml b/crates/llama-cpp-server/Cargo.toml new file mode 100644 index 000000000000..5cfcf64846dc --- /dev/null +++ b/crates/llama-cpp-server/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "llama-cpp-server" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[features] +cuda = [] +rocm = [] +vulkan = [] + +[dependencies] +futures.workspace = true +http-api-bindings = { path = "../http-api-bindings" } +reqwest.workspace = true +serde_json.workspace = true +tabby-inference = { path = "../tabby-inference" } +tracing.workspace = true +async-trait.workspace = true +tokio = { workspace = true, features = ["process"] } +anyhow.workspace = true + +[build-dependencies] +cmake = "0.1" +omnicopy_to_output = "0.1.1" \ No newline at end of file diff --git a/crates/llama-cpp-server/build.rs b/crates/llama-cpp-server/build.rs new file mode 100644 index 000000000000..b59a76f9fc27 --- /dev/null +++ b/crates/llama-cpp-server/build.rs @@ -0,0 +1,79 @@ +use std::{env, path::Path}; + +use cmake::Config; +use omnicopy_to_output::copy_to_output; + +fn main() { + let mut config = Config::new("../llama-cpp-bindings/llama.cpp"); + config.profile("Release"); + config.define("LLAMA_NATIVE", "OFF"); + config.define("INS_ENB", "ON"); + + if cfg!(target_os = "macos") { + config.define("LLAMA_METAL", "ON"); + config.define("LLAMA_METAL_EMBED_LIBRARY", "ON"); + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=framework=Accelerate"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=MetalKit"); + } + if cfg!(feature = "cuda") { + config.define("LLAMA_CUBLAS", "ON"); + config.define("CMAKE_POSITION_INDEPENDENT_CODE", "ON"); + } + if cfg!(feature = "rocm") { + let amd_gpu_targets: Vec<&str> = vec![ + "gfx803", + "gfx900", + "gfx906:xnack-", + "gfx908:xnack-", + "gfx90a:xnack+", + "gfx90a:xnack-", + "gfx940", + "gfx941", + "gfx942", + "gfx1010", + "gfx1012", + "gfx1030", + "gfx1031", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1103", + ]; + + let rocm_root = env::var("ROCM_ROOT").unwrap_or("/opt/rocm".to_string()); + config.define("LLAMA_HIPBLAS", "ON"); + config.define("CMAKE_C_COMPILER", format!("{}/llvm/bin/clang", rocm_root)); + config.define( + "CMAKE_CXX_COMPILER", + format!("{}/llvm/bin/clang++", rocm_root), + ); + config.define("AMDGPU_TARGETS", amd_gpu_targets.join(";")); + } + if cfg!(feature = "vulkan") { + config.define("LLAMA_VULKAN", "ON"); + } + + let out = config.build(); + let server_binary = make_output_binary(&out, "server"); + let renamed_server_binary = if cfg!(target_os = "macos") { + make_output_binary(&out, "llama-server-metal") + } else if cfg!(feature = "cuda") { + make_output_binary(&out, "llama-server-cuda") + } else if cfg!(feature = "rocm") { + make_output_binary(&out, "llama-server-rocm") + } else if cfg!(feature = "vulkan") { + make_output_binary(&out, "llama-server-vulkan") + } else { + make_output_binary(&out, "llama-server") + }; + + std::fs::rename(server_binary, &renamed_server_binary).expect("Failed to rename server binary"); + copy_to_output(&renamed_server_binary) + .expect("Failed to copy server binary to output directory"); +} + +fn make_output_binary(out: &Path, name: &str) -> String { + out.join("bin").join(name).display().to_string() + env::consts::EXE_SUFFIX +} diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs new file mode 100644 index 000000000000..f4250e90f3b1 --- /dev/null +++ b/crates/llama-cpp-server/src/lib.rs @@ -0,0 +1,180 @@ +use std::{net::TcpListener, process::Stdio, sync::Arc}; + +use anyhow::Result; +use async_trait::async_trait; +use futures::stream::BoxStream; +use serde_json::json; +use tabby_inference::{CompletionOptions, CompletionStream, Embedding}; +use tokio::task::JoinHandle; +use tracing::warn; + +pub struct LlamaCppServer { + port: u16, + handle: JoinHandle<()>, + completion: Arc, + embedding: Arc, +} + +#[async_trait] +impl CompletionStream for LlamaCppServer { + async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { + self.completion.generate(prompt, options).await + } +} + +#[async_trait] +impl Embedding for LlamaCppServer { + async fn embed(&self, prompt: &str) -> Result> { + self.embedding.embed(prompt).await + } +} + +impl LlamaCppServer { + pub fn new(device: &str, model_path: &str, parallelism: u8) -> Self { + let use_gpu = device != "cpu"; + let Some(binary_name) = find_binary_name(Some(device)) else { + panic!("Failed to find llama-server binary for device {device}, please make sure you have corresponding llama-server binary locates in the same directory as the current executable."); + }; + + let model_path = model_path.to_owned(); + let port = get_available_port(); + let handle = tokio::spawn(async move { + loop { + let server_binary = std::env::current_exe() + .expect("Failed to get current executable path") + .parent() + .expect("Failed to get parent directory") + .join(&binary_name) + .display() + .to_string() + + std::env::consts::EXE_SUFFIX; + let mut command = tokio::process::Command::new(server_binary); + + command + .arg("-m") + .arg(&model_path) + .arg("--port") + .arg(port.to_string()) + .arg("-np") + .arg(parallelism.to_string()) + .arg("--log-disable") + .kill_on_drop(true) + .stderr(Stdio::null()) + .stdout(Stdio::null()); + + if let Ok(n_threads) = std::env::var("LLAMA_CPP_N_THREADS") { + command.arg("-t").arg(n_threads); + } + + if use_gpu { + let num_gpu_layers = + std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into()); + command.arg("-ngl").arg(&num_gpu_layers); + } + + let mut process = command.spawn().unwrap_or_else(|e| { + panic!( + "Failed to start llama-server with command {:?}: {}", + command, e + ) + }); + + let status_code = process + .wait() + .await + .ok() + .and_then(|s| s.code()) + .unwrap_or(-1); + + if status_code != 0 { + warn!( + "llama-server exited with status code {}, restarting...", + status_code + ); + } + } + }); + + Self { + handle, + port, + completion: make_completion(port), + embedding: make_embedding(port), + } + } + + pub async fn start(&self) { + let client = reqwest::Client::new(); + loop { + let Ok(resp) = client.get(api_endpoint(self.port) + "/health").send().await else { + continue; + }; + + if resp.status().is_success() { + return; + } + } + } +} + +fn find_binary_name(suffix: Option<&str>) -> Option { + let current_exe = std::env::current_exe().expect("Failed to get current executable path"); + let binary_dir = current_exe + .parent() + .expect("Failed to get parent directory"); + let binary_name = if let Some(suffix) = suffix { + format!("llama-server-{}", suffix) + } else { + "llama-server".to_owned() + }; + std::fs::read_dir(binary_dir) + .expect("Failed to read directory") + .filter_map(|entry| entry.ok()) + .filter(|entry| { + entry + .file_name() + .to_string_lossy() + .starts_with(&binary_name) + }) + .map(|entry| entry.path().display().to_string()) + .next() +} + +fn make_completion(port: u16) -> Arc { + let model_spec: String = serde_json::to_string(&json!({ + "kind": "llama", + "api_endpoint": api_endpoint(port), + })) + .expect("Failed to serialize model spec"); + let (engine, _, _) = http_api_bindings::create(&model_spec); + engine +} + +pub fn make_embedding(port: u16) -> Arc { + let model_spec: String = serde_json::to_string(&json!({ + "kind": "llama", + "api_endpoint": api_endpoint(port), + })) + .expect("Failed to serialize model spec"); + http_api_bindings::create_embedding(&model_spec) +} + +fn get_available_port() -> u16 { + (30888..40000) + .find(|port| port_is_available(*port)) + .expect("Failed to find available port") +} + +fn port_is_available(port: u16) -> bool { + TcpListener::bind(("127.0.0.1", port)).is_ok() +} + +impl Drop for LlamaCppServer { + fn drop(&mut self) { + self.handle.abort(); + } +} + +fn api_endpoint(port: u16) -> String { + format!("http://localhost:{port}") +} diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 6169e51c80fc..e0c4b5bd9e1f 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -6,11 +6,11 @@ authors.workspace = true homepage.workspace = true [features] -default = ["ee", "dep:color-eyre"] +default = ["ee"] ee = ["dep:tabby-webserver"] -cuda = ["llama-cpp-bindings/cuda"] -rocm = ["llama-cpp-bindings/rocm"] -vulkan = ["llama-cpp-bindings/vulkan"] +cuda = [] +rocm = [] +vulkan = [] # If compiling on a system without OpenSSL installed, or cross-compiling for a different # architecture, enable this feature to compile OpenSSL as part of the build. # See https://docs.rs/openssl/#vendored for more. @@ -47,7 +47,7 @@ async-stream = { workspace = true } minijinja = { version = "1.0.8", features = ["loader"] } textdistance = "1.0.2" regex.workspace = true -llama-cpp-bindings = { path = "../llama-cpp-bindings" } +llama-cpp-server = { path = "../llama-cpp-server" } futures.workspace = true async-trait.workspace = true tabby-webserver = { path = "../../ee/tabby-webserver", optional = true } @@ -57,7 +57,7 @@ axum-prometheus = "0.6" uuid.workspace = true cached = { workspace = true, features = ["async"] } parse-git-url = "0.5.1" -color-eyre = { version = "0.6.3", optional = true } +color-eyre = { version = "0.6.3" } [dependencies.openssl] optional = true diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index ddc5bc67bc06..65bd4e8ea0a0 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -61,19 +61,15 @@ pub enum Device { #[strum(serialize = "cpu")] Cpu, - #[cfg(feature = "cuda")] #[strum(serialize = "cuda")] Cuda, - #[cfg(feature = "rocm")] #[strum(serialize = "rocm")] Rocm, - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] #[strum(serialize = "metal")] Metal, - #[cfg(feature = "vulkan")] #[strum(serialize = "vulkan")] Vulkan, @@ -83,40 +79,16 @@ pub enum Device { } impl Device { - #[cfg(all(target_os = "macos", target_arch = "aarch64"))] pub fn ggml_use_gpu(&self) -> bool { - *self == Device::Metal - } - - #[cfg(feature = "cuda")] - pub fn ggml_use_gpu(&self) -> bool { - *self == Device::Cuda - } - - #[cfg(feature = "rocm")] - pub fn ggml_use_gpu(&self) -> bool { - *self == Device::Rocm - } - - #[cfg(feature = "vulkan")] - pub fn ggml_use_gpu(&self) -> bool { - *self == Device::Vulkan - } - - #[cfg(not(any( - all(target_os = "macos", target_arch = "aarch64"), - feature = "cuda", - feature = "rocm", - feature = "vulkan", - )))] - pub fn ggml_use_gpu(&self) -> bool { - false + match self { + Device::Metal | Device::Vulkan | Device::Cuda | Device::Rocm => true, + Device::Cpu | Device::ExperimentalHttp => false, + } } } #[tokio::main] async fn main() { - #[cfg(feature = "dep:color-eyre")] color_eyre::install().expect("Must be able to install color_eyre"); let cli = Cli::parse(); diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index 655fba5ba561..1b0466245ae3 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -64,17 +64,18 @@ async fn load_completion( device, model_path.display().to_string().as_str(), parallelism, - ); + ) + .await; let engine_info = PromptInfo::read(path.join("tabby.json")); - (Arc::new(engine), engine_info) + (engine, engine_info) } else { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; let model_path = registry.get_model_path(name).display().to_string(); let model_info = registry.get_model_info(name); - let engine = create_ggml_engine(device, &model_path, parallelism); + let engine = create_ggml_engine(device, &model_path, parallelism).await; ( - Arc::new(engine), + engine, PromptInfo { prompt_template: model_info.prompt_template.clone(), chat_template: model_info.chat_template.clone(), @@ -96,7 +97,11 @@ impl PromptInfo { } } -fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl CompletionStream { +async fn create_ggml_engine( + device: &Device, + model_path: &str, + parallelism: u8, +) -> Arc { if !device.ggml_use_gpu() { InfoMessage::new( "CPU Device", @@ -107,14 +112,11 @@ fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> imp ], ); } - let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() - .model_path(model_path.to_owned()) - .use_gpu(device.ggml_use_gpu()) - .parallelism(parallelism) - .build() - .expect("Failed to create llama text generation options"); - llama_cpp_bindings::LlamaTextGeneration::new(options) + let device_str = device.to_string().to_lowercase(); + let server = llama_cpp_server::LlamaCppServer::new(&device_str, model_path, parallelism); + server.start().await; + Arc::new(server) } pub async fn download_model_if_needed(model: &str) { diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 938091abc781..bdc5603628df 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -37,7 +37,7 @@ tabby-schema = { path = "../../ee/tabby-schema" } tabby-db = { path = "../../ee/tabby-db" } tarpc = { version = "0.33.0", features = ["serde-transport"] } thiserror.workspace = true -tokio = { workspace = true, features = ["fs", "process"] } +tokio = { workspace = true, features = ["fs"] } tokio-tungstenite = "0.21" tower = { version = "0.4", features = ["util", "limit"] } tower-http = { workspace = true, features = ["fs", "trace"] }