diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index 9b9bcb18db55..ef1ba84804d4 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -10,7 +10,7 @@ use serde::Deserialize; use supervisor::LlamaCppSupervisor; use tabby_common::{ config::{HttpModelConfigBuilder, LocalModelConfig, ModelConfig}, - registry::{parse_model_id, ModelRegistry}, + registry::{parse_model_id, ModelRegistry, GGML_MODEL_PARTITIONED_PREFIX}, }; use tabby_inference::{ChatCompletionStream, CompletionOptions, CompletionStream, Embedding}; @@ -277,10 +277,20 @@ pub async fn create_embedding(config: &ModelConfig) -> Arc { } async fn resolve_model_path(model_id: &str) -> String { - let (registry, name) = parse_model_id(model_id); - let registry = ModelRegistry::new(registry).await; - let path = registry.get_model_entry_path(name); - path.unwrap().display().to_string() + let path = PathBuf::from(model_id); + let path = if path.exists() { + path.join("ggml").join(format!( + "{}00001.gguf", + GGML_MODEL_PARTITIONED_PREFIX.to_owned() + )) + } else { + let (registry, name) = parse_model_id(model_id); + let registry = ModelRegistry::new(registry).await; + registry + .get_model_entry_path(name) + .expect("Model not found") + }; + path.display().to_string() } #[derive(Deserialize)]