diff --git a/crates/llama-cpp-server/build.rs b/crates/llama-cpp-server/build.rs index 088fb9fe4fe2..17443cdf7a6b 100644 --- a/crates/llama-cpp-server/build.rs +++ b/crates/llama-cpp-server/build.rs @@ -1,4 +1,4 @@ -use std::env; +use std::{env, path::Path}; use cmake::Config; use omnicopy_to_output::copy_to_output; @@ -56,8 +56,23 @@ fn main() { } let out = config.build(); - let server_binary = - out.join("bin").join("server").display().to_string() + env::consts::EXE_SUFFIX; + let server_binary = make_output_binary(&out, "server"); + let renamed_server_binary = if cfg!(feature = "cuda") { + make_output_binary(&out, "llama-server-cuda") + } else if cfg!(feature = "rocm") { + make_output_binary(&out, "llama-server-rocm") + } else if cfg!(feature = "vulkan") { + make_output_binary(&out, "llama-server-vulkan") + } else { + make_output_binary(&out, "llama-server") + }; - copy_to_output(&server_binary).expect("Failed to copy server binary to output directory"); + std::fs::rename(&server_binary, &renamed_server_binary) + .expect("Failed to rename server binary"); + copy_to_output(&renamed_server_binary) + .expect("Failed to copy server binary to output directory"); +} + +fn make_output_binary(out: &Path, name: &str) -> String { + out.join("bin").join(name).display().to_string() + env::consts::EXE_SUFFIX } diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index 64628bb33493..f33f54160553 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -30,7 +30,13 @@ impl Embedding for LlamaCppServer { } impl LlamaCppServer { - pub fn new(model_path: &str, use_gpu: bool, parallelism: u8) -> Self { + pub fn new(device: &str, model_path: &str, parallelism: u8) -> Self { + let use_gpu = device != "cpu"; + let mut binary_name = "llama-server".to_owned(); + if device != "cpu" && device != "metal" { + binary_name = binary_name + "-" + device; + } + let model_path = model_path.to_owned(); let port = get_available_port(); let handle = tokio::spawn(async move { @@ -39,7 +45,7 @@ impl LlamaCppServer { .expect("Failed to get current executable path") .parent() .expect("Failed to get parent directory") - .join("server") + .join(&binary_name) .display() .to_string() + std::env::consts::EXE_SUFFIX; @@ -67,7 +73,12 @@ impl LlamaCppServer { command.arg("-ngl").arg(&num_gpu_layers); } - let mut process = command.spawn().expect("Failed to spawn llama-cpp-server"); + let mut process = command.spawn().unwrap_or_else(|e| { + panic!( + "Failed to start llama-server with command {:?}: {}", + command, e + ) + }); let status_code = process .wait() @@ -162,7 +173,7 @@ mod tests { let registry = ModelRegistry::new(registry).await; let model_path = registry.get_model_path(name).display().to_string(); - let server = LlamaCppServer::new(&model_path, false, 1); + let server = LlamaCppServer::new("cpu", &model_path, false, 1); server.start().await; let s = server diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index 4e6362e88f6d..5a70839787ca 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -113,8 +113,9 @@ async fn create_ggml_engine( ); } + let device_str = device.to_string().to_lowercase(); let server = - llama_cpp_server::LlamaCppServer::new(model_path, device.ggml_use_gpu(), parallelism); + llama_cpp_server::LlamaCppServer::new(&device_str, model_path, parallelism); server.start().await; Arc::new(server) }