Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wsxiaoys committed May 13, 2024
1 parent c47d3a9 commit 7d34ac1
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 16 deletions.
5 changes: 2 additions & 3 deletions crates/llama-cpp-server/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ fn main() {

let out = config.build();
let server_binary = make_output_binary(&out, "server");
let renamed_server_binary =
if cfg!(target_os = "macos") {
let renamed_server_binary = if cfg!(target_os = "macos") {
make_output_binary(&out, "llama-server-metal")
} else if cfg!(feature = "cuda") {
make_output_binary(&out, "llama-server-cuda")
Expand All @@ -70,7 +69,7 @@ fn main() {
make_output_binary(&out, "llama-server")
};

std::fs::rename(&server_binary, &renamed_server_binary)
std::fs::rename(server_binary, &renamed_server_binary)
.expect("Failed to rename server binary");
copy_to_output(&renamed_server_binary)
.expect("Failed to copy server binary to output directory");
Expand Down
36 changes: 28 additions & 8 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use futures::stream::BoxStream;
use serde_json::json;
use tabby_inference::{CompletionOptions, CompletionStream, Embedding};
use tokio::task::JoinHandle;
use tracing::warn;
use tracing::{warn};

pub struct LlamaCppServer {
port: u16,
Expand All @@ -32,12 +32,9 @@ impl Embedding for LlamaCppServer {
impl LlamaCppServer {
pub fn new(device: &str, model_path: &str, parallelism: u8) -> Self {
let use_gpu = device != "cpu";
let mut binary_name = "llama-server".to_owned();
if cfg!(target_os = "macos") {
binary_name = binary_name + "-metal";
} else if device != "cpu" {
binary_name = binary_name + "-" + device;
}
let Some(binary_name) = find_binary_name(Some(device)) else {
panic!("Failed to find llama-server binary for device {device}, please make sure you have corresponding llama-server binary locates in the same directory as the current executable");
};

let model_path = model_path.to_owned();
let port = get_available_port();
Expand Down Expand Up @@ -120,6 +117,29 @@ impl LlamaCppServer {
}
}

fn find_binary_name(suffix: Option<&str>) -> Option<String> {
let current_exe = std::env::current_exe().expect("Failed to get current executable path");
let binary_dir = current_exe
.parent()
.expect("Failed to get parent directory");
let binary_name = if let Some(suffix) = suffix {
format!("llama-server-{}", suffix)
} else {
"llama-server".to_owned()
};
std::fs::read_dir(binary_dir)
.expect("Failed to read directory")
.filter_map(|entry| entry.ok())
.filter(|entry| {
entry
.file_name()
.to_string_lossy()
.starts_with(&binary_name)
})
.map(|entry| entry.path().display().to_string())
.next()
}

fn make_completion(port: u16) -> Arc<dyn CompletionStream> {
let model_spec: String = serde_json::to_string(&json!({
"kind": "llama",
Expand Down Expand Up @@ -157,4 +177,4 @@ impl Drop for LlamaCppServer {

fn api_endpoint(port: u16) -> String {
format!("http://localhost:{port}")
}
}
4 changes: 2 additions & 2 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors.workspace = true
homepage.workspace = true

[features]
default = ["ee", "dep:color-eyre"]
default = ["ee"]
ee = ["dep:tabby-webserver"]
cuda = []
rocm = []
Expand Down Expand Up @@ -57,7 +57,7 @@ axum-prometheus = "0.6"
uuid.workspace = true
cached = { workspace = true, features = ["async"] }
parse-git-url = "0.5.1"
color-eyre = { version = "0.6.3", optional = true }
color-eyre = { version = "0.6.3" }

[dependencies.openssl]
optional = true
Expand Down
1 change: 0 additions & 1 deletion crates/tabby/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ impl Device {

#[tokio::main]
async fn main() {
#[cfg(feature = "dep:color-eyre")]
color_eyre::install().expect("Must be able to install color_eyre");

let cli = Cli::parse();
Expand Down
3 changes: 1 addition & 2 deletions crates/tabby/src/services/model/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ async fn create_ggml_engine(
}

let device_str = device.to_string().to_lowercase();
let server =
llama_cpp_server::LlamaCppServer::new(&device_str, model_path, parallelism);
let server = llama_cpp_server::LlamaCppServer::new(&device_str, model_path, parallelism);
server.start().await;
Arc::new(server)
}
Expand Down

0 comments on commit 7d34ac1

Please sign in to comment.