Skip to content

Commit

Permalink
feat: switch cuda backend to llama.cpp (#656)
Browse files Browse the repository at this point in the history
* feat: switch cuda backend to llama.cpp

* fix

* fix
  • Loading branch information
wsxiaoys authored Oct 27, 2023
1 parent 308681e commit 23bd542
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 107 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ jobs:
- run: bash ./ci/prepare_build_environment.sh

- name: Bulid release binary
run: cargo build --no-default-features --release --target ${{ matrix.target }} --package tabby
run: cargo build --release --target ${{ matrix.target }} --package tabby

- name: Rename release binary
run: mv target/${{ matrix.target }}/release/tabby tabby_${{ matrix.target }}
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

* Switch cpu backend to llama.cpp: https://github.com/TabbyML/tabby/pull/638
* add `server.completion_timeout` to control the code completion interface timeout: https://github.com/TabbyML/tabby/pull/637
* Switch cuda backend to llama.cpp: https://github.com/TabbyML/tabby/pull/656

# v0.4.0

Expand Down
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 11 additions & 7 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
FROM ghcr.io/opennmt/ctranslate2:3.20.0-ubuntu20.04-cuda11.2 as source
FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04 as builder
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG CUDA_VERSION=11.7.1
# Target the CUDA build image
ARG BASE_CUDA_DEV_CONTAINER=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
# Target the CUDA runtime image
ARG BASE_CUDA_RUN_CONTAINER=nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}

ENV CTRANSLATE2_ROOT=/opt/ctranslate2
COPY --from=source $CTRANSLATE2_ROOT $CTRANSLATE2_ROOT
FROM ${BASE_CUDA_DEV_CONTAINER} as build

ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
Expand Down Expand Up @@ -30,10 +34,10 @@ RUN mkdir -p target

RUN --mount=type=cache,target=/usr/local/cargo/registry \
--mount=type=cache,target=/root/workspace/target \
cargo build --features link_shared --release && \
cargo build --features cuda --release && \
cp target/release/tabby /opt/tabby/bin/

FROM ghcr.io/opennmt/ctranslate2:3.20.0-ubuntu20.04-cuda11.2
FROM ${BASE_CUDA_RUN_CONTAINER} as runtime

RUN apt-get update && \
apt-get install -y --no-install-recommends \
Expand All @@ -51,7 +55,7 @@ RUN git config --system --add safe.directory "*"
RUN ln -s /usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1 \
/usr/lib/x86_64-linux-gnu/libnvidia-ml.so

COPY --from=builder /opt/tabby /opt/tabby
COPY --from=build /opt/tabby /opt/tabby

ENV TABBY_ROOT=/data

Expand Down
3 changes: 3 additions & 0 deletions crates/llama-cpp-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ name = "llama-cpp-bindings"
version = "0.5.0-dev"
edition = "2021"

[features]
cuda = []

[build-dependencies]
cxx-build = "1.0"
cmake = "0.1"
Expand Down
20 changes: 10 additions & 10 deletions crates/llama-cpp-bindings/build.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
use cmake::Config;

