Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(core): embed llama.cpp's server binary directly for LLM inference #2113

Merged
merged 16 commits into from
May 14, 2024
47 changes: 44 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ members = [
"ee/tabby-db",
"ee/tabby-db-macros",
"ee/tabby-search",
"ee/tabby-schema",
"ee/tabby-schema", "crates/llama-cpp-server",
]

[workspace.package]
Expand Down
2 changes: 1 addition & 1 deletion crates/llama-cpp-bindings/llama.cpp
Submodule llama.cpp updated 203 files
26 changes: 26 additions & 0 deletions crates/llama-cpp-server/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
name = "llama-cpp-server"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true

[features]
cuda = []
rocm = []
vulkan = []

[dependencies]
futures.workspace = true
http-api-bindings = { path = "../http-api-bindings" }
reqwest.workspace = true
serde_json.workspace = true
tabby-inference = { path = "../tabby-inference" }
tracing.workspace = true
async-trait.workspace = true
tokio = { workspace = true, features = ["process"] }
anyhow.workspace = true

[build-dependencies]
cmake = "0.1"
omnicopy_to_output = "0.1.1"
79 changes: 79 additions & 0 deletions crates/llama-cpp-server/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::{env, path::Path};

use cmake::Config;
use omnicopy_to_output::copy_to_output;

fn main() {
let mut config = Config::new("../llama-cpp-bindings/llama.cpp");
config.profile("Release");
config.define("LLAMA_NATIVE", "OFF");
config.define("INS_ENB", "ON");

if cfg!(target_os = "macos") {
config.define("LLAMA_METAL", "ON");
config.define("LLAMA_METAL_EMBED_LIBRARY", "ON");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Accelerate");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
}
if cfg!(feature = "cuda") {
config.define("LLAMA_CUBLAS", "ON");
config.define("CMAKE_POSITION_INDEPENDENT_CODE", "ON");
}
if cfg!(feature = "rocm") {
let amd_gpu_targets: Vec<&str> = vec![
"gfx803",
"gfx900",
"gfx906:xnack-",
"gfx908:xnack-",
"gfx90a:xnack+",
"gfx90a:xnack-",
"gfx940",
"gfx941",
"gfx942",
"gfx1010",
"gfx1012",
"gfx1030",
"gfx1031",
"gfx1100",
"gfx1101",
"gfx1102",
"gfx1103",
];

let rocm_root = env::var("ROCM_ROOT").unwrap_or("/opt/rocm".to_string());
config.define("LLAMA_HIPBLAS", "ON");
config.define("CMAKE_C_COMPILER", format!("{}/llvm/bin/clang", rocm_root));
config.define(
"CMAKE_CXX_COMPILER",
format!("{}/llvm/bin/clang++", rocm_root),
);
config.define("AMDGPU_TARGETS", amd_gpu_targets.join(";"));
}
if cfg!(feature = "vulkan") {
config.define("LLAMA_VULKAN", "ON");
}

let out = config.build();
let server_binary = make_output_binary(&out, "server");
let renamed_server_binary = if cfg!(target_os = "macos") {
make_output_binary(&out, "llama-server-metal")
} else if cfg!(feature = "cuda") {
make_output_binary(&out, "llama-server-cuda")
} else if cfg!(feature = "rocm") {
make_output_binary(&out, "llama-server-rocm")
} else if cfg!(feature = "vulkan") {
make_output_binary(&out, "llama-server-vulkan")
} else {
make_output_binary(&out, "llama-server")
};

std::fs::rename(server_binary, &renamed_server_binary).expect("Failed to rename server binary");
copy_to_output(&renamed_server_binary)
.expect("Failed to copy server binary to output directory");
}

fn make_output_binary(out: &Path, name: &str) -> String {
out.join("bin").join(name).display().to_string() + env::consts::EXE_SUFFIX
}
180 changes: 180 additions & 0 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
use std::{net::TcpListener, process::Stdio, sync::Arc};

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde_json::json;
use tabby_inference::{CompletionOptions, CompletionStream, Embedding};
use tokio::task::JoinHandle;
use tracing::warn;

