Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(download): Add support for segmented model urls #1735

Merged
merged 5 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/tabby-common/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ pub struct ModelInfo {
pub prompt_template: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template: Option<String>,
pub urls: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub urls: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub partition_urls: Option<Vec<Vec<String>>>,
pub sha256: String,
}

Expand Down
86 changes: 82 additions & 4 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
//! Responsible for downloading ML models for use with tabby.
use std::{fs, path::Path};
use std::{
fs::{self, File, OpenOptions},
io::{BufRead, BufReader, Write},
path::Path,
};

use aim_downloader::{bar::WrappedBar, error::DownloadError, https};
use anyhow::{anyhow, bail, Result};
use tabby_common::registry::{parse_model_id, ModelRegistry};
use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{info, warn};

fn download_host() -> String {
std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned())
}

Check warning on line 19 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L17-L19

Added lines #L17 - L19 were not covered by tests

async fn download_model_impl(
registry: &ModelRegistry,
name: &str,
Expand Down Expand Up @@ -37,8 +45,17 @@
}
}

let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned());
let Some(model_url) = model_info.urls.iter().find(|x| x.contains(&registry)) else {
if model_info.partition_urls.is_some() {
return download_split_model(model_info, &model_path).await;
}

let registry = download_host();
let Some(model_url) = model_info
.urls
.iter()
.flatten()
.find(|x| x.contains(&registry))

Check warning on line 57 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L48-L57

Added lines #L48 - L57 were not covered by tests
else {
return Err(anyhow!(
"Invalid mirror <{}> for model urls: {:?}",
registry,
Expand All @@ -52,6 +69,67 @@
Ok(())
}

async fn download_split_model(model_info: &ModelInfo, model_path: &Path) -> Result<()> {
if model_info.urls.is_some() {
return Err(anyhow!(
"{}: Cannot specify both `urls` and `partition_urls`",
model_info.name
));
}
let mut paths = vec![];
let partition_urls = model_info.partition_urls.clone().unwrap_or_default();
let mirror = download_host();

Check warning on line 81 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L72-L81

Added lines #L72 - L81 were not covered by tests

let Some(urls) = partition_urls
.iter()
.find(|urls| urls.iter().all(|url| url.contains(&mirror)))

Check warning on line 85 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L83-L85

Added lines #L83 - L85 were not covered by tests
else {
return Err(anyhow!(
"Invalid mirror <{}> for model urls: {:?}",
mirror,
partition_urls
));

Check warning on line 91 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L87-L91

Added lines #L87 - L91 were not covered by tests
};

for (index, url) in urls.iter().enumerate() {
let ext = format!(
"{}.{}",
model_path.extension().unwrap_or_default().to_string_lossy(),
index
);
let path = model_path.with_extension(ext);
info!(
"Downloading {path:?} ({index} / {total})",
index = index + 1,
total = partition_urls.len()
);
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
let download_job = Retry::spawn(strategy, || download_file(url, &path));
download_job.await?;
paths.push(path);

Check warning on line 109 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L94-L109

Added lines #L94 - L109 were not covered by tests
}
info!("Merging split model files...");
println!("{model_path:?}");
let mut file = OpenOptions::new()
.append(true)
.create(true)
.open(model_path)?;
for path in paths {
let mut reader = BufReader::new(File::open(&path)?);

Check warning on line 118 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L111-L118

Added lines #L111 - L118 were not covered by tests
loop {
let buffer = reader.fill_buf()?;
file.write_all(buffer)?;
let len = buffer.len();
reader.consume(len);
if len == 0 {
break;
}

Check warning on line 126 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L120-L126

Added lines #L120 - L126 were not covered by tests
}
std::fs::remove_file(path)?;

Check warning on line 128 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L128

Added line #L128 was not covered by tests
}
Ok(())
}

Check warning on line 131 in crates/tabby-download/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/tabby-download/src/lib.rs#L130-L131

Added lines #L130 - L131 were not covered by tests

async fn download_file(url: &str, path: &Path) -> Result<()> {
let dir = path
.parent()
Expand Down
Loading