diff --git a/crates/llama-cpp-server/build.rs b/crates/llama-cpp-server/build.rs index c96926dbd578..f91c9c535b3f 100644 --- a/crates/llama-cpp-server/build.rs +++ b/crates/llama-cpp-server/build.rs @@ -57,8 +57,7 @@ fn main() { let out = config.build(); let server_binary = make_output_binary(&out, "server"); - let renamed_server_binary = - if cfg!(target_os = "macos") { + let renamed_server_binary = if cfg!(target_os = "macos") { make_output_binary(&out, "llama-server-metal") } else if cfg!(feature = "cuda") { make_output_binary(&out, "llama-server-cuda") @@ -70,7 +69,7 @@ fn main() { make_output_binary(&out, "llama-server") }; - std::fs::rename(&server_binary, &renamed_server_binary) + std::fs::rename(server_binary, &renamed_server_binary) .expect("Failed to rename server binary"); copy_to_output(&renamed_server_binary) .expect("Failed to copy server binary to output directory"); diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index b365b4b45b17..b0945e777515 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -6,7 +6,7 @@ use futures::stream::BoxStream; use serde_json::json; use tabby_inference::{CompletionOptions, CompletionStream, Embedding}; use tokio::task::JoinHandle; -use tracing::warn; +use tracing::{warn}; pub struct LlamaCppServer { port: u16, @@ -32,12 +32,9 @@ impl Embedding for LlamaCppServer { impl LlamaCppServer { pub fn new(device: &str, model_path: &str, parallelism: u8) -> Self { let use_gpu = device != "cpu"; - let mut binary_name = "llama-server".to_owned(); - if cfg!(target_os = "macos") { - binary_name = binary_name + "-metal"; - } else if device != "cpu" { - binary_name = binary_name + "-" + device; - } + let Some(binary_name) = find_binary_name(Some(device)) else { + panic!("Failed to find llama-server binary for device {device}, please make sure you have corresponding llama-server binary locates in the same directory as the current executable"); + }; let model_path = model_path.to_owned(); let port = get_available_port(); @@ -120,6 +117,29 @@ impl LlamaCppServer { } } +fn find_binary_name(suffix: Option<&str>) -> Option { + let current_exe = std::env::current_exe().expect("Failed to get current executable path"); + let binary_dir = current_exe + .parent() + .expect("Failed to get parent directory"); + let binary_name = if let Some(suffix) = suffix { + format!("llama-server-{}", suffix) + } else { + "llama-server".to_owned() + }; + std::fs::read_dir(binary_dir) + .expect("Failed to read directory") + .filter_map(|entry| entry.ok()) + .filter(|entry| { + entry + .file_name() + .to_string_lossy() + .starts_with(&binary_name) + }) + .map(|entry| entry.path().display().to_string()) + .next() +} + fn make_completion(port: u16) -> Arc { let model_spec: String = serde_json::to_string(&json!({ "kind": "llama", @@ -157,4 +177,4 @@ impl Drop for LlamaCppServer { fn api_endpoint(port: u16) -> String { format!("http://localhost:{port}") -} \ No newline at end of file +} diff --git a/crates/tabby/Cargo.toml b/crates/tabby/Cargo.toml index 60400b777e0c..e0c4b5bd9e1f 100644 --- a/crates/tabby/Cargo.toml +++ b/crates/tabby/Cargo.toml @@ -6,7 +6,7 @@ authors.workspace = true homepage.workspace = true [features] -default = ["ee", "dep:color-eyre"] +default = ["ee"] ee = ["dep:tabby-webserver"] cuda = [] rocm = [] @@ -57,7 +57,7 @@ axum-prometheus = "0.6" uuid.workspace = true cached = { workspace = true, features = ["async"] } parse-git-url = "0.5.1" -color-eyre = { version = "0.6.3", optional = true } +color-eyre = { version = "0.6.3" } [dependencies.openssl] optional = true diff --git a/crates/tabby/src/main.rs b/crates/tabby/src/main.rs index 807973af0a10..65bd4e8ea0a0 100644 --- a/crates/tabby/src/main.rs +++ b/crates/tabby/src/main.rs @@ -89,7 +89,6 @@ impl Device { #[tokio::main] async fn main() { - #[cfg(feature = "dep:color-eyre")] color_eyre::install().expect("Must be able to install color_eyre"); let cli = Cli::parse(); diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index 5a70839787ca..1b0466245ae3 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -114,8 +114,7 @@ async fn create_ggml_engine( } let device_str = device.to_string().to_lowercase(); - let server = - llama_cpp_server::LlamaCppServer::new(&device_str, model_path, parallelism); + let server = llama_cpp_server::LlamaCppServer::new(&device_str, model_path, parallelism); server.start().await; Arc::new(server) }