Skip to content

Commit

Permalink
feat: Added ollama api connection options (#2227)
Browse files Browse the repository at this point in the history
* feat: Add ollama-api-binding crate

Add support of it in http-api-binding as well

* fix(ollama-api-binding): return empty stream if error occurs in completion

* refactor(ollama-api-binding): Control model pulling via env var

Added TABBY_OLLAMA_ALLOW_PULL env to enable pulling in Ollama

* refactor(ollama-api-bindings): Do not use first available model if template is specified.

It should be assumed that the user has tuned prompt or chat template for specific model, so it's better to ask the user to specify a model explicitly instead using a whatever model is available in the Ollama.

* refactor(http-api-bindings): Update ollama embedding kind name

* refactor(ollama-api-bindings): apply formatting

* refactor(http-api-bindings): Update ollama completion kind name

* refactor(http-api-bindings): Update ollama chat kind name
  • Loading branch information
SpeedCrash100 authored May 26, 2024
1 parent 5aedf9c commit 1d1edfe
Show file tree
Hide file tree
Showing 14 changed files with 435 additions and 34 deletions.
32 changes: 30 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ members = [
"crates/aim-downloader",
"crates/http-api-bindings",
"crates/llama-cpp-server",
"crates/ollama-api-bindings",

"ee/tabby-webserver",
"ee/tabby-db",
Expand Down
3 changes: 2 additions & 1 deletion crates/http-api-bindings/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ serde.workspace = true
serde_json = { workspace = true }
tabby-common = { path = "../tabby-common" }
tabby-inference = { path = "../tabby-inference" }
ollama-api-bindings = { path = "../ollama-api-bindings" }
tracing.workspace = true

[dev-dependencies]
tokio ={ workspace = true, features = ["rt", "macros"]}
tokio = { workspace = true, features = ["rt", "macros"] }
14 changes: 7 additions & 7 deletions crates/http-api-bindings/src/chat/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use openai_chat::OpenAIChatEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::ChatCompletionStream;

pub fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
if model.kind == "openai-chat" {
let engine = OpenAIChatEngine::create(
pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
match model.kind.as_str() {
"openai-chat" => Arc::new(OpenAIChatEngine::create(
&model.api_endpoint,
model.model_name.as_deref().unwrap_or_default(),
model.api_key.clone(),
);
Arc::new(engine)
} else {
panic!("Only openai-chat are supported for http chat");
)),
"ollama/chat" => ollama_api_bindings::create_chat(model).await,

unsupported_kind => panic!("Unsupported kind for http chat: {}", unsupported_kind),
}
}
18 changes: 12 additions & 6 deletions crates/http-api-bindings/src/completion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,17 @@ use llama::LlamaCppEngine;
use tabby_common::config::HttpModelConfig;
use tabby_inference::CompletionStream;

pub fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
if model.kind == "llama.cpp/completion" {
let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone());
Arc::new(engine)
} else {
panic!("Unsupported model kind: {}", model.kind);
pub async fn create(model: &HttpModelConfig) -> Arc<dyn CompletionStream> {
match model.kind.as_str() {
"llama.cpp/completion" => {
let engine = LlamaCppEngine::create(&model.api_endpoint, model.api_key.clone());
Arc::new(engine)
}
"ollama/completion" => ollama_api_bindings::create_completion(model).await,

unsupported_kind => panic!(
"Unsupported model kind for http completion: {}",
unsupported_kind
),
}
}
34 changes: 21 additions & 13 deletions crates/http-api-bindings/src/embedding/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod llama;
mod openai;

use core::panic;
use std::sync::Arc;

use llama::LlamaCppEngine;
Expand All @@ -9,18 +10,25 @@ use tabby_inference::Embedding;

use self::openai::OpenAIEmbeddingEngine;

pub fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
if config.kind == "llama.cpp/embedding" {
let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone());
Arc::new(engine)
} else if config.kind == "openai-embedding" {
let engine = OpenAIEmbeddingEngine::create(
&config.api_endpoint,
config.model_name.as_deref().unwrap_or_default(),
config.api_key.clone(),
);
Arc::new(engine)
} else {
panic!("Only llama are supported for http embedding");
pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
match config.kind.as_str() {
"llama.cpp/embedding" => {
let engine = LlamaCppEngine::create(&config.api_endpoint, config.api_key.clone());
Arc::new(engine)
}
"openai-embedding" => {
let engine = OpenAIEmbeddingEngine::create(
&config.api_endpoint,
config.model_name.as_deref().unwrap_or_default(),
config.api_key.clone(),
);
Arc::new(engine)
}
"ollama/embedding" => ollama_api_bindings::create_embedding(config).await,

unsupported_kind => panic!(
"Unsupported kind for http embedding model: {}",
unsupported_kind
),
}
}
6 changes: 3 additions & 3 deletions crates/llama-cpp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl EmbeddingServer {

Self {
server,
embedding: http_api_bindings::create_embedding(&config),
embedding: http_api_bindings::create_embedding(&config).await,
}
}
}
Expand All @@ -61,7 +61,7 @@ impl CompletionServer {
.kind("llama.cpp/completion".to_string())
.build()
.expect("Failed to create HttpModelConfig");
let completion = http_api_bindings::create(&config);
let completion = http_api_bindings::create(&config).await;
Self { server, completion }
}
}
Expand All @@ -83,7 +83,7 @@ pub async fn create_completion(

pub async fn create_embedding(config: &ModelConfig) -> Arc<dyn Embedding> {
match config {
ModelConfig::Http(http) => http_api_bindings::create_embedding(http),
ModelConfig::Http(http) => http_api_bindings::create_embedding(http).await,
ModelConfig::Local(llama) => {
if fs::metadata(&llama.model_id).is_ok() {
let path = PathBuf::from(&llama.model_id);
Expand Down
24 changes: 24 additions & 0 deletions crates/ollama-api-bindings/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "ollama-api-bindings"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
tabby-common = { path = "../tabby-common" }
tabby-inference = { path = "../tabby-inference" }

anyhow.workspace = true
async-stream.workspace = true
async-trait.workspace = true
futures.workspace = true
tracing.workspace = true

# Use git version for now: https://github.com/pepperoni21/ollama-rs/issues/44 is required to correct work with normal URLs
[dependencies.ollama-rs]
git = "https://github.com/pepperoni21/ollama-rs.git"
rev = "56e8157d98d4185bc171fe9468d3d09bc56e9dd3"
features = ["stream"]
93 changes: 93 additions & 0 deletions crates/ollama-api-bindings/src/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use std::sync::Arc;

use anyhow::{bail, Result};
use async_trait::async_trait;
use futures::{stream::BoxStream, StreamExt};
use ollama_rs::{
generation::{
chat::{request::ChatMessageRequest, ChatMessage, MessageRole},
options::GenerationOptions,
},
Ollama,
};
use tabby_common::{api::chat::Message, config::HttpModelConfig};
use tabby_inference::{ChatCompletionOptions, ChatCompletionStream};

use crate::model::OllamaModelExt;

/// A special adapter to convert Tabby messages to ollama-rs messages
struct ChatMessageAdapter(ChatMessage);

impl TryFrom<Message> for ChatMessageAdapter {
type Error = anyhow::Error;
fn try_from(value: Message) -> Result<ChatMessageAdapter> {
let role = match value.role.as_str() {
"system" => MessageRole::System,
"assistant" => MessageRole::Assistant,
"user" => MessageRole::User,
other => bail!("Unsupported chat message role: {other}"),
};

Ok(ChatMessageAdapter(ChatMessage::new(role, value.content)))
}
}

impl From<ChatMessageAdapter> for ChatMessage {
fn from(val: ChatMessageAdapter) -> Self {
val.0
}
}

/// Ollama chat completions
pub struct OllamaChat {
/// Connection to Ollama API
connection: Ollama,
/// Model name, <model>
model: String,
}

#[async_trait]
impl ChatCompletionStream for OllamaChat {
async fn chat_completion(
&self,
messages: &[Message],
options: ChatCompletionOptions,
) -> Result<BoxStream<String>> {
let messages = messages
.iter()
.map(|m| ChatMessageAdapter::try_from(m.to_owned()))
.collect::<Result<Vec<_>, _>>()?;

let messages = messages.into_iter().map(|m| m.into()).collect::<Vec<_>>();

let options = GenerationOptions::default()
.seed(options.seed as i32)
.temperature(options.sampling_temperature)
.num_predict(options.max_decoding_tokens);

let request = ChatMessageRequest::new(self.model.to_owned(), messages).options(options);

let stream = self.connection.send_chat_messages_stream(request).await?;

let stream = stream
.map(|x| match x {
Ok(response) => response.message,
Err(_) => None,
})
.map(|x| match x {
Some(e) => e.content,
None => "".to_owned(),
});

Ok(stream.boxed())
}
}

pub async fn create(config: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
let connection = Ollama::try_new(config.api_endpoint.to_owned())
.expect("Failed to create connection to Ollama, URL invalid");

let model = connection.select_model_or_default(config).await.unwrap();

Arc::new(OllamaChat { connection, model })
}
66 changes: 66 additions & 0 deletions crates/ollama-api-bindings/src/completion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::sync::Arc;

use async_stream::stream;
use async_trait::async_trait;
use futures::{stream::BoxStream, StreamExt};
use ollama_rs::{
generation::{completion::request::GenerationRequest, options::GenerationOptions},
Ollama,
};
use tabby_common::config::HttpModelConfig;
use tabby_inference::{CompletionOptions, CompletionStream};
use tracing::error;

use crate::model::OllamaModelExt;

pub struct OllamaCompletion {
/// Connection to Ollama API
connection: Ollama,
/// Model name, <model>
model: String,
}

#[async_trait]
impl CompletionStream for OllamaCompletion {
async fn generate(&self, prompt: &str, options: CompletionOptions) -> BoxStream<String> {
let ollama_options = GenerationOptions::default()
.num_ctx(options.max_input_length as u32)
.num_predict(options.max_decoding_tokens)
.seed(options.seed as i32)
.temperature(options.sampling_temperature);
let request = GenerationRequest::new(self.model.to_owned(), prompt.to_owned())
.template("{{ .Prompt }}".to_string())
.options(ollama_options);

// Why this function returns not Result?
match self.connection.generate_stream(request).await {
Ok(stream) => {
let tabby_stream = stream! {

for await response in stream {
let parts = response.unwrap();
for part in parts {
yield part.response
}
}

};

tabby_stream.boxed()
}
Err(err) => {
error!("Failed to generate completion: {}", err);
futures::stream::empty().boxed()
}
}
}
}

pub async fn create(config: &HttpModelConfig) -> Arc<dyn CompletionStream> {
let connection = Ollama::try_new(config.api_endpoint.to_owned())
.expect("Failed to create connection to Ollama, URL invalid");

let model = connection.select_model_or_default(config).await.unwrap();

Arc::new(OllamaCompletion { connection, model })
}
Loading

0 comments on commit 1d1edfe

Please sign in to comment.