fn main() {
let mut config = Config::new("llama.cpp");
if cfg!(target_os = "macos") {
config.define("LLAMA_METAL", "ON");
}
let dst = config.build();

println!("cargo:rerun-if-changed=cc/*.h");
println!("cargo:rerun-if-changed=cc/*.cc");

println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=llama");
println!("cargo:rustc-link-lib=ggml_static");

let mut config = Config::new("llama.cpp");
if cfg!(target_os = "macos") {
config.define("LLAMA_METAL", "ON");
println!("cargo:rustc-link-lib=framework=Foundation");
println!("cargo:rustc-link-lib=framework=Accelerate");
println!("cargo:rustc-link-lib=framework=Metal");
println!("cargo:rustc-link-lib=framework=MetalKit");
}
if cfg!(feature = "cuda") {
config.define("LLAMA_CUBLAS", "ON");
}

let dst = config.build();
println!("cargo:rustc-link-search=native={}/build", dst.display());
println!("cargo:rustc-link-lib=llama");
println!("cargo:rustc-link-lib=ggml_static");

cxx_build::bridge("src/lib.rs")
.file("src/engine.cc")
Expand Down
8 changes: 3 additions & 5 deletions crates/tabby/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ name = "tabby"
version = "0.5.0-dev"
edition = "2021"

[features]
cuda = ["llama-cpp-bindings/cuda"]

[dependencies]
tabby-common = { path = "../tabby-common" }
tabby-scheduler = { path = "../tabby-scheduler" }
Expand Down Expand Up @@ -43,7 +46,6 @@ textdistance = "1.0.2"
regex.workspace = true
thiserror.workspace = true
llama-cpp-bindings = { path = "../llama-cpp-bindings" }
ctranslate2-bindings = { path = "../ctranslate2-bindings", optional = true }

[dependencies.uuid]
version = "1.3.3"
Expand All @@ -53,10 +55,6 @@ features = [
"macro-diagnostics", # Enable better diagnostics for compile-time UUIDs
]

[features]
link_shared = ["ctranslate2-bindings/link_shared"]
link_cuda_static = ["ctranslate2-bindings"]

[build-dependencies]
vergen = { version = "8.0.0", features = ["build", "git", "gitcl"] }

Expand Down
44 changes: 1 addition & 43 deletions crates/tabby/src/serve/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub fn create_engine(
if args.device != super::Device::ExperimentalHttp {
let model_dir = get_model_dir(model);
let metadata = read_metadata(&model_dir);
let engine = create_local_engine(args, &model_dir, &metadata);
let engine = create_ggml_engine(&args.device, &model_dir);
(
engine,
EngineInfo {
Expand All @@ -38,48 +38,6 @@ pub struct EngineInfo {
pub chat_template: Option<String>,
}

#[cfg(not(any(feature = "link_shared", feature = "link_cuda_static")))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
_metadata: &Metadata,
) -> Box<dyn TextGeneration> {
create_ggml_engine(&args.device, model_dir)
}

#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
fn create_local_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
if args.device.use_ggml_backend() {
create_ggml_engine(&args.device, model_dir)
} else {
create_ctranslate2_engine(args, model_dir, metadata)
}
}

#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
fn create_ctranslate2_engine(
args: &crate::serve::ServeArgs,
model_dir: &ModelDir,
metadata: &Metadata,
) -> Box<dyn TextGeneration> {
use ctranslate2_bindings::{CTranslate2Engine, CTranslate2EngineOptionsBuilder};

let device = format!("{}", args.device);
let options = CTranslate2EngineOptionsBuilder::default()
.model_path(model_dir.ctranslate2_dir())
.tokenizer_path(model_dir.tokenizer_file())
.device(device)
.model_type(metadata.auto_model.clone())
.device_indices(args.device_indices.clone())
.build()
.unwrap();
Box::new(CTranslate2Engine::create(options))
}

fn create_ggml_engine(device: &super::Device, model_dir: &ModelDir) -> Box<dyn TextGeneration> {
let options = llama_cpp_bindings::LlamaEngineOptionsBuilder::default()
.model_path(model_dir.ggml_q8_0_v2_file())
Expand Down
54 changes: 14 additions & 40 deletions crates/tabby/src/serve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub enum Device {
#[strum(serialize = "cpu")]
Cpu,

#[cfg(any(feature = "link_shared", feature = "link_cuda_static"))]
#[cfg(feature = "cuda")]
Cuda,

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
Expand All @@ -87,21 +87,16 @@ pub enum Device {

impl Device {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
fn use_ggml_backend(&self) -> bool {
*self == Device::Metal || *self == Device::Cpu
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
fn use_ggml_backend(&self) -> bool {
*self == Device::Cpu
fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
}

#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
#[cfg(feature = "cuda")]
fn ggml_use_gpu(&self) -> bool {
*self == Device::Metal
*self == Device::Cuda
}

#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
#[cfg(not(any(all(target_os = "macos", target_arch = "aarch64"), feature = "cuda")))]
fn ggml_use_gpu(&self) -> bool {
false
}
Expand All @@ -124,26 +119,19 @@ pub struct ServeArgs {
#[clap(long, default_value_t=Device::Cpu)]
device: Device,

/// GPU indices to run models, only applicable for CUDA.
#[clap(long, default_values_t=[0])]
device_indices: Vec<i32>,

/// DEPRECATED: Do not use.
#[deprecated(since = "0.5.0")]
#[clap(long, hide(true))]
num_replicas_per_device: Option<usize>,

/// DEPRECATED: Do not use.
#[clap(long, hide(true))]
compute_type: Option<String>,
device_indices: Vec<i32>,
}

pub async fn main(config: &Config, args: &ServeArgs) {
valid_args(args);

if args.device != Device::ExperimentalHttp {
download_model(&args.model, &args.device).await;
download_model(&args.model).await;
if let Some(chat_model) = &args.chat_model {
download_model(chat_model, &args.device).await;
download_model(chat_model).await;
}
} else {
warn!("HTTP device is unstable and does not comply with semver expectations.")
Expand Down Expand Up @@ -261,17 +249,8 @@ fn api_router(args: &ServeArgs, config: &Config) -> Router {
}

fn valid_args(args: &ServeArgs) {
if args.num_replicas_per_device.is_some() {
warn!("--num-replicas-per-device is deprecated and will be removed in future release.");
}

if args.device == Device::Cpu && (args.device_indices.len() != 1 || args.device_indices[0] != 0)
{
fatal!("CPU device only supports device indices = [0]");
}

if args.compute_type.is_some() {
warn!("--compute-type is deprecated and will be removed in future release.");
if !args.device_indices.is_empty() {
warn!("--device-indices is deprecated and will be removed in future release.");
}
}

Expand All @@ -285,15 +264,10 @@ fn start_heartbeat(args: &ServeArgs) {
});
}

async fn download_model(model: &str, device: &Device) {
async fn download_model(model: &str) {
let downloader = Downloader::new(model, /* prefer_local_file= */ true);
let handler = |err| fatal!("Failed to fetch model '{}' due to '{}'", model, err,);
let download_result = if device.use_ggml_backend() {
downloader.download_ggml_files().await
} else {
downloader.download_ctranslate2_files().await
};

let download_result = downloader.download_ggml_files().await;
download_result.unwrap_or_else(handler);
}

Expand Down

0 comments on commit 23bd542

Please sign in to comment.