-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(core): embed llama.cpp's server binary directly for LLM infe…
…rence (#2113) * chore: add llama-cpp-server sub crate * chore: add llama-cpp-server to embed llama-server directly * update * cleanup * update * update * update * update * update * update * update * update * update * update * update
- Loading branch information
Showing
10 changed files
with
356 additions
and
56 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
[package] | ||
name = "llama-cpp-server" | ||
version.workspace = true | ||
edition.workspace = true | ||
authors.workspace = true | ||
homepage.workspace = true | ||
|
||
[features] | ||
cuda = [] | ||
rocm = [] | ||
vulkan = [] | ||
|
||
[dependencies] | ||
futures.workspace = true | ||
http-api-bindings = { path = "../http-api-bindings" } | ||
reqwest.workspace = true | ||
serde_json.workspace = true | ||
tabby-inference = { path = "../tabby-inference" } | ||
tracing.workspace = true | ||
async-trait.workspace = true | ||
tokio = { workspace = true, features = ["process"] } | ||
anyhow.workspace = true | ||
|
||
[build-dependencies] | ||
cmake = "0.1" | ||
omnicopy_to_output = "0.1.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
use std::{env, path::Path}; | ||
|
||
use cmake::Config; | ||
use omnicopy_to_output::copy_to_output; | ||
|
||
fn main() { | ||
let mut config = Config::new("../llama-cpp-bindings/llama.cpp"); | ||
config.profile("Release"); | ||
config.define("LLAMA_NATIVE", "OFF"); | ||
config.define("INS_ENB", "ON"); | ||
|
||
if cfg!(target_os = "macos") { | ||
config.define("LLAMA_METAL", "ON"); | ||
config.define("LLAMA_METAL_EMBED_LIBRARY", "ON"); | ||
println!("cargo:rustc-link-lib=framework=Foundation"); | ||
println!("cargo:rustc-link-lib=framework=Accelerate"); | ||
println!("cargo:rustc-link-lib=framework=Metal"); | ||
println!("cargo:rustc-link-lib=framework=MetalKit"); | ||
} | ||
if cfg!(feature = "cuda") { | ||
config.define("LLAMA_CUBLAS", "ON"); | ||
config.define("CMAKE_POSITION_INDEPENDENT_CODE", "ON"); | ||
} | ||
if cfg!(feature = "rocm") { | ||
let amd_gpu_targets: Vec<&str> = vec![ | ||
"gfx803", | ||
"gfx900", | ||
"gfx906:xnack-", | ||
"gfx908:xnack-", | ||
"gfx90a:xnack+", | ||
"gfx90a:xnack-", | ||
"gfx940", | ||
"gfx941", | ||
"gfx942", | ||
"gfx1010", | ||
"gfx1012", | ||
"gfx1030", | ||
"gfx1031", | ||
"gfx1100", | ||
"gfx1101", | ||
"gfx1102", | ||
"gfx1103", | ||
]; | ||
|
||
let rocm_root = env::var("ROCM_ROOT").unwrap_or("/opt/rocm".to_string()); | ||
config.define("LLAMA_HIPBLAS", "ON"); | ||
config.define("CMAKE_C_COMPILER", format!("{}/llvm/bin/clang", rocm_root)); | ||
config.define( | ||
"CMAKE_CXX_COMPILER", | ||
format!("{}/llvm/bin/clang++", rocm_root), | ||
); | ||
config.define("AMDGPU_TARGETS", amd_gpu_targets.join(";")); | ||
} | ||
if cfg!(feature = "vulkan") { | ||
config.define("LLAMA_VULKAN", "ON"); | ||
} | ||
|
||
let out = config.build(); | ||
let server_binary = make_output_binary(&out, "server"); | ||
let renamed_server_binary = if cfg!(target_os = "macos") { | ||
make_output_binary(&out, "llama-server-metal") | ||
} else if cfg!(feature = "cuda") { | ||
make_output_binary(&out, "llama-server-cuda") | ||
} else if cfg!(feature = "rocm") { | ||
make_output_binary(&out, "llama-server-rocm") | ||
} else if cfg!(feature = "vulkan") { | ||
make_output_binary(&out, "llama-server-vulkan") | ||
} else { | ||
make_output_binary(&out, "llama-server") | ||
}; | ||
|
||
std::fs::rename(server_binary, &renamed_server_binary).expect("Failed to rename server binary"); | ||
copy_to_output(&renamed_server_binary) | ||
.expect("Failed to copy server binary to output directory"); | ||
} | ||
|
||
fn make_output_binary(out: &Path, name: &str) -> String { | ||
out.join("bin").join(name).display().to_string() + env::consts::EXE_SUFFIX | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
use std::{net::TcpListener, process::Stdio, sync::Arc}; | ||
|
||
use anyhow::Result; | ||
use async_trait::async_trait; | ||
use futures::stream::BoxStream; | ||
use serde_json::json; | ||
use tabby_inference::{CompletionOptions, CompletionStream, Embedding}; | ||
use tokio::task::JoinHandle; | ||
use tracing::warn; | ||
|
||
pub struct LlamaCppServer { | ||
port: u16, | ||
handle: JoinHandle<()>, | ||
completion: Arc<dyn CompletionStream>, | ||
embedding: Arc<dyn Embedding>, | ||
} | ||
|
||
#[async_trait] | ||
impl CompletionStream for LlamaCppServer { | ||
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> { | ||
self.completion.generate(prompt, options).await | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl Embedding for LlamaCppServer { | ||
async fn embed(&self, prompt: &str) -> Result<Vec<f32>> { | ||
self.embedding.embed(prompt).await | ||
} | ||
} | ||
|
||
impl LlamaCppServer { | ||
pub fn new(device: &str, model_path: &str, parallelism: u8) -> Self { | ||
let use_gpu = device != "cpu"; | ||
let Some(binary_name) = find_binary_name(Some(device)) else { | ||
panic!("Failed to find llama-server binary for device {device}, please make sure you have corresponding llama-server binary locates in the same directory as the current executable."); | ||
}; | ||
|
||
let model_path = model_path.to_owned(); | ||
let port = get_available_port(); | ||
let handle = tokio::spawn(async move { | ||
loop { | ||
let server_binary = std::env::current_exe() | ||
.expect("Failed to get current executable path") | ||
.parent() | ||
.expect("Failed to get parent directory") | ||
.join(&binary_name) | ||
.display() | ||
.to_string() | ||
+ std::env::consts::EXE_SUFFIX; | ||
let mut command = tokio::process::Command::new(server_binary); | ||
|
||
command | ||
.arg("-m") | ||
.arg(&model_path) | ||
.arg("--port") | ||
.arg(port.to_string()) | ||
.arg("-np") | ||
.arg(parallelism.to_string()) | ||
.arg("--log-disable") | ||
.kill_on_drop(true) | ||
.stderr(Stdio::null()) | ||
.stdout(Stdio::null()); | ||
|
||
if let Ok(n_threads) = std::env::var("LLAMA_CPP_N_THREADS") { | ||
command.arg("-t").arg(n_threads); | ||
} | ||
|
||
if use_gpu { | ||
let num_gpu_layers = | ||
std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into()); | ||
command.arg("-ngl").arg(&num_gpu_layers); | ||
} | ||
|
||
let mut process = command.spawn().unwrap_or_else(|e| { | ||
panic!( | ||
"Failed to start llama-server with command {:?}: {}", | ||
command, e | ||
) | ||
}); | ||
|
||
let status_code = process | ||
.wait() | ||
.await | ||
.ok() | ||
.and_then(|s| s.code()) | ||
.unwrap_or(-1); | ||
|
||
if status_code != 0 { | ||
warn!( | ||
"llama-server exited with status code {}, restarting...", | ||
status_code | ||
); | ||
} | ||
} | ||
}); | ||
|
||
Self { | ||
handle, | ||
port, | ||
completion: make_completion(port), | ||
embedding: make_embedding(port), | ||
} | ||
} | ||
|
||
pub async fn start(&self) { | ||
let client = reqwest::Client::new(); | ||
loop { | ||
let Ok(resp) = client.get(api_endpoint(self.port) + "/health").send().await else { | ||
continue; | ||
}; | ||
|
||
if resp.status().is_success() { | ||
return; | ||
} | ||
} | ||
} | ||
} | ||
|
||
fn find_binary_name(suffix: Option<&str>) -> Option<String> { | ||
let current_exe = std::env::current_exe().expect("Failed to get current executable path"); | ||
let binary_dir = current_exe | ||
.parent() | ||
.expect("Failed to get parent directory"); | ||
let binary_name = if let Some(suffix) = suffix { | ||
format!("llama-server-{}", suffix) | ||
} else { | ||
"llama-server".to_owned() | ||
}; | ||
std::fs::read_dir(binary_dir) | ||
.expect("Failed to read directory") | ||
.filter_map(|entry| entry.ok()) | ||
.filter(|entry| { | ||
entry | ||
.file_name() | ||
.to_string_lossy() | ||
.starts_with(&binary_name) | ||
}) | ||
.map(|entry| entry.path().display().to_string()) | ||
.next() | ||
} | ||
|
||
fn make_completion(port: u16) -> Arc<dyn CompletionStream> { | ||
let model_spec: String = serde_json::to_string(&json!({ | ||
"kind": "llama", | ||
"api_endpoint": api_endpoint(port), | ||
})) | ||
.expect("Failed to serialize model spec"); | ||
let (engine, _, _) = http_api_bindings::create(&model_spec); | ||
engine | ||
} | ||
|
||
pub fn make_embedding(port: u16) -> Arc<dyn Embedding> { | ||
let model_spec: String = serde_json::to_string(&json!({ | ||
"kind": "llama", | ||
"api_endpoint": api_endpoint(port), | ||
})) | ||
.expect("Failed to serialize model spec"); | ||
http_api_bindings::create_embedding(&model_spec) | ||
} | ||
|
||
fn get_available_port() -> u16 { | ||
(30888..40000) | ||
.find(|port| port_is_available(*port)) | ||
.expect("Failed to find available port") | ||
} | ||
|
||
fn port_is_available(port: u16) -> bool { | ||
TcpListener::bind(("127.0.0.1", port)).is_ok() | ||
} | ||
|
||
impl Drop for LlamaCppServer { | ||
fn drop(&mut self) { | ||
self.handle.abort(); | ||
} | ||
} | ||
|
||
fn api_endpoint(port: u16) -> String { | ||
format!("http://localhost:{port}") | ||
} |
Oops, something went wrong.