Skip to content

Commit

Permalink
feat(download): Add support for segmented models
Browse files Browse the repository at this point in the history
  • Loading branch information
boxbeam committed Mar 27, 2024
1 parent 3c32788 commit cdf760c
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
3 changes: 3 additions & 0 deletions crates/tabby-common/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ pub struct ModelInfo {
pub prompt_template: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub urls: Vec<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
pub segmented_urls: Vec<String>,
pub sha256: String,
}

Expand Down
50 changes: 48 additions & 2 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
//! Responsible for downloading ML models for use with tabby.
use std::{fs, path::Path};
use std::{
fs::{self, File},
io::{BufRead, BufReader, Write},
path::Path,
};

use aim_downloader::{bar::WrappedBar, error::DownloadError, https};
use anyhow::{anyhow, bail, Result};
use tabby_common::registry::{parse_model_id, ModelRegistry};
use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
Expand Down Expand Up @@ -37,6 +41,10 @@ async fn download_model_impl(
}
}

if !model_info.segmented_urls.is_empty() {
return download_split_model(&model_info, &model_path).await;
}

Check warning on line 47 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L44-L47

Added lines #L44 - L47 were not covered by tests
let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned());
let Some(model_url) = model_info.urls.iter().find(|x| x.contains(&registry)) else {
return Err(anyhow!(
Expand All @@ -52,6 +60,44 @@ async fn download_model_impl(
Ok(())
}

async fn download_split_model(model_info: &ModelInfo, model_path: &Path) -> Result<()> {
if !model_info.urls.is_empty() {
return Err(anyhow!(
"{}: Cannot specify both `urls` and `segmented_urls`",
model_info.name
));
}
let mut paths = vec![];
for (index, url) in model_info.segmented_urls.iter().enumerate() {
let ext = format!(
"{}.{}",
model_path.extension().unwrap_or_default().to_string_lossy(),
index.to_string()
);
let path = model_path.with_extension(ext);
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
let download_job = Retry::spawn(strategy, || download_file(url, &path));
download_job.await?;
paths.push(path);

Check warning on line 81 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L63-L81

Added lines #L63 - L81 were not covered by tests
}
info!("Merging split model files...");
let mut file = File::open(model_path)?;
for path in paths {
let mut reader = BufReader::new(File::open(&path)?);

Check warning on line 86 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L83-L86

Added lines #L83 - L86 were not covered by tests
loop {
let buffer = reader.fill_buf()?;
file.write_all(buffer)?;
let len = buffer.len();
reader.consume(len);
if len == 0 {
break;
}

Check warning on line 94 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L88-L94

Added lines #L88 - L94 were not covered by tests
}
std::fs::remove_file(path)?;

Check warning on line 96 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L96

Added line #L96 was not covered by tests
}
Ok(())
}

Check warning on line 99 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L98-L99

Added lines #L98 - L99 were not covered by tests

async fn download_file(url: &str, path: &Path) -> Result<()> {
let dir = path
.parent()
Expand Down

0 comments on commit cdf760c

Please sign in to comment.