-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Added ollama api connection options (#2227)
* feat: Add ollama-api-binding crate Add support of it in http-api-binding as well * fix(ollama-api-binding): return empty stream if error occurs in completion * refactor(ollama-api-binding): Control model pulling via env var Added TABBY_OLLAMA_ALLOW_PULL env to enable pulling in Ollama * refactor(ollama-api-bindings): Do not use first available model if template is specified. It should be assumed that the user has tuned prompt or chat template for specific model, so it's better to ask the user to specify a model explicitly instead using a whatever model is available in the Ollama. * refactor(http-api-bindings): Update ollama embedding kind name * refactor(ollama-api-bindings): apply formatting * refactor(http-api-bindings): Update ollama completion kind name * refactor(http-api-bindings): Update ollama chat kind name
- Loading branch information
1 parent
5aedf9c
commit 1d1edfe
Showing
14 changed files
with
435 additions
and
34 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
[package] | ||
name = "ollama-api-bindings" | ||
version.workspace = true | ||
edition.workspace = true | ||
authors.workspace = true | ||
homepage.workspace = true | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
tabby-common = { path = "../tabby-common" } | ||
tabby-inference = { path = "../tabby-inference" } | ||
|
||
anyhow.workspace = true | ||
async-stream.workspace = true | ||
async-trait.workspace = true | ||
futures.workspace = true | ||
tracing.workspace = true | ||
|
||
# Use git version for now: https://github.com/pepperoni21/ollama-rs/issues/44 is required to correct work with normal URLs | ||
[dependencies.ollama-rs] | ||
git = "https://github.com/pepperoni21/ollama-rs.git" | ||
rev = "56e8157d98d4185bc171fe9468d3d09bc56e9dd3" | ||
features = ["stream"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
use std::sync::Arc; | ||
|
||
use anyhow::{bail, Result}; | ||
use async_trait::async_trait; | ||
use futures::{stream::BoxStream, StreamExt}; | ||
use ollama_rs::{ | ||
generation::{ | ||
chat::{request::ChatMessageRequest, ChatMessage, MessageRole}, | ||
options::GenerationOptions, | ||
}, | ||
Ollama, | ||
}; | ||
use tabby_common::{api::chat::Message, config::HttpModelConfig}; | ||
use tabby_inference::{ChatCompletionOptions, ChatCompletionStream}; | ||
|
||
use crate::model::OllamaModelExt; | ||
|
||
/// A special adapter to convert Tabby messages to ollama-rs messages | ||
struct ChatMessageAdapter(ChatMessage); | ||
|
||
impl TryFrom<Message> for ChatMessageAdapter { | ||
type Error = anyhow::Error; | ||
fn try_from(value: Message) -> Result<ChatMessageAdapter> { | ||
let role = match value.role.as_str() { | ||
"system" => MessageRole::System, | ||
"assistant" => MessageRole::Assistant, | ||
"user" => MessageRole::User, | ||
other => bail!("Unsupported chat message role: {other}"), | ||
}; | ||
|
||
Ok(ChatMessageAdapter(ChatMessage::new(role, value.content))) | ||
} | ||
} | ||
|
||
impl From<ChatMessageAdapter> for ChatMessage { | ||
fn from(val: ChatMessageAdapter) -> Self { | ||
val.0 | ||
} | ||
} | ||
|
||
/// Ollama chat completions | ||
pub struct OllamaChat { | ||
/// Connection to Ollama API | ||
connection: Ollama, | ||
/// Model name, <model> | ||
model: String, | ||
} | ||
|
||
#[async_trait] | ||
impl ChatCompletionStream for OllamaChat { | ||
async fn chat_completion( | ||
&self, | ||
messages: &[Message], | ||
options: ChatCompletionOptions, | ||
) -> Result<BoxStream<String>> { | ||
let messages = messages | ||
.iter() | ||
.map(|m| ChatMessageAdapter::try_from(m.to_owned())) | ||
.collect::<Result<Vec<_>, _>>()?; | ||
|
||
let messages = messages.into_iter().map(|m| m.into()).collect::<Vec<_>>(); | ||
|
||
let options = GenerationOptions::default() | ||
.seed(options.seed as i32) | ||
.temperature(options.sampling_temperature) | ||
.num_predict(options.max_decoding_tokens); | ||
|
||
let request = ChatMessageRequest::new(self.model.to_owned(), messages).options(options); | ||
|
||
let stream = self.connection.send_chat_messages_stream(request).await?; | ||
|
||
let stream = stream | ||
.map(|x| match x { | ||
Ok(response) => response.message, | ||
Err(_) => None, | ||
}) | ||
.map(|x| match x { | ||
Some(e) => e.content, | ||
None => "".to_owned(), | ||
}); | ||
|
||
Ok(stream.boxed()) | ||
} | ||
} | ||
|
||
pub async fn create(config: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> { | ||
let connection = Ollama::try_new(config.api_endpoint.to_owned()) | ||
.expect("Failed to create connection to Ollama, URL invalid"); | ||
|
||
let model = connection.select_model_or_default(config).await.unwrap(); | ||
|
||
Arc::new(OllamaChat { connection, model }) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
use std::sync::Arc; | ||
|
||
use async_stream::stream; | ||
use async_trait::async_trait; | ||
use futures::{stream::BoxStream, StreamExt}; | ||
use ollama_rs::{ | ||
generation::{completion::request::GenerationRequest, options::GenerationOptions}, | ||
Ollama, | ||
}; | ||
use tabby_common::config::HttpModelConfig; | ||
use tabby_inference::{CompletionOptions, CompletionStream}; | ||
use tracing::error; | ||
|
||
use crate::model::OllamaModelExt; | ||
|
||
pub struct OllamaCompletion { | ||
/// Connection to Ollama API | ||
connection: Ollama, | ||
/// Model name, <model> | ||
model: String, | ||
} | ||
|
||
#[async_trait] | ||
impl CompletionStream for OllamaCompletion { | ||
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> { | ||
let ollama_options = GenerationOptions::default() | ||
.num_ctx(options.max_input_length as u32) | ||
.num_predict(options.max_decoding_tokens) | ||
.seed(options.seed as i32) | ||
.temperature(options.sampling_temperature); | ||
let request = GenerationRequest::new(self.model.to_owned(), prompt.to_owned()) | ||
.template("{{ .Prompt }}".to_string()) | ||
.options(ollama_options); | ||
|
||
// Why this function returns not Result? | ||
match self.connection.generate_stream(request).await { | ||
Ok(stream) => { | ||
let tabby_stream = stream! { | ||
|
||
for await response in stream { | ||
let parts = response.unwrap(); | ||
for part in parts { | ||
yield part.response | ||
} | ||
} | ||
|
||
}; | ||
|
||
tabby_stream.boxed() | ||
} | ||
Err(err) => { | ||
error!("Failed to generate completion: {}", err); | ||
futures::stream::empty().boxed() | ||
} | ||
} | ||
} | ||
} | ||
|
||
pub async fn create(config: &HttpModelConfig) -> Arc<dyn CompletionStream> { | ||
let connection = Ollama::try_new(config.api_endpoint.to_owned()) | ||
.expect("Failed to create connection to Ollama, URL invalid"); | ||
|
||
let model = connection.select_model_or_default(config).await.unwrap(); | ||
|
||
Arc::new(OllamaCompletion { connection, model }) | ||
} |
Oops, something went wrong.