Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add llama-cpp-server to embed llama-server directly #2112

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ members = [
"ee/tabby-db",
"ee/tabby-db-macros",
"ee/tabby-search",
"ee/tabby-schema",
"ee/tabby-schema", "crates/llama-cpp-server",
]

[workspace.package]
Expand Down
2 changes: 1 addition & 1 deletion crates/llama-cpp-bindings/llama.cpp
Submodule llama.cpp updated 203 files
20 changes: 20 additions & 0 deletions crates/llama-cpp-server/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[package]
name = "llama-cpp-server"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
futures.workspace = true
http-api-bindings = { path = "../http-api-bindings" }
reqwest.workspace = true
serde_json.workspace = true
tabby-inference = { path = "../tabby-inference" }
tracing.workspace = true
tokio = { workspace = true, features = ["process"] }

[dev-dependencies]
tabby-common = { path = "../tabby-common" }
150 changes: 150 additions & 0 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use std::{process::Stdio, sync::Arc};

use serde_json::json;
use tabby_inference::{ChatCompletionStream, CompletionStream, Embedding};
use tokio::task::JoinHandle;
use tracing::warn;

struct LlamaCppServer {
handle: JoinHandle<()>,
}

const SERVER_PORT: u16 = 30888;

impl LlamaCppServer {
pub fn new(model_path: &str, use_gpu: bool, parallelism: u8) -> Self {
let mut num_gpu_layers = std::env::var("LLAMA_CPP_N_GPU_LAYERS").unwrap_or("9999".into());
if !use_gpu {
num_gpu_layers = "0".to_string();
}

Check warning on line 19 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L15-L19

Added lines #L15 - L19 were not covered by tests

let model_path = model_path.to_owned();
let handle = tokio::spawn(async move {
loop {
let mut process = tokio::process::Command::new("llama-server")
.arg("-m")
.arg(&model_path)
.arg("--port")
.arg(SERVER_PORT.to_string())
.arg("-ngl")
.arg(&num_gpu_layers)
.arg("-np")
.arg(parallelism.to_string())
.kill_on_drop(true)
.stderr(Stdio::inherit())
.stdout(Stdio::inherit())
.spawn()
.expect("Failed to spawn llama-cpp-server");

Check warning on line 37 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L21-L37

Added lines #L21 - L37 were not covered by tests

let status_code = process
.wait()
.await
.ok()
.and_then(|s| s.code())
.unwrap_or(-1);

if status_code != 0 {
warn!(
"llama-server exited with status code {}, restarting...",
status_code
);
}

Check warning on line 51 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L39-L51

Added lines #L39 - L51 were not covered by tests
}
});

Self { handle }
}

Check warning on line 56 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L53-L56

Added lines #L53 - L56 were not covered by tests

async fn wait_for_health(&self) {
let client = reqwest::Client::new();

Check warning on line 59 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L58-L59

Added lines #L58 - L59 were not covered by tests
loop {
let Ok(resp) = client.get(api_endpoint() + "/health").send().await else {
continue;

Check warning on line 62 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L61-L62

Added lines #L61 - L62 were not covered by tests
};

if resp.status().is_success() {
return;
}

Check warning on line 67 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L65-L67

Added lines #L65 - L67 were not covered by tests
}
}

Check warning on line 69 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L69

Added line #L69 was not covered by tests

pub fn completion(&self, prompt_template: String) -> Arc<dyn CompletionStream> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": api_endpoint(),
"prompt_template": prompt_template,
}))
.expect("Failed to serialize model spec");
let (engine, _, _) = http_api_bindings::create(&model_spec);
engine
}

Check warning on line 80 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L71-L80

Added lines #L71 - L80 were not covered by tests

pub fn chat(&self) -> Arc<dyn ChatCompletionStream> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "openai-chat",
"api_endpoint": format!("http://localhost:{SERVER_PORT}/v1"),
}))
.expect("Failed to serialize model spec");
http_api_bindings::create_chat(&model_spec)
}

Check warning on line 89 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L82-L89

Added lines #L82 - L89 were not covered by tests

pub fn embedding(self) -> Arc<dyn Embedding> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
"api_endpoint": format!("http://localhost:{SERVER_PORT}"),
}))
.expect("Failed to serialize model spec");
http_api_bindings::create_embedding(&model_spec)
}

Check warning on line 98 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L91-L98

Added lines #L91 - L98 were not covered by tests
}

impl Drop for LlamaCppServer {
fn drop(&mut self) {
self.handle.abort();
}

Check warning on line 104 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L102-L104

Added lines #L102 - L104 were not covered by tests
}

fn api_endpoint() -> String {
format!("http://localhost:{SERVER_PORT}")
}

Check warning on line 109 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L107-L109

Added lines #L107 - L109 were not covered by tests

#[cfg(test)]
mod tests {
use futures::StreamExt;
use tabby_common::registry::{parse_model_id, ModelRegistry};
use tabby_inference::CompletionOptionsBuilder;

use super::*;

#[tokio::test]
#[ignore = "Should only be run in local manual testing"]
async fn test_create_completion() {
let model_id = "StarCoder-1B";
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);

let server = LlamaCppServer::new(&model_path, false, 1);
server.wait_for_health().await;

let completion = server.completion(model_info.prompt_template.clone().unwrap());
let s = completion
.generate(
"def fib(n):",
CompletionOptionsBuilder::default()
.max_decoding_tokens(7)
.max_input_length(1024)
.sampling_temperature(0.0)
.seed(12345)
.build()
.unwrap(),
)
.await;

let content: Vec<String> = s.collect().await;

let content = content.join("");
assert_eq!(content, "\n if n <= 1:")
}

Check warning on line 149 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L121-L149

Added lines #L121 - L149 were not covered by tests
}
2 changes: 1 addition & 1 deletion ee/tabby-webserver/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ tabby-schema = { path = "../../ee/tabby-schema" }
tabby-db = { path = "../../ee/tabby-db" }
tarpc = { version = "0.33.0", features = ["serde-transport"] }
thiserror.workspace = true
tokio = { workspace = true, features = ["fs", "process"] }
tokio = { workspace = true, features = ["fs"] }
tokio-tungstenite = "0.21"
tower = { version = "0.4", features = ["util", "limit"] }
tower-http = { workspace = true, features = ["fs", "trace"] }
Expand Down
Loading