diff --git a/Cargo.lock b/Cargo.lock index 3e2fd06c17c0..f692707ee0dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2110,6 +2110,7 @@ dependencies = [ "async-stream", "async-trait", "futures", + "ollama-api-bindings", "reqwest 0.12.4", "reqwest-eventsource", "serde", @@ -3336,6 +3337,33 @@ dependencies = [ "url", ] +[[package]] +name = "ollama-api-bindings" +version = "0.12.0-dev.0" +dependencies = [ + "anyhow", + "async-stream", + "async-trait", + "futures", + "ollama-rs", + "tabby-common", + "tabby-inference", + "tracing", +] + +[[package]] +name = "ollama-rs" +version = "0.1.9" +source = "git+https://github.com/pepperoni21/ollama-rs.git?rev=56e8157d98d4185bc171fe9468d3d09bc56e9dd3#56e8157d98d4185bc171fe9468d3d09bc56e9dd3" +dependencies = [ + "reqwest 0.12.4", + "serde", + "serde_json", + "tokio", + "tokio-stream", + "url", +] + [[package]] name = "omnicopy_to_output" version = "0.1.1" @@ -5989,9 +6017,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" dependencies = [ "futures-core", "pin-project-lite", diff --git a/Cargo.toml b/Cargo.toml index 988b041204d4..473f6e660083 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "crates/aim-downloader", "crates/http-api-bindings", "crates/llama-cpp-server", + "crates/ollama-api-bindings", "ee/tabby-webserver", "ee/tabby-db", diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 082c90a8319d..022cd3455385 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -17,7 +17,8 @@ serde.workspace = true serde_json = { workspace = true } tabby-common = { path = "../tabby-common" } tabby-inference = { path = "../tabby-inference" } +ollama-api-bindings = { path = "../ollama-api-bindings" } tracing.workspace = true [dev-dependencies] -tokio ={ workspace = true, features = ["rt", "macros"]} +tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index b3ebdad18f41..4e9e3d00474c 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -6,15 +6,15 @@ use openai_chat::OpenAIChatEngine; use tabby_common::config::HttpModelConfig; use tabby_inference::ChatCompletionStream; -pub fn create(model: &HttpModelConfig) -> Arc { - if model.kind == "openai-chat" { - let engine = OpenAIChatEngine::create( +pub async fn create(model: &HttpModelConfig) -> Arc { + match model.kind.as_str() { + "openai-chat" => Arc::new(OpenAIChatEngine::create( &model.api_endpoint, model.model_name.as_deref().unwrap_or_default(), model.api_key.clone(), - ); - Arc::new(engine) - } else { - panic!("Only openai-chat are supported for http chat"); + )), + "ollama/chat" => ollama_api_bindings::create_chat(model).await, + + unsupported_kind => panic!("Unsupported kind for http chat: {}", unsupported_kind), } } diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs index 591e22241b55..d97051446292 100644 --- a/crates/http-api-bindings/src/completion/mod.rs +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -6,11 +6,17 @@ use llama::LlamaCppEngine; use tabby_common::config::HttpModelConfig; use tabby_inference::CompletionStream; -pub fn create(model: &HttpModelConfig) -> Arc { - if model.kind == "llama.cpp/completion" { - let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone()); - Arc::new(engine) - } else { - panic!("Unsupported model kind: {}", model.kind); +pub async fn create(model: &HttpModelConfig) -> Arc { + match model.kind.as_str() { + "llama.cpp/completion" => { + let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone()); + Arc::new(engine) + } + "ollama/completion" => ollama_api_bindings::create_completion(model).await, + + unsupported_kind => panic!( + "Unsupported model kind for http completion: {}", + unsupported_kind + ), } } diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index 5181378debf1..686ccbfd565c 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -1,6 +1,7 @@ mod llama; mod openai; +use core::panic; use std::sync::Arc; use llama::LlamaCppEngine; @@ -9,18 +10,25 @@ use tabby_inference::Embedding; use self::openai::OpenAIEmbeddingEngine; -pub fn create(config: &HttpModelConfig) -> Arc { - if config.kind == "llama.cpp/embedding" { - let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone()); - Arc::new(engine) - } else if config.kind == "openai-embedding" { - let engine = OpenAIEmbeddingEngine::create( - &config.api_endpoint, - config.model_name.as_deref().unwrap_or_default(), - config.api_key.clone(), - ); - Arc::new(engine) - } else { - panic!("Only llama are supported for http embedding"); +pub async fn create(config: &HttpModelConfig) -> Arc { + match config.kind.as_str() { + "llama.cpp/embedding" => { + let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone()); + Arc::new(engine) + } + "openai-embedding" => { + let engine = OpenAIEmbeddingEngine::create( + &config.api_endpoint, + config.model_name.as_deref().unwrap_or_default(), + config.api_key.clone(), + ); + Arc::new(engine) + } + "ollama/embedding" => ollama_api_bindings::create_embedding(config).await, + + unsupported_kind => panic!( + "Unsupported kind for http embedding model: {}", + unsupported_kind + ), } } diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index 330412ac1a00..3fa69522a8f1 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -35,7 +35,7 @@ impl EmbeddingServer { Self { server, - embedding: http_api_bindings::create_embedding(&config), + embedding: http_api_bindings::create_embedding(&config).await, } } } @@ -61,7 +61,7 @@ impl CompletionServer { .kind("llama.cpp/completion".to_string()) .build() .expect("Failed to create HttpModelConfig"); - let completion = http_api_bindings::create(&config); + let completion = http_api_bindings::create(&config).await; Self { server, completion } } } @@ -83,7 +83,7 @@ pub async fn create_completion( pub async fn create_embedding(config: &ModelConfig) -> Arc { match config { - ModelConfig::Http(http) => http_api_bindings::create_embedding(http), + ModelConfig::Http(http) => http_api_bindings::create_embedding(http).await, ModelConfig::Local(llama) => { if fs::metadata(&llama.model_id).is_ok() { let path = PathBuf::from(&llama.model_id); diff --git a/crates/ollama-api-bindings/Cargo.toml b/crates/ollama-api-bindings/Cargo.toml new file mode 100644 index 000000000000..f20927f66b87 --- /dev/null +++ b/crates/ollama-api-bindings/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "ollama-api-bindings" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tabby-common = { path = "../tabby-common" } +tabby-inference = { path = "../tabby-inference" } + +anyhow.workspace = true +async-stream.workspace = true +async-trait.workspace = true +futures.workspace = true +tracing.workspace = true + +# Use git version for now: https://github.com/pepperoni21/ollama-rs/issues/44 is required to correct work with normal URLs +[dependencies.ollama-rs] +git = "https://github.com/pepperoni21/ollama-rs.git" +rev = "56e8157d98d4185bc171fe9468d3d09bc56e9dd3" +features = ["stream"] diff --git a/crates/ollama-api-bindings/src/chat.rs b/crates/ollama-api-bindings/src/chat.rs new file mode 100644 index 000000000000..a627d5b43317 --- /dev/null +++ b/crates/ollama-api-bindings/src/chat.rs @@ -0,0 +1,93 @@ +use std::sync::Arc; + +use anyhow::{bail, Result}; +use async_trait::async_trait; +use futures::{stream::BoxStream, StreamExt}; +use ollama_rs::{ + generation::{ + chat::{request::ChatMessageRequest, ChatMessage, MessageRole}, + options::GenerationOptions, + }, + Ollama, +}; +use tabby_common::{api::chat::Message, config::HttpModelConfig}; +use tabby_inference::{ChatCompletionOptions, ChatCompletionStream}; + +use crate::model::OllamaModelExt; + +/// A special adapter to convert Tabby messages to ollama-rs messages +struct ChatMessageAdapter(ChatMessage); + +impl TryFrom for ChatMessageAdapter { + type Error = anyhow::Error; + fn try_from(value: Message) -> Result { + let role = match value.role.as_str() { + "system" => MessageRole::System, + "assistant" => MessageRole::Assistant, + "user" => MessageRole::User, + other => bail!("Unsupported chat message role: {other}"), + }; + + Ok(ChatMessageAdapter(ChatMessage::new(role, value.content))) + } +} + +impl From for ChatMessage { + fn from(val: ChatMessageAdapter) -> Self { + val.0 + } +} + +/// Ollama chat completions +pub struct OllamaChat { + /// Connection to Ollama API + connection: Ollama, + /// Model name, + model: String, +} + +#[async_trait] +impl ChatCompletionStream for OllamaChat { + async fn chat_completion( + &self, + messages: &[Message], + options: ChatCompletionOptions, + ) -> Result> { + let messages = messages + .iter() + .map(|m| ChatMessageAdapter::try_from(m.to_owned())) + .collect::, _>>()?; + + let messages = messages.into_iter().map(|m| m.into()).collect::>(); + + let options = GenerationOptions::default() + .seed(options.seed as i32) + .temperature(options.sampling_temperature) + .num_predict(options.max_decoding_tokens); + + let request = ChatMessageRequest::new(self.model.to_owned(), messages).options(options); + + let stream = self.connection.send_chat_messages_stream(request).await?; + + let stream = stream + .map(|x| match x { + Ok(response) => response.message, + Err(_) => None, + }) + .map(|x| match x { + Some(e) => e.content, + None => "".to_owned(), + }); + + Ok(stream.boxed()) + } +} + +pub async fn create(config: &HttpModelConfig) -> Arc { + let connection = Ollama::try_new(config.api_endpoint.to_owned()) + .expect("Failed to create connection to Ollama, URL invalid"); + + let model = connection.select_model_or_default(config).await.unwrap(); + + Arc::new(OllamaChat { connection, model }) +} diff --git a/crates/ollama-api-bindings/src/completion.rs b/crates/ollama-api-bindings/src/completion.rs new file mode 100644 index 000000000000..7161cdb5d442 --- /dev/null +++ b/crates/ollama-api-bindings/src/completion.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use async_stream::stream; +use async_trait::async_trait; +use futures::{stream::BoxStream, StreamExt}; +use ollama_rs::{ + generation::{completion::request::GenerationRequest, options::GenerationOptions}, + Ollama, +}; +use tabby_common::config::HttpModelConfig; +use tabby_inference::{CompletionOptions, CompletionStream}; +use tracing::error; + +use crate::model::OllamaModelExt; + +pub struct OllamaCompletion { + /// Connection to Ollama API + connection: Ollama, + /// Model name, + model: String, +} + +#[async_trait] +impl CompletionStream for OllamaCompletion { + async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { + let ollama_options = GenerationOptions::default() + .num_ctx(options.max_input_length as u32) + .num_predict(options.max_decoding_tokens) + .seed(options.seed as i32) + .temperature(options.sampling_temperature); + let request = GenerationRequest::new(self.model.to_owned(), prompt.to_owned()) + .template("{{ .Prompt }}".to_string()) + .options(ollama_options); + + // Why this function returns not Result? + match self.connection.generate_stream(request).await { + Ok(stream) => { + let tabby_stream = stream! { + + for await response in stream { + let parts = response.unwrap(); + for part in parts { + yield part.response + } + } + + }; + + tabby_stream.boxed() + } + Err(err) => { + error!("Failed to generate completion: {}", err); + futures::stream::empty().boxed() + } + } + } +} + +pub async fn create(config: &HttpModelConfig) -> Arc { + let connection = Ollama::try_new(config.api_endpoint.to_owned()) + .expect("Failed to create connection to Ollama, URL invalid"); + + let model = connection.select_model_or_default(config).await.unwrap(); + + Arc::new(OllamaCompletion { connection, model }) +} diff --git a/crates/ollama-api-bindings/src/embedding.rs b/crates/ollama-api-bindings/src/embedding.rs new file mode 100644 index 000000000000..153a460cba39 --- /dev/null +++ b/crates/ollama-api-bindings/src/embedding.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use ollama_rs::Ollama; +use tabby_common::config::HttpModelConfig; +use tabby_inference::Embedding; + +use crate::model::OllamaModelExt; + +pub struct OllamaCompletion { + /// Connection to Ollama API + connection: Ollama, + /// Model name, + model: String, +} + +#[async_trait] +impl Embedding for OllamaCompletion { + async fn embed(&self, prompt: &str) -> anyhow::Result> { + self.connection + .generate_embeddings(self.model.to_owned(), prompt.to_owned(), None) + .await + .map(|x| x.embeddings) + .map(|e| e.iter().map(|v| *v as f32).collect()) + .map_err(|err| err.into()) + } +} + +pub async fn create(config: &HttpModelConfig) -> Arc { + let connection = Ollama::try_new(config.api_endpoint.to_owned()) + .expect("Failed to create connection to Ollama, URL invalid"); + + let model = connection.select_model_or_default(config).await.unwrap(); + + Arc::new(OllamaCompletion { connection, model }) +} diff --git a/crates/ollama-api-bindings/src/lib.rs b/crates/ollama-api-bindings/src/lib.rs new file mode 100644 index 000000000000..424434d531bf --- /dev/null +++ b/crates/ollama-api-bindings/src/lib.rs @@ -0,0 +1,10 @@ +mod model; + +mod chat; +pub use chat::create as create_chat; + +mod completion; +pub use completion::create as create_completion; + +mod embedding; +pub use embedding::create as create_embedding; diff --git a/crates/ollama-api-bindings/src/model.rs b/crates/ollama-api-bindings/src/model.rs new file mode 100644 index 000000000000..8b75a72d1a86 --- /dev/null +++ b/crates/ollama-api-bindings/src/model.rs @@ -0,0 +1,128 @@ +//! +//! Ollama model management utils +//! + +use anyhow::{anyhow, bail, Result}; +use async_trait::async_trait; +use futures::StreamExt; +use ollama_rs::Ollama; +use tabby_common::config::HttpModelConfig; +use tracing::{info, warn}; + +/// Env variable for allowing pulling models with Ollama +static ALLOW_PULL_ENV: &str = "TABBY_OLLAMA_ALLOW_PULL"; + +#[async_trait] +pub trait OllamaModelExt { + /// Check if a model is available in remote Ollama instance + async fn model_available(&self, name: impl AsRef + Send) -> Result; + + /// Get the first available model in remote Ollama instance + async fn get_first_available_model(&self) -> Result>; + + /// For input model specification: + /// - If model is specified, check if it is available in remote Ollama instance and returns its name + /// - If model is not specified and prompt/chat templates are specified, returns a error because it is unsound + /// - If model is not specified and prompt/chat templates are not specified get the first available model in remote Ollama instance and returns its name + /// - If no model is available, returns error + /// - If model is specified and not available, tries to pull it if a env `TABBY_OLLAMA_ALLOW_PULL` equal to `1`, `y`, or `yes` + /// and returns error if the environment variable is not set or haves a wrong value + /// + /// # Parameters + /// - `config`: model config configuration + /// + /// # Returns + /// - model name to use + async fn select_model_or_default(&self, config: &HttpModelConfig) -> Result; + + /// Pull model and puts progress in tracing + async fn pull_model_with_tracing(&self, model: &str) -> Result<()>; +} + +#[async_trait] +impl OllamaModelExt for Ollama { + async fn model_available(&self, name: impl AsRef + Send) -> Result { + let name = name.as_ref(); + + let models_available = self.list_local_models().await?; + + Ok(models_available.into_iter().any(|model| model.name == name)) + } + + async fn get_first_available_model(&self) -> Result> { + let models_available = self.list_local_models().await?; + + Ok(models_available.first().map(|x| x.name.to_owned())) + } + + async fn select_model_or_default(&self, config: &HttpModelConfig) -> Result { + let prompt_or_chat_templates_set = + config.prompt_template.is_some() || config.chat_template.is_some(); + + let model = match config.model_name.to_owned() { + Some(ref model) => model.to_owned(), + None => { + let model = self + .get_first_available_model() + .await? + .ok_or(anyhow!("Ollama instances does not have any models"))?; + + if prompt_or_chat_templates_set { + bail!("No model name is provided but prompt or chat templates are set. Please set model name explicitly") + } + + warn!( + "No model name is provided, using first available: {}", + model + ); + model + } + }; + + let available = self.model_available(&model).await?; + + let allow_pull = std::env::var_os(ALLOW_PULL_ENV) + .map(|x| x == "1" || x.to_ascii_lowercase() == "y" || x.to_ascii_lowercase() == "yes") + .unwrap_or(false); + + match (available, allow_pull) { + (true, _) => Ok(model), + (false, true) => { + info!("Model is not available, pulling it"); + self.pull_model_with_tracing(model.as_str()).await?; + Ok(model) + } + (false, false) => { + bail!("Model is not available, and pulling is disabled") + } + } + } + + async fn pull_model_with_tracing(&self, model: &str) -> Result<()> { + let mut stream = self.pull_model_stream(model.to_owned(), false).await?; + + let mut last_status = "".to_string(); + let mut last_progress = 0.0; + + while let Some(result) = stream.next().await { + let response = result?; + let status = response.message; + if last_status != status { + info!("Status: {}", status); + last_status = status; + last_progress = 0.0; + } + + // Show progress only if 1% gain happened + if let (Some(completed), Some(total)) = (response.completed, response.total) { + let progress = completed as f64 / total as f64; + if progress - last_progress > 0.01 { + info!("Progress: {:.2}%", progress * 100.0); + last_progress = progress; + } + } + } + + Ok(()) + } +} diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index 8924fd48e944..573f3ae5ea60 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -15,7 +15,7 @@ use crate::fatal; pub async fn load_chat_completion(chat: &ModelConfig) -> Arc { match chat { - ModelConfig::Http(http) => http_api_bindings::create_chat(http), + ModelConfig::Http(http) => http_api_bindings::create_chat(http).await, ModelConfig::Local(_) => { let (engine, PromptInfo { chat_template, .. }) = load_completion(chat).await; @@ -41,7 +41,7 @@ pub async fn load_code_generation(model: &ModelConfig) -> (Arc, async fn load_completion(model: &ModelConfig) -> (Arc, PromptInfo) { match model { ModelConfig::Http(http) => { - let engine = http_api_bindings::create(http); + let engine = http_api_bindings::create(http).await; ( engine, PromptInfo {