Skip to content

Commit

Permalink
refactor(core): embed llama.cpp's server binary directly for LLM infe…
Browse files Browse the repository at this point in the history
…rence (#2113)

* chore: add llama-cpp-server sub crate

* chore: add llama-cpp-server to embed llama-server directly

* update

* cleanup

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
  • Loading branch information
wsxiaoys authored May 14, 2024
1 parent 0b1315c commit b674eb7
Show file tree
Hide file tree
Showing 10 changed files with 356 additions and 56 deletions.
47 changes: 44 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ members = [
"ee/tabby-db",
"ee/tabby-db-macros",
"ee/tabby-search",
"ee/tabby-schema",
"ee/tabby-schema", "crates/llama-cpp-server",
]

[workspace.package]
Expand Down
2 changes: 1 addition & 1 deletion crates/llama-cpp-bindings/llama.cpp
Submodule llama.cpp updated 203 files
26 changes: 26 additions & 0 deletions crates/llama-cpp-server/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
name = "llama-cpp-server"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true

[features]
cuda = []
rocm = []
vulkan = []

[dependencies]
futures.workspace = true
http-api-bindings = { path = "../http-api-bindings" }
reqwest.workspace = true
serde_json.workspace = true
tabby-inference = { path = "../tabby-inference" }
tracing.workspace = true
async-trait.workspace = true
tokio = { workspace = true, features = ["process"] }
anyhow.workspace = true

[build-dependencies]
cmake = "0.1"
omnicopy_to_output = "0.1.1"
79 changes: 79 additions & 0 deletions crates/llama-cpp-server/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::{env, path::Path};

use cmake::Config;
use omnicopy_to_output::copy_to_output;

fn main() {
let mut config = Config::new("../llama-cpp-bindings/llama.cpp");
config.profile("Release");
config.define("LLAMA_NATIVE", "OFF");
config.define("INS_ENB", "ON");

if cfg!(target_os = "macos") {
config.define("LLAMA_METAL", "ON");
config.define("LLAMA_METAL_EMBED_LIBRARY", "ON");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Accelerate");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
}
if cfg!(feature = "cuda") {
config.define("LLAMA_CUBLAS", "ON");
config.define("CMAKE_POSITION_INDEPENDENT_CODE", "ON");
}
if cfg!(feature = "rocm") {
let amd_gpu_targets: Vec<&str> = vec![
"gfx803",
"gfx900",
"gfx906:xnack-",
"gfx908:xnack-",
"gfx90a:xnack+",
"gfx90a:xnack-",
"gfx940",
"gfx941",
"gfx942",
"gfx1010",
"gfx1012",
"gfx1030",
"gfx1031",
"gfx1100",
"gfx1101",
"gfx1102",
"gfx1103",
];

let rocm_root = env::var("ROCM_ROOT").unwrap_or("/opt/rocm".to_string());
config.define("LLAMA_HIPBLAS", "ON");
config.define("CMAKE_C_COMPILER", format!("{}/llvm/bin/clang", rocm_root));
config.define(
"CMAKE_CXX_COMPILER",
format!("{}/llvm/bin/clang++", rocm_root),
);
config.define("AMDGPU_TARGETS", amd_gpu_targets.join(";"));
}
if cfg!(feature = "vulkan") {
config.define("LLAMA_VULKAN", "ON");
}

let out = config.build();
let server_binary = make_output_binary(&out, "server");
let renamed_server_binary = if cfg!(target_os = "macos") {
make_output_binary(&out, "llama-server-metal")
} else if cfg!(feature = "cuda") {
make_output_binary(&out, "llama-server-cuda")
} else if cfg!(feature = "rocm") {
make_output_binary(&out, "llama-server-rocm")
} else if cfg!(feature = "vulkan") {
make_output_binary(&out, "llama-server-vulkan")
} else {
make_output_binary(&out, "llama-server")
};

std::fs::rename(server_binary, &renamed_server_binary).expect("Failed to rename server binary");
copy_to_output(&renamed_server_binary)
.expect("Failed to copy server binary to output directory");
}

fn make_output_binary(out: &Path, name: &str) -> String {
out.join("bin").join(name).display().to_string() + env::consts::EXE_SUFFIX
}
180 changes: 180 additions & 0 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
use std::{net::TcpListener, process::Stdio, sync::Arc};

use anyhow::Result;
use async_trait::async_trait;
use futures::stream::BoxStream;
use serde_json::json;
use tabby_inference::{CompletionOptions, CompletionStream, Embedding};
use tokio::task::JoinHandle;
use tracing::warn;

pub struct LlamaCppServer {
port: u16,
handle: JoinHandle<()>,
completion: Arc<dyn CompletionStream>,
embedding: Arc<dyn Embedding>,
}

#[async_trait]
impl CompletionStream for LlamaCppServer {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
self.completion.generate(prompt, options).await
}
}

