diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index 25e612a5a93c..b62942352311 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -279,7 +279,7 @@ pub async fn create_embedding(config: &ModelConfig) -> Arc { async fn resolve_model_path(model_id: &str) -> String { let path = PathBuf::from(model_id); let path = if path.exists() { - path.join(GGML_MODEL_RELATIVE_PATH) + path.join(GGML_MODEL_RELATIVE_PATH.as_str()) } else { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; diff --git a/crates/tabby-common/src/api/event.rs b/crates/tabby-common/src/api/event.rs index de8b48f31a14..674ad25f3599 100644 --- a/crates/tabby-common/src/api/event.rs +++ b/crates/tabby-common/src/api/event.rs @@ -67,7 +67,7 @@ pub enum Event { #[serde(skip_serializing_if = "Option::is_none")] segments: Option, choices: Vec, - user_agent: String, + user_agent: Option, }, ChatCompletion { completion_id: String, diff --git a/crates/tabby-common/src/registry.rs b/crates/tabby-common/src/registry.rs index 81950c709876..862e2388c02c 100644 --- a/crates/tabby-common/src/registry.rs +++ b/crates/tabby-common/src/registry.rs @@ -1,6 +1,7 @@ use std::{fs, path::PathBuf}; use anyhow::{Context, Result}; +use lazy_static::lazy_static; use serde::{Deserialize, Serialize}; use crate::path::models_dir; @@ -76,7 +77,7 @@ impl ModelRegistry { let model_path = self.get_model_path(name); let old_model_path = self .get_model_dir(name) - .join(LEGACY_GGML_MODEL_RELATIVE_PATH); + .join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str()); if !model_path.exists() && old_model_path.exists() { std::fs::rename(&old_model_path, &model_path)?; @@ -89,7 +90,8 @@ impl ModelRegistry { } pub fn get_model_path(&self, name: &str) -> PathBuf { - self.get_model_dir(name).join(GGML_MODEL_RELATIVE_PATH) + self.get_model_dir(name) + .join(GGML_MODEL_RELATIVE_PATH.as_str()) } pub fn save_model_info(&self, name: &str) { @@ -118,8 +120,12 @@ pub fn parse_model_id(model_id: &str) -> (&str, &str) { } } -pub static LEGACY_GGML_MODEL_RELATIVE_PATH: &str = "ggml/q8_0.v2.gguf"; -pub static GGML_MODEL_RELATIVE_PATH: &str = "ggml/model.gguf"; +lazy_static! { + pub static ref LEGACY_GGML_MODEL_RELATIVE_PATH: String = + format!("ggml{}q8_0.v2.gguf", std::path::MAIN_SEPARATOR_STR); + pub static ref GGML_MODEL_RELATIVE_PATH: String = + format!("ggml{}model.gguf", std::path::MAIN_SEPARATOR_STR); +} #[cfg(test)] mod tests { @@ -136,7 +142,7 @@ mod tests { let registry = ModelRegistry::new("TabbyML").await; let dir = registry.get_model_dir("StarCoder-1B"); - let old_model_path = dir.join(LEGACY_GGML_MODEL_RELATIVE_PATH); + let old_model_path = dir.join(LEGACY_GGML_MODEL_RELATIVE_PATH.as_str()); tokio::fs::create_dir_all(old_model_path.parent().unwrap()) .await .unwrap(); diff --git a/crates/tabby/src/routes/completions.rs b/crates/tabby/src/routes/completions.rs index dbb6438eb709..0e4fda02736a 100644 --- a/crates/tabby/src/routes/completions.rs +++ b/crates/tabby/src/routes/completions.rs @@ -26,14 +26,16 @@ use crate::services::completion::{CompletionRequest, CompletionResponse, Complet pub async fn completions( State(state): State>, TypedHeader(MaybeUser(user)): TypedHeader, - TypedHeader(user_agent): TypedHeader, + user_agent: Option>, Json(mut request): Json, ) -> Result, StatusCode> { if let Some(user) = user { request.user.replace(user); } - match state.generate(&request, &user_agent.to_string()).await { + let user_agent = user_agent.map(|x| x.0.to_string()); + + match state.generate(&request, user_agent.as_deref()).await { Ok(resp) => Ok(Json(resp)), Err(err) => { warn!("{}", err); diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index b33dd8d54937..926a6ce0031b 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -294,7 +294,7 @@ impl CompletionService { pub async fn generate( &self, request: &CompletionRequest, - user_agent: &str, + user_agent: Option<&str>, ) -> Result { let completion_id = format!("cmpl-{}", uuid::Uuid::new_v4()); let language = request.language_or_unknown(); @@ -338,7 +338,7 @@ impl CompletionService { index: 0, text: text.clone(), }], - user_agent: user_agent.to_string(), + user_agent: user_agent.map(|x| x.to_owned()), }, ); @@ -462,7 +462,7 @@ mod tests { }; let response = completion_service - .generate(&request, "test user agent") + .generate(&request, Some("test user agent")) .await .unwrap(); assert_eq!(response.choices[0].text, r#""Hello, world!""#); diff --git a/crates/tabby/tests/goldentests.rs b/crates/tabby/tests/goldentests.rs index 9b10999970f3..d2cdd02797c0 100644 --- a/crates/tabby/tests/goldentests.rs +++ b/crates/tabby/tests/goldentests.rs @@ -79,6 +79,16 @@ async fn golden_test(body: serde_json::Value) -> serde_json::Value { }), ); + let resp = CLIENT + .post("http://127.0.0.1:9090/v1/completions") + .json(&body) + .send() + .await + .unwrap(); + + let info = resp.text().await.unwrap(); + eprintln!("info {}", info); + let actual: serde_json::Value = CLIENT .post("http://127.0.0.1:9090/v1/completions") .json(&body) diff --git a/ee/tabby-webserver/src/service/event_logger.rs b/ee/tabby-webserver/src/service/event_logger.rs index 174a87adfd6b..2f6124b8330c 100644 --- a/ee/tabby-webserver/src/service/event_logger.rs +++ b/ee/tabby-webserver/src/service/event_logger.rs @@ -165,7 +165,7 @@ mod tests { prompt: "testprompt".into(), segments: None, choices: vec![], - user_agent: "ide: version test".into(), + user_agent: Some("ide: version test".into()), }, ); @@ -242,7 +242,7 @@ mod tests { prompt: "testprompt".into(), segments: None, choices: vec![], - user_agent: "ide: version unknown".into(), + user_agent: Some("ide: version unknown".into()), }, ); @@ -257,7 +257,7 @@ mod tests { prompt: "testprompt".into(), segments: None, choices: vec![], - user_agent: "ide: version unknown".into(), + user_agent: Some("ide: version unknown".into()), }, );