Skip to content

Commit

Permalink
feat(download): Add support for segmented model urls (#1735)
Browse files Browse the repository at this point in the history
* feat(download): Add support for segmented models

* Fix segmented downloads

* Apply suggestion

* Allow specifying multiple mirrors for partitioned models

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
boxbeam and autofix-ci[bot] authored Mar 29, 2024
1 parent 7b3f532 commit e45e509
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 5 deletions.
5 changes: 4 additions & 1 deletion crates/tabby-common/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ pub struct ModelInfo {
pub prompt_template: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub chat_template: Option<String>,
pub urls: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub urls: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub partition_urls: Option<Vec<Vec<String>>>,
pub sha256: String,
}

Expand Down
86 changes: 82 additions & 4 deletions crates/tabby-download/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
//! Responsible for downloading ML models for use with tabby.
use std::{fs, path::Path};
use std::{
fs::{self, File, OpenOptions},
io::{BufRead, BufReader, Write},
path::Path,
};

use aim_downloader::{bar::WrappedBar, error::DownloadError, https};
use anyhow::{anyhow, bail, Result};
use tabby_common::registry::{parse_model_id, ModelRegistry};
use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry};
use tokio_retry::{
strategy::{jitter, ExponentialBackoff},
Retry,
};
use tracing::{info, warn};

fn download_host() -> String {
std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned())
}

async fn download_model_impl(
registry: &ModelRegistry,
name: &str,
Expand Down Expand Up @@ -37,8 +45,17 @@ async fn download_model_impl(
}
}

let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned());
let Some(model_url) = model_info.urls.iter().find(|x| x.contains(&registry)) else {
if model_info.partition_urls.is_some() {
return download_split_model(model_info, &model_path).await;
}

let registry = download_host();
let Some(model_url) = model_info
.urls
.iter()
.flatten()
.find(|x| x.contains(&registry))
else {
return Err(anyhow!(
"Invalid mirror <{}> for model urls: {:?}",
registry,
Expand All @@ -52,6 +69,67 @@ async fn download_model_impl(
Ok(())
}

async fn download_split_model(model_info: &ModelInfo, model_path: &Path) -> Result<()> {
if model_info.urls.is_some() {
return Err(anyhow!(
"{}: Cannot specify both `urls` and `partition_urls`",
model_info.name
));
}
let mut paths = vec![];
let partition_urls = model_info.partition_urls.clone().unwrap_or_default();
let mirror = download_host();

let Some(urls) = partition_urls
.iter()
.find(|urls| urls.iter().all(|url| url.contains(&mirror)))
else {
return Err(anyhow!(
"Invalid mirror <{}> for model urls: {:?}",
mirror,
partition_urls
));
};

for (index, url) in urls.iter().enumerate() {
let ext = format!(
"{}.{}",
model_path.extension().unwrap_or_default().to_string_lossy(),
index
);
let path = model_path.with_extension(ext);
info!(
"Downloading {path:?} ({index} / {total})",
index = index + 1,
total = partition_urls.len()
);
let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2);
let download_job = Retry::spawn(strategy, || download_file(url, &path));
download_job.await?;
paths.push(path);
}
info!("Merging split model files...");
println!("{model_path:?}");
let mut file = OpenOptions::new()
.append(true)
.create(true)
.open(model_path)?;
for path in paths {
let mut reader = BufReader::new(File::open(&path)?);
loop {
let buffer = reader.fill_buf()?;
file.write_all(buffer)?;
let len = buffer.len();
reader.consume(len);
if len == 0 {
break;
}
}
std::fs::remove_file(path)?;
}
Ok(())
}

async fn download_file(url: &str, path: &Path) -> Result<()> {
let dir = path
.parent()
Expand Down

0 comments on commit e45e509

Please sign in to comment.