#[async_trait]
impl Embedding for LlamaCppServer {
async fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
self.embedding.embed(prompt).await
}
}

impl LlamaCppServer {
pub fn new(device: &str, model_path: &str, parallelism: u8) -> Self {
let use_gpu = device != "cpu";
let Some(binary_name) = find_binary_name(Some(device)) else {
panic!("Failed to find llama-server binary for device {device}, please make sure you have corresponding llama-server binary locates in the same directory as the current executable.");
};

let model_path = model_path.to_owned();
let port = get_available_port();
let handle = tokio::spawn(async move {
loop {
let server_binary = std::env::current_exe()
.expect("Failed to get current executable path")
.parent()
.expect("Failed to get parent directory")
.join(&binary_name)
.display()
.to_string()
+ std::env::consts::EXE_SUFFIX;
let mut command = tokio::process::Command::new(server_binary);

command
.arg("-m")
.arg(&model_path)
.arg("--port")
.arg(port.to_string())
.arg("-np")
.arg(parallelism.to_string())
.arg("--log-disable")
.kill_on_drop(true)
.stderr(Stdio::null())
.stdout(Stdio::null());

if let Ok(n_threads) = std::env::var("LLAMA_CPP_N_THREADS") {
command.arg("-t").arg(n_threads);
}

if use_gpu {
let num_gpu_layers =
std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into());
command.arg("-ngl").arg(&num_gpu_layers);
}

let mut process = command.spawn().unwrap_or_else(|e| {
panic!(
"Failed to start llama-server with command {:?}: {}",
command, e
)
});

let status_code = process
.wait()
.await
.ok()
.and_then(|s| s.code())
.unwrap_or(-1);

if status_code != 0 {
warn!(
"llama-server exited with status code {}, restarting...",
status_code
);
}
}
});

Self {
handle,
port,
completion: make_completion(port),
embedding: make_embedding(port),
}
}

pub async fn start(&self) {
let client = reqwest::Client::new();
loop {
let Ok(resp) = client.get(api_endpoint(self.port) + "/health").send().await else {
continue;
};

if resp.status().is_success() {
return;
}
}
}
}

fn find_binary_name(suffix: Option<&str>) -> Option<String> {
let current_exe = std::env::current_exe().expect("Failed to get current executable path");
let binary_dir = current_exe
.parent()
.expect("Failed to get parent directory");
let binary_name = if let Some(suffix) = suffix {
format!("llama-server-{}", suffix)
} else {
"llama-server".to_owned()
};
std::fs::read_dir(binary_dir)
.expect("Failed to read directory")
.filter_map(|entry| entry.ok())
.filter(|entry| {
entry
.file_name()
.to_string_lossy()
.starts_with(&binary_name)
})
.map(|entry| entry.path().display().to_string())
.next()
}

fn make_completion(port: u16) -> Arc<dyn CompletionStream> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": api_endpoint(port),
}))
.expect("Failed to serialize model spec");
let (engine, _, _) = http_api_bindings::create(&model_spec);
engine
}

pub fn make_embedding(port: u16) -> Arc<dyn Embedding> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": api_endpoint(port),
}))
.expect("Failed to serialize model spec");
http_api_bindings::create_embedding(&model_spec)
}

fn get_available_port() -> u16 {
(30888..40000)
.find(|port| port_is_available(*port))
.expect("Failed to find available port")
}

fn port_is_available(port: u16) -> bool {
TcpListener::bind(("127.0.0.1", port)).is_ok()
}

impl Drop for LlamaCppServer {
fn drop(&mut self) {
self.handle.abort();
}
}

fn api_endpoint(port: u16) -> String {
format!("http://localhost:{port}")
}
Loading

0 comments on commit b674eb7

Please sign in to comment.