Skip to content

Commit

Permalink
fix: should check local model before resolving model id
Browse files Browse the repository at this point in the history
  • Loading branch information
zwpaper committed Nov 26, 2024
1 parent 8e3e449 commit 56581c7
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use serde::Deserialize;
use supervisor::LlamaCppSupervisor;
use tabby_common::{
config::{HttpModelConfigBuilder, LocalModelConfig, ModelConfig},
registry::{parse_model_id, ModelRegistry},
registry::{parse_model_id, ModelRegistry, GGML_MODEL_PARTITIONED_PREFIX},
};
use tabby_inference::{ChatCompletionStream, CompletionOptions, CompletionStream, Embedding};

Expand Down Expand Up @@ -277,10 +277,20 @@ pub async fn create_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
}

async fn resolve_model_path(model_id: &str) -> String {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let path = registry.get_model_entry_path(name);
path.unwrap().display().to_string()
let path = PathBuf::from(model_id);
let path = if path.exists() {
path.join("ggml").join(format!(
"{}00001.gguf",
GGML_MODEL_PARTITIONED_PREFIX.to_owned()
))

Check warning on line 285 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L280-L285

Added lines #L280 - L285 were not covered by tests
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
registry
.get_model_entry_path(name)
.expect("Model not found")

Check warning on line 291 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L287-L291

Added lines #L287 - L291 were not covered by tests
};
path.display().to_string()

Check warning on line 293 in crates/llama-cpp-server/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

crates/llama-cpp-server/src/lib.rs#L293

Added line #L293 was not covered by tests
}

#[derive(Deserialize)]
Expand Down

0 comments on commit 56581c7

Please sign in to comment.