diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs index f82388c7cf99..dd0fe59c8c08 100644 --- a/crates/tabby-common/src/registry.rs +++ b/crates/tabby-common/src/registry.rs @@ -15,7 +15,7 @@ pub struct ModelInfo { #[serde(skip_serializing_if = "Option::is_none")] pub urls: Option>, #[serde(skip_serializing_if = "Option::is_none")] - pub partition_urls: Option>, + pub partition_urls: Option>>, pub sha256: String, } diff --git a/crates/tabby-download/src/lib.rs b/crates/tabby-download/src/lib.rs index 85b3553d6878..502ac857e708 100644 --- a/crates/tabby-download/src/lib.rs +++ b/crates/tabby-download/src/lib.rs @@ -14,6 +14,10 @@ use tokio_retry::{ }; use tracing::{info, warn}; +fn download_host() -> String { + std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned()) +} + async fn download_model_impl( registry: &ModelRegistry, name: &str, @@ -45,7 +49,7 @@ async fn download_model_impl( return download_split_model(&model_info, &model_path).await; } - let registry = std::env::var("TABBY_DOWNLOAD_HOST").unwrap_or("huggingface.co".to_owned()); + let registry = download_host(); let Some(model_url) = model_info .urls .iter() @@ -74,7 +78,20 @@ async fn download_split_model(model_info: &ModelInfo, model_path: &Path) -> Resu } let mut paths = vec![]; let partition_urls = model_info.partition_urls.clone().unwrap_or_default(); - for (index, url) in partition_urls.iter().enumerate() { + let mirror = download_host(); + + let Some(urls) = partition_urls + .iter() + .find(|urls| urls.iter().all(|url| url.contains(&mirror))) + else { + return Err(anyhow!( + "Invalid mirror <{}> for model urls: {:?}", + mirror, + partition_urls + )); + }; + + for (index, url) in urls.iter().enumerate() { let ext = format!( "{}.{}", model_path.extension().unwrap_or_default().to_string_lossy(),