From 1d1edfec6ebc09526399abee428f9406efd31b87 Mon Sep 17 00:00:00 2001 From: Anton Kushakov <57725022+SpeedCrash100@users.noreply.github.com> Date: Sun, 26 May 2024 04:13:32 +0300 Subject: [PATCH] feat: Added ollama api connection options (#2227) * feat: Add ollama-api-binding crate Add support of it in http-api-binding as well * fix(ollama-api-binding): return empty stream if error occurs in completion * refactor(ollama-api-binding): Control model pulling via env var Added TABBY_OLLAMA_ALLOW_PULL env to enable pulling in Ollama * refactor(ollama-api-bindings): Do not use first available model if template is specified. It should be assumed that the user has tuned prompt or chat template for specific model, so it's better to ask the user to specify a model explicitly instead using a whatever model is available in the Ollama. * refactor(http-api-bindings): Update ollama embedding kind name * refactor(ollama-api-bindings): apply formatting * refactor(http-api-bindings): Update ollama completion kind name * refactor(http-api-bindings): Update ollama chat kind name --- Cargo.lock | 32 ++++- Cargo.toml | 1 + crates/http-api-bindings/Cargo.toml | 3 +- crates/http-api-bindings/src/chat/mod.rs | 14 +- .../http-api-bindings/src/completion/mod.rs | 18 ++- crates/http-api-bindings/src/embedding/mod.rs | 34 +++-- crates/llama-cpp-server/src/lib.rs | 6 +- crates/ollama-api-bindings/Cargo.toml | 24 ++++ crates/ollama-api-bindings/src/chat.rs | 93 +++++++++++++ crates/ollama-api-bindings/src/completion.rs | 66 +++++++++ crates/ollama-api-bindings/src/embedding.rs | 36 +++++ crates/ollama-api-bindings/src/lib.rs | 10 ++ crates/ollama-api-bindings/src/model.rs | 128 ++++++++++++++++++ crates/tabby/src/services/model/mod.rs | 4 +- 14 files changed, 435 insertions(+), 34 deletions(-) create mode 100644 crates/ollama-api-bindings/Cargo.toml create mode 100644 crates/ollama-api-bindings/src/chat.rs create mode 100644 crates/ollama-api-bindings/src/completion.rs create mode 100644 crates/ollama-api-bindings/src/embedding.rs create mode 100644 crates/ollama-api-bindings/src/lib.rs create mode 100644 crates/ollama-api-bindings/src/model.rs diff --git a/Cargo.lock b/Cargo.lock index 3e2fd06c17c0..f692707ee0dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2110,6 +2110,7 @@ dependencies = [ "async-stream", "async-trait", "futures", + "ollama-api-bindings", "reqwest 0.12.4", "reqwest-eventsource", "serde", @@ -3336,6 +3337,33 @@ dependencies = [ "url", ] +[[package]] +name = "ollama-api-bindings" +version = "0.12.0-dev.0" +dependencies = [ + "anyhow", + "async-stream", + "async-trait", + "futures", + "ollama-rs", + "tabby-common", + "tabby-inference", + "tracing", +] + +[[package]] +name = "ollama-rs" +version = "0.1.9" +source = "git+https://github.com/pepperoni21/ollama-rs.git?rev=56e8157d98d4185bc171fe9468d3d09bc56e9dd3#56e8157d98d4185bc171fe9468d3d09bc56e9dd3" +dependencies = [ + "reqwest 0.12.4", + "serde", + "serde_json", + "tokio", + "tokio-stream", + "url", +] + [[package]] name = "omnicopy_to_output" version = "0.1.1" @@ -5989,9 +6017,9 @@ dependencies = [ [[package]] name = "tokio-stream" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "397c988d37662c7dda6d2208364a706264bf3d6138b11d436cbac0ad38832842" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" dependencies = [ "futures-core", "pin-project-lite", diff --git a/Cargo.toml b/Cargo.toml index 988b041204d4..473f6e660083 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "crates/aim-downloader", "crates/http-api-bindings", "crates/llama-cpp-server", + "crates/ollama-api-bindings", "ee/tabby-webserver", "ee/tabby-db", diff --git a/crates/http-api-bindings/Cargo.toml b/crates/http-api-bindings/Cargo.toml index 082c90a8319d..022cd3455385 100644 --- a/crates/http-api-bindings/Cargo.toml +++ b/crates/http-api-bindings/Cargo.toml @@ -17,7 +17,8 @@ serde.workspace = true serde_json = { workspace = true } tabby-common = { path = "../tabby-common" } tabby-inference = { path = "../tabby-inference" } +ollama-api-bindings = { path = "../ollama-api-bindings" } tracing.workspace = true [dev-dependencies] -tokio ={ workspace = true, features = ["rt", "macros"]} +tokio = { workspace = true, features = ["rt", "macros"] } diff --git a/crates/http-api-bindings/src/chat/mod.rs b/crates/http-api-bindings/src/chat/mod.rs index b3ebdad18f41..4e9e3d00474c 100644 --- a/crates/http-api-bindings/src/chat/mod.rs +++ b/crates/http-api-bindings/src/chat/mod.rs @@ -6,15 +6,15 @@ use openai_chat::OpenAIChatEngine; use tabby_common::config::HttpModelConfig; use tabby_inference::ChatCompletionStream; -pub fn create(model: &HttpModelConfig) -> Arc { - if model.kind == "openai-chat" { - let engine = OpenAIChatEngine::create( +pub async fn create(model: &HttpModelConfig) -> Arc { + match model.kind.as_str() { + "openai-chat" => Arc::new(OpenAIChatEngine::create( &model.api_endpoint, model.model_name.as_deref().unwrap_or_default(), model.api_key.clone(), - ); - Arc::new(engine) - } else { - panic!("Only openai-chat are supported for http chat"); + )), + "ollama/chat" => ollama_api_bindings::create_chat(model).await, + + unsupported_kind => panic!("Unsupported kind for http chat: {}", unsupported_kind), } } diff --git a/crates/http-api-bindings/src/completion/mod.rs b/crates/http-api-bindings/src/completion/mod.rs index 591e22241b55..d97051446292 100644 --- a/crates/http-api-bindings/src/completion/mod.rs +++ b/crates/http-api-bindings/src/completion/mod.rs @@ -6,11 +6,17 @@ use llama::LlamaCppEngine; use tabby_common::config::HttpModelConfig; use tabby_inference::CompletionStream; -pub fn create(model: &HttpModelConfig) -> Arc { - if model.kind == "llama.cpp/completion" { - let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone()); - Arc::new(engine) - } else { - panic!("Unsupported model kind: {}", model.kind); +pub async fn create(model: &HttpModelConfig) -> Arc { + match model.kind.as_str() { + "llama.cpp/completion" => { + let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone()); + Arc::new(engine) + } + "ollama/completion" => ollama_api_bindings::create_completion(model).await, + + unsupported_kind => panic!( + "Unsupported model kind for http completion: {}", + unsupported_kind + ), } } diff --git a/crates/http-api-bindings/src/embedding/mod.rs b/crates/http-api-bindings/src/embedding/mod.rs index 5181378debf1..686ccbfd565c 100644 --- a/crates/http-api-bindings/src/embedding/mod.rs +++ b/crates/http-api-bindings/src/embedding/mod.rs @@ -1,6 +1,7 @@ mod llama; mod openai; +use core::panic; use std::sync::Arc; use llama::LlamaCppEngine; @@ -9,18 +10,25 @@ use tabby_inference::Embedding; use self::openai::OpenAIEmbeddingEngine; -pub fn create(config: &HttpModelConfig) -> Arc { - if config.kind == "llama.cpp/embedding" { - let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone()); - Arc::new(engine) - } else if config.kind == "openai-embedding" { - let engine = OpenAIEmbeddingEngine::create( - &config.api_endpoint, - config.model_name.as_deref().unwrap_or_default(), - config.api_key.clone(), - ); - Arc::new(engine) - } else { - panic!("Only llama are supported for http embedding"); +pub async fn create(config: &HttpModelConfig) -> Arc { + match config.kind.as_str() { + "llama.cpp/embedding" => { + let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone()); + Arc::new(engine) + } + "openai-embedding" => { + let engine = OpenAIEmbeddingEngine::create( + &config.api_endpoint, + config.model_name.as_deref().unwrap_or_default(), + config.api_key.clone(), + ); + Arc::new(engine) + } + "ollama/embedding" => ollama_api_bindings::create_embedding(config).await, + + unsupported_kind => panic!( + "Unsupported kind for http embedding model: {}", + unsupported_kind + ), } } diff --git a/crates/llama-cpp-server/src/lib.rs b/crates/llama-cpp-server/src/lib.rs index 330412ac1a00..3fa69522a8f1 100644 --- a/crates/llama-cpp-server/src/lib.rs +++ b/crates/llama-cpp-server/src/lib.rs @@ -35,7 +35,7 @@ impl EmbeddingServer { Self { server, - embedding: http_api_bindings::create_embedding(&config), + embedding: http_api_bindings::create_embedding(&config).await, } } } @@ -61,7 +61,7 @@ impl CompletionServer { .kind("llama.cpp/completion".to_string()) .build() .expect("Failed to create HttpModelConfig"); - let completion = http_api_bindings::create(&config); + let completion = http_api_bindings::create(&config).await; Self { server, completion } } } @@ -83,7 +83,7 @@ pub async fn create_completion( pub async fn create_embedding(config: &ModelConfig) -> Arc { match config { - ModelConfig::Http(http) => http_api_bindings::create_embedding(http), + ModelConfig::Http(http) => http_api_bindings::create_embedding(http).await, ModelConfig::Local(llama) => { if fs::metadata(&llama.model_id).is_ok() { let path = PathBuf::from(&llama.model_id); diff --git a/crates/ollama-api-bindings/Cargo.toml b/crates/ollama-api-bindings/Cargo.toml new file mode 100644 index 000000000000..f20927f66b87 --- /dev/null +++ b/crates/ollama-api-bindings/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "ollama-api-bindings" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +tabby-common = { path = "../tabby-common" } +tabby-inference = { path = "../tabby-inference" } + +anyhow.workspace = true +async-stream.workspace = true +async-trait.workspace = true +futures.workspace = true +tracing.workspace = true + +# Use git version for now: https://github.com/pepperoni21/ollama-rs/issues/44 is required to correct work with normal URLs +[dependencies.ollama-rs] +git = "https://github.com/pepperoni21/ollama-rs.git" +rev = "56e8157d98d4185bc171fe9468d3d09bc56e9dd3" +features = ["stream"] diff --git a/crates/ollama-api-bindings/src/chat.rs b/crates/ollama-api-bindings/src/chat.rs new file mode 100644 index 000000000000..a627d5b43317 --- /dev/null +++ b/crates/ollama-api-bindings/src/chat.rs @@ -0,0 +1,93 @@ +use std::sync::Arc; + +use anyhow::{bail, Result}; +use async_trait::async_trait; +use futures::{stream::BoxStream, StreamExt}; +use ollama_rs::{ + generation::{ + chat::{request::ChatMessageRequest, ChatMessage, MessageRole}, + options::GenerationOptions, + }, + Ollama, +}; +use tabby_common::{api::chat::Message, config::HttpModelConfig}; +use tabby_inference::{ChatCompletionOptions, ChatCompletionStream}; + +use crate::model::OllamaModelExt; + +/// A special adapter to convert Tabby messages to ollama-rs messages +struct ChatMessageAdapter(ChatMessage); + +impl TryFrom for ChatMessageAdapter { + type Error = anyhow::Error; + fn try_from(value: Message) -> Result { + let role = match value.role.as_str() { + "system" => MessageRole::System, + "assistant" => MessageRole::Assistant, + "user" => MessageRole::User, + other => bail!("Unsupported chat message role: {other}"), + }; + + Ok(ChatMessageAdapter(ChatMessage::new(role, value.content))) + } +} + +impl From for ChatMessage { + fn from(val: ChatMessageAdapter) -> Self { + val.0 + } +} + +/// Ollama chat completions +pub struct OllamaChat { + /// Connection to Ollama API + connection: Ollama, + /// Model name, + model: String, +} + +#[async_trait] +impl ChatCompletionStream for OllamaChat { + async fn chat_completion( + &self, + messages: &[Message], + options: ChatCompletionOptions, + ) -> Result> { + let messages = messages + .iter() + .map(|m| ChatMessageAdapter::try_from(m.to_owned())) + .collect::, _>>()?; + + let messages = messages.into_iter().map(|m| m.into()).collect::>(); + + let options = GenerationOptions::default() + .seed(options.seed as i32) + .temperature(options.sampling_temperature) + .num_predict(options.max_decoding_tokens); + + let request = ChatMessageRequest::new(self.model.to_owned(), messages).options(options); + + let stream = self.connection.send_chat_messages_stream(request).await?; + + let stream = stream + .map(|x| match x { + Ok(response) => response.message, + Err(_) => None, + }) + .map(|x| match x { + Some(e) => e.content, + None => "".to_owned(), + }); + + Ok(stream.boxed()) + } +} + +pub async fn create(config: &HttpModelConfig) -> Arc { + let connection = Ollama::try_new(config.api_endpoint.to_owned()) + .expect("Failed to create connection to Ollama, URL invalid"); + + let model = connection.select_model_or_default(config).await.unwrap(); + + Arc::new(OllamaChat { connection, model }) +} diff --git a/crates/ollama-api-bindings/src/completion.rs b/crates/ollama-api-bindings/src/completion.rs new file mode 100644 index 000000000000..7161cdb5d442 --- /dev/null +++ b/crates/ollama-api-bindings/src/completion.rs @@ -0,0 +1,66 @@ +use std::sync::Arc; + +use async_stream::stream; +use async_trait::async_trait; +use futures::{stream::BoxStream, StreamExt}; +use ollama_rs::{ + generation::{completion::request::GenerationRequest, options::GenerationOptions}, + Ollama, +}; +use tabby_common::config::HttpModelConfig; +use tabby_inference::{CompletionOptions, CompletionStream}; +use tracing::error; + +use crate::model::OllamaModelExt; + +pub struct OllamaCompletion { + /// Connection to Ollama API + connection: Ollama, + /// Model name, + model: String, +} + +#[async_trait] +impl CompletionStream for OllamaCompletion { + async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream { + let ollama_options = GenerationOptions::default() + .num_ctx(options.max_input_length as u32) + .num_predict(options.max_decoding_tokens) + .seed(options.seed as i32) + .temperature(options.sampling_temperature); + let request = GenerationRequest::new(self.model.to_owned(), prompt.to_owned()) + .template("{{ .Prompt }}".to_string()) + .options(ollama_options); + + // Why this function returns not Result? + match self.connection.generate_stream(request).await { + Ok(stream) => { + let tabby_stream = stream! { + + for await response in stream { + let parts = response.unwrap(); + for part in parts { + yield part.response + } + } + + }; + + tabby_stream.boxed() + } + Err(err) => { + error!("Failed to generate completion: {}", err); + futures::stream::empty().boxed() + } + } + } +} + +pub async fn create(config: &HttpModelConfig) -> Arc { + let connection = Ollama::try_new(config.api_endpoint.to_owned()) + .expect("Failed to create connection to Ollama, URL invalid"); + + let model = connection.select_model_or_default(config).await.unwrap(); + + Arc::new(OllamaCompletion { connection, model }) +} diff --git a/crates/ollama-api-bindings/src/embedding.rs b/crates/ollama-api-bindings/src/embedding.rs new file mode 100644 index 000000000000..153a460cba39 --- /dev/null +++ b/crates/ollama-api-bindings/src/embedding.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use ollama_rs::Ollama; +use tabby_common::config::HttpModelConfig; +use tabby_inference::Embedding; + +use crate::model::OllamaModelExt; + +pub struct OllamaCompletion { + /// Connection to Ollama API + connection: Ollama, + /// Model name, + model: String, +} + +#[async_trait] +impl Embedding for OllamaCompletion { + async fn embed(&self, prompt: &str) -> anyhow::Result> { + self.connection + .generate_embeddings(self.model.to_owned(), prompt.to_owned(), None) + .await + .map(|x| x.embeddings) + .map(|e| e.iter().map(|v| *v as f32).collect()) + .map_err(|err| err.into()) + } +} + +pub async fn create(config: &HttpModelConfig) -> Arc { + let connection = Ollama::try_new(config.api_endpoint.to_owned()) + .expect("Failed to create connection to Ollama, URL invalid"); + + let model = connection.select_model_or_default(config).await.unwrap(); + + Arc::new(OllamaCompletion { connection, model }) +} diff --git a/crates/ollama-api-bindings/src/lib.rs b/crates/ollama-api-bindings/src/lib.rs new file mode 100644 index 000000000000..424434d531bf --- /dev/null +++ b/crates/ollama-api-bindings/src/lib.rs @@ -0,0 +1,10 @@ +mod model; + +mod chat; +pub use chat::create as create_chat; + +mod completion; +pub use completion::create as create_completion; + +mod embedding; +pub use embedding::create as create_embedding; diff --git a/crates/ollama-api-bindings/src/model.rs b/crates/ollama-api-bindings/src/model.rs new file mode 100644 index 000000000000..8b75a72d1a86 --- /dev/null +++ b/crates/ollama-api-bindings/src/model.rs @@ -0,0 +1,128 @@ +//! +//! Ollama model management utils +//! + +use anyhow::{anyhow, bail, Result}; +use async_trait::async_trait; +use futures::StreamExt; +use ollama_rs::Ollama; +use tabby_common::config::HttpModelConfig; +use tracing::{info, warn}; + +/// Env variable for allowing pulling models with Ollama +static ALLOW_PULL_ENV: &str = "TABBY_OLLAMA_ALLOW_PULL"; + +#[async_trait] +pub trait OllamaModelExt { + /// Check if a model is available in remote Ollama instance + async fn model_available(&self, name: impl AsRef + Send) -> Result; + + /// Get the first available model in remote Ollama instance + async fn get_first_available_model(&self) -> Result>; + + /// For input model specification: + /// - If model is specified, check if it is available in remote Ollama instance and returns its name + /// - If model is not specified and prompt/chat templates are specified, returns a error because it is unsound + /// - If model is not specified and prompt/chat templates are not specified get the first available model in remote Ollama instance and returns its name + /// - If no model is available, returns error + /// - If model is specified and not available, tries to pull it if a env `TABBY_OLLAMA_ALLOW_PULL` equal to `1`, `y`, or `yes` + /// and returns error if the environment variable is not set or haves a wrong value + /// + /// # Parameters + /// - `config`: model config configuration + /// + /// # Returns + /// - model name to use + async fn select_model_or_default(&self, config: &HttpModelConfig) -> Result; + + /// Pull model and puts progress in tracing + async fn pull_model_with_tracing(&self, model: &str) -> Result<()>; +} + +#[async_trait] +impl OllamaModelExt for Ollama { + async fn model_available(&self, name: impl AsRef + Send) -> Result { + let name = name.as_ref(); + + let models_available = self.list_local_models().await?; + + Ok(models_available.into_iter().any(|model| model.name == name)) + } + + async fn get_first_available_model(&self) -> Result> { + let models_available = self.list_local_models().await?; + + Ok(models_available.first().map(|x| x.name.to_owned())) + } + + async fn select_model_or_default(&self, config: &HttpModelConfig) -> Result { + let prompt_or_chat_templates_set = + config.prompt_template.is_some() || config.chat_template.is_some(); + + let model = match config.model_name.to_owned() { + Some(ref model) => model.to_owned(), + None => { + let model = self + .get_first_available_model() + .await? + .ok_or(anyhow!("Ollama instances does not have any models"))?; + + if prompt_or_chat_templates_set { + bail!("No model name is provided but prompt or chat templates are set. Please set model name explicitly") + } + + warn!( + "No model name is provided, using first available: {}", + model + ); + model + } + }; + + let available = self.model_available(&model).await?; + + let allow_pull = std::env::var_os(ALLOW_PULL_ENV) + .map(|x| x == "1" || x.to_ascii_lowercase() == "y" || x.to_ascii_lowercase() == "yes") + .unwrap_or(false); + + match (available, allow_pull) { + (true, _) => Ok(model), + (false, true) => { + info!("Model is not available, pulling it"); + self.pull_model_with_tracing(model.as_str()).await?; + Ok(model) + } + (false, false) => { + bail!("Model is not available, and pulling is disabled") + } + } + } + + async fn pull_model_with_tracing(&self, model: &str) -> Result<()> { + let mut stream = self.pull_model_stream(model.to_owned(), false).await?; + + let mut last_status = "".to_string(); + let mut last_progress = 0.0; + + while let Some(result) = stream.next().await { + let response = result?; + let status = response.message; + if last_status != status { + info!("Status: {}", status); + last_status = status; + last_progress = 0.0; + } + + // Show progress only if 1% gain happened + if let (Some(completed), Some(total)) = (response.completed, response.total) { + let progress = completed as f64 / total as f64; + if progress - last_progress > 0.01 { + info!("Progress: {:.2}%", progress * 100.0); + last_progress = progress; + } + } + } + + Ok(()) + } +} diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index 8924fd48e944..573f3ae5ea60 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -15,7 +15,7 @@ use crate::fatal; pub async fn load_chat_completion(chat: &ModelConfig) -> Arc { match chat { - ModelConfig::Http(http) => http_api_bindings::create_chat(http), + ModelConfig::Http(http) => http_api_bindings::create_chat(http).await, ModelConfig::Local(_) => { let (engine, PromptInfo { chat_template, .. }) = load_completion(chat).await; @@ -41,7 +41,7 @@ pub async fn load_code_generation(model: &ModelConfig) -> (Arc, async fn load_completion(model: &ModelConfig) -> (Arc, PromptInfo) { match model { ModelConfig::Http(http) => { - let engine = http_api_bindings::create(http); + let engine = http_api_bindings::create(http).await; ( engine, PromptInfo {