diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs index 24e9c2d0352d..40e11fbf15b7 100644 --- a/crates/tabby-common/src/registry.rs +++ b/crates/tabby-common/src/registry.rs @@ -12,7 +12,10 @@ pub struct ModelInfo { pub prompt_template: Option, #[serde(skip_serializing_if = "Option::is_none")] pub chat_template: Option, + #[serde(skip_serializing_if = "Vec::is_empty")] pub urls: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub segmented_urls: Vec, pub sha256: String, } diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index a746fdd45fd8..7da8fb6cacba 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -1,9 +1,13 @@ //! Responsible for downloading ML models for use with tabby. -use std::{fs, path::Path}; +use std::{ + fs::{self, File}, + io::{BufRead, BufReader, Write}, + path::Path, +}; use aim_downloader::{bar::WrappedBar, error::DownloadError, https}; use anyhow::{anyhow, bail, Result}; -use tabby_common::registry::{parse_model_id, ModelRegistry}; +use tabby_common::registry::{parse_model_id, ModelInfo, ModelRegistry}; use tokio_retry::{ strategy::{jitter, ExponentialBackoff}, Retry, @@ -37,6 +41,10 @@ async fn download_model_impl( } } + if !model_info.segmented_urls.is_empty() { + return download_split_model(&model_info, &model_path).await; + } + let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned()); let Some(model_url) = model_info.urls.iter().find(|x| x.contains(®istry)) else { return Err(anyhow!( @@ -52,6 +60,44 @@ async fn download_model_impl( Ok(()) } +async fn download_split_model(model_info: &ModelInfo, model_path: &Path) -> Result<()> { + if !model_info.urls.is_empty() { + return Err(anyhow!( + "{}: Cannot specify both `urls` and `segmented_urls`", + model_info.name + )); + } + let mut paths = vec![]; + for (index, url) in model_info.segmented_urls.iter().enumerate() { + let ext = format!( + "{}.{}", + model_path.extension().unwrap_or_default().to_string_lossy(), + index.to_string() + ); + let path = model_path.with_extension(ext); + let strategy = ExponentialBackoff::from_millis(100).map(jitter).take(2); + let download_job = Retry::spawn(strategy, || download_file(url, &path)); + download_job.await?; + paths.push(path); + } + info!("Merging split model files..."); + let mut file = File::open(model_path)?; + for path in paths { + let mut reader = BufReader::new(File::open(&path)?); + loop { + let buffer = reader.fill_buf()?; + file.write_all(buffer)?; + let len = buffer.len(); + reader.consume(len); + if len == 0 { + break; + } + } + std::fs::remove_file(path)?; + } + Ok(()) +} + async fn download_file(url: &str, path: &Path) -> Result<()> { let dir = path .parent()