diff --git a/Cargo.lock b/Cargo.lock index 4561d4bec611..fa4d90caad04 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2784,6 +2784,7 @@ dependencies = [ "tabby-common", "tabby-inference", "tokio", + "tracing", ] [[package]] diff --git a/crates/llama-cpp-server/Cargo.toml b/crates/llama-cpp-server/Cargo.toml index 9b46e5624482..f3d7f45e0895 100644 --- a/crates/llama-cpp-server/Cargo.toml +++ b/crates/llama-cpp-server/Cargo.toml @@ -13,6 +13,7 @@ http-api-bindings = { path = "../http-api-bindings" } reqwest.workspace = true serde_json.workspace = true tabby-inference = { path = "../tabby-inference" } +tracing.workspace = true tokio = { workspace = true, features = ["process"] } [dev-dependencies] diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index 86276922d2a7..016f1be1c33e 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -6,6 +6,7 @@ use std::{ use serde_json::json; use tabby_inference::{ChatCompletionStream, CompletionStream, Embedding}; use tokio::task::JoinHandle; +use tracing::warn; struct LlamaCppServer { handle: JoinHandle<()>, @@ -19,29 +20,36 @@ impl LlamaCppServer { if !use_gpu { num_gpu_layers = "0".to_string(); } - let mut process = tokio::process::Command::new("llama-server") - .arg("-m") - .arg(model_path) - .arg("--port") - .arg(SERVER_PORT.to_string()) - .arg("-ngl") - .arg(num_gpu_layers) - .arg("-np") - .arg(parallelism.to_string()) - .kill_on_drop(true) - .stderr(Stdio::null()) - .stdout(Stdio::null()) - .spawn() - .expect("Failed to spawn llama-cpp-server"); + let model_path = model_path.to_owned(); let handle = tokio::spawn(async move { - let status_code = process - .wait() - .await - .ok() - .and_then(|s| s.code()) - .unwrap_or(-1); - println!("Exist with exit code {}", status_code); + loop { + let mut process = tokio::process::Command::new("llama-server") + .arg("-m") + .arg(&model_path) + .arg("--port") + .arg(SERVER_PORT.to_string()) + .arg("-ngl") + .arg(&num_gpu_layers) + .arg("-np") + .arg(parallelism.to_string()) + .kill_on_drop(true) + .stderr(Stdio::inherit()) + .stdout(Stdio::inherit()) + .spawn() + .expect("Failed to spawn llama-cpp-server"); + + let status_code = process + .wait() + .await + .ok() + .and_then(|s| s.code()) + .unwrap_or(-1); + + if status_code != 0 { + warn!("llama-server exited with status code {}, restarting...", status_code); + } + } }); Self { handle } @@ -90,6 +98,12 @@ impl LlamaCppServer { } } +impl Drop for LlamaCppServer { + fn drop(&mut self) { + self.handle.abort(); + } +} + fn api_endpoint() -> String { format!("http://localhost:{SERVER_PORT}") }