Skip to content

Commit

Permalink
chore: add llama-cpp-server to embed llama-server directly
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed May 13, 2024
1 parent 05c40e5 commit bc2d09b
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 11 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions crates/llama-cpp-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@ homepage.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
http-api-bindings = { version = "0.12.0-dev.0", path = "../http-api-bindings" }
futures.workspace = true
http-api-bindings = { path = "../http-api-bindings" }
reqwest.workspace = true
serde_json.workspace = true
tabby-inference = { version = "0.12.0-dev.0", path = "../tabby-inference" }
tokio.workspace = true
tabby-inference = { path = "../tabby-inference" }
tokio = { workspace = true, features = ["process"] }

[dev-dependencies]
tabby-common = { path = "../tabby-common" }
87 changes: 80 additions & 7 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
use std::{process::Stdio, sync::Arc};
use std::{
process::{ExitStatus, Stdio},
sync::Arc,
};

use serde_json::json;
use tabby_inference::{ChatCompletionStream, CompletionStream, Embedding};
use tokio::task::JoinHandle;

struct LlamaCppServer {
process: tokio::process::Child,
handle: JoinHandle<()>,
}

const SERVER_PORT: u16 = 30888;

impl LlamaCppServer {
pub fn new(model_path: &str, use_gpu: bool, parallelism: u8) -> Self {
let mut num_gpu_layers = std::env::var("LLAMA_CPP_N_GPU_LAYERS")
.unwrap_or("9999".into());
let mut num_gpu_layers = std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into());
if !use_gpu {
num_gpu_layers = "0".to_string();
}
let process = tokio::process::Command::new("llama-cpp-server")
let mut process = tokio::process::Command::new("llama-server")
.arg("-m")
.arg(model_path)
.arg("--port")
Expand All @@ -26,16 +29,41 @@ impl LlamaCppServer {
.arg("-np")
.arg(parallelism.to_string())
.kill_on_drop(true)
.stderr(Stdio::null())
.stdout(Stdio::null())
.spawn()
.expect("Failed to spawn llama-cpp-server");

Self { process }
let handle = tokio::spawn(async move {
let status_code = process
.wait()
.await
.ok()
.and_then(|s| s.code())
.unwrap_or(-1);
println!("Exist with exit code {}", status_code);
});

Self { handle }
}

async fn wait_for_health(&self) {
let client = reqwest::Client::new();
loop {
let Ok(resp) = client.get(api_endpoint() + "/health").send().await else {
continue;
};

if resp.status().is_success() {
return;
}
}
}

pub fn completion(&self, prompt_template: String) -> Arc<dyn CompletionStream> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": format!("http://localhost:{SERVER_PORT}"),
"api_endpoint": api_endpoint(),
"prompt_template": prompt_template,
}))
.expect("Failed to serialize model spec");
Expand All @@ -61,3 +89,48 @@ impl LlamaCppServer {
http_api_bindings::create_embedding(&model_spec)
}
}

fn api_endpoint() -> String {
format!("http://localhost:{SERVER_PORT}")
}

#[cfg(test)]
mod tests {
use futures::StreamExt;
use tabby_common::registry::{parse_model_id, ModelRegistry};
use tabby_inference::CompletionOptionsBuilder;

use super::*;

#[tokio::test]
#[ignore = "Should only be run in local manual testing"]
async fn test_create_completion() {
let model_id = "StarCoder-1B";
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);

let server = LlamaCppServer::new(&model_path, false, 1);
server.wait_for_health().await;

let completion = server.completion(model_info.prompt_template.clone().unwrap());
let s = completion
.generate(
"def fib(n):",
CompletionOptionsBuilder::default()
.max_decoding_tokens(7)
.max_input_length(1024)
.sampling_temperature(0.0)
.seed(12345)
.build()
.unwrap(),
)
.await;

let content: Vec<String> = s.collect().await;

let content = content.join("");
assert_eq!(content, "\n if n <= 1:")
}
}
2 changes: 1 addition & 1 deletion ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ tabby-schema = { path = "../../ee/tabby-schema" }
tabby-db = { path = "../../ee/tabby-db" }
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
tokio = { workspace = true, features = ["fs", "process"] }
tokio = { workspace = true, features = ["fs"] }
tokio-tungstenite = "0.21"
tower = { version = "0.4", features = ["util", "limit"] }
tower-http = { workspace = true, features = ["fs", "trace"] }
Expand Down

0 comments on commit bc2d09b

Please sign in to comment.