pub struct LlamaCppServer {
port: u16,
handle: JoinHandle<()>,
completion: Arc<dyn CompletionStream>,
embedding: Arc<dyn Embedding>,
}

#[async_trait]
impl CompletionStream for LlamaCppServer {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
self.completion.generate(prompt, options).await
}
}

#[async_trait]
impl Embedding for LlamaCppServer {
async fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
self.embedding.embed(prompt).await
}
}

impl LlamaCppServer {
pub fn new(device: &str, model_path: &str, parallelism: u8) -> Self {
let use_gpu = device != "cpu";
let Some(binary_name) = find_binary_name(Some(device)) else {
panic!("Failed to find llama-server binary for device {device}, please make sure you have corresponding llama-server binary locates in the same directory as the current executable.");
};

let model_path = model_path.to_owned();
let port = get_available_port();
let handle = tokio::spawn(async move {
loop {
let server_binary = std::env::current_exe()
.expect("Failed to get current executable path")
.parent()
.expect("Failed to get parent directory")
.join(&binary_name)
.display()
.to_string()
+ std::env::consts::EXE_SUFFIX;
let mut command = tokio::process::Command::new(server_binary);

command
.arg("-m")
.arg(&model_path)
.arg("--port")
.arg(port.to_string())
.arg("-np")
.arg(parallelism.to_string())
.arg("--log-disable")
.kill_on_drop(true)
.stderr(Stdio::null())
.stdout(Stdio::null());

if let Ok(n_threads) = std::env::var("LLAMA_CPP_N_THREADS") {
command.arg("-t").arg(n_threads);
}

if use_gpu {
let num_gpu_layers =
std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into());
command.arg("-ngl").arg(&num_gpu_layers);
}

let mut process = command.spawn().unwrap_or_else(|e| {
panic!(
"Failed to start llama-server with command {:?}: {}",
command, e
)
});

let status_code = process
.wait()
.await
.ok()
.and_then(|s| s.code())
.unwrap_or(-1);

if status_code != 0 {
warn!(
"llama-server exited with status code {}, restarting...",
status_code
);
}
}
});

Self {
handle,
port,
completion: make_completion(port),
embedding: make_embedding(port),
}
}

pub async fn start(&self) {
let client = reqwest::Client::new();
loop {
let Ok(resp) = client.get(api_endpoint(self.port) + "/health").send().await else {
continue;
};

if resp.status().is_success() {
return;
}
}
}
}

fn find_binary_name(suffix: Option<&str>) -> Option<String> {
let current_exe = std::env::current_exe().expect("Failed to get current executable path");
let binary_dir = current_exe
.parent()
.expect("Failed to get parent directory");
let binary_name = if let Some(suffix) = suffix {
format!("llama-server-{}", suffix)
} else {
"llama-server".to_owned()
};
std::fs::read_dir(binary_dir)
.expect("Failed to read directory")
.filter_map(|entry| entry.ok())
.filter(|entry| {
entry
.file_name()
.to_string_lossy()
.starts_with(&binary_name)
})
.map(|entry| entry.path().display().to_string())
.next()
}

fn make_completion(port: u16) -> Arc<dyn CompletionStream> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": api_endpoint(port),
}))
.expect("Failed to serialize model spec");
let (engine, _, _) = http_api_bindings::create(&model_spec);
engine
}

pub fn make_embedding(port: u16) -> Arc<dyn Embedding> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": api_endpoint(port),
}))
.expect("Failed to serialize model spec");
http_api_bindings::create_embedding(&model_spec)
}

fn get_available_port() -> u16 {
(30888..40000)
.find(|port| port_is_available(*port))
.expect("Failed to find available port")
}

fn port_is_available(port: u16) -> bool {
TcpListener::bind(("127.0.0.1", port)).is_ok()
}

impl Drop for LlamaCppServer {
fn drop(&mut self) {
self.handle.abort();
}
}

fn api_endpoint(port: u16) -> String {
format!("http://localhost:{port}")
}
Loading
Loading