diff --git a/Cargo.lock b/Cargo.lock index 08ecc2d..af0ce2d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -151,6 +151,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "async-stream" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b5a71a6f37880a80d1d7f19efd781e4b5de42c88f0722cc13bcb6cc2cfe8476" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "async-trait" version = "0.1.83" @@ -1560,6 +1582,23 @@ dependencies = [ "memchr", ] +[[package]] +name = "ollama-rs" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46483ac9e1f9e93da045b5875837ca3c9cf014fd6ab89b4d9736580ddefc4759" +dependencies = [ + "async-stream", + "async-trait", + "log", + "reqwest", + "serde", + "serde_json", + "tokio", + "tokio-stream", + "url", +] + [[package]] name = "once_cell" version = "1.20.1" @@ -2172,6 +2211,7 @@ dependencies = [ "image", "lazy_static", "log", + "ollama-rs", "regex", "reqwest", "rustyline", @@ -2181,6 +2221,7 @@ dependencies = [ "tempfile", "termimad", "tokio", + "tokio-stream", "toml", "unicode-width 0.2.0", "walkdir", diff --git a/Cargo.toml b/Cargo.toml index 3d52608..02e7b0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" [dependencies] reqwest = { version = "0.12", features = ["json"] } tokio = { version = "1", features = ["full"] } +tokio-stream = "0.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" async-openai = "0.24.0" @@ -30,3 +31,4 @@ ignore = "0.4" log = "0.4" env_logger = "0.11" async-trait = "0.1" +ollama-rs = { version = "0.2", features = ["stream"] } diff --git a/docs/docs/getting_started.md b/docs/docs/getting_started.md index 1398964..796ebb0 100644 --- a/docs/docs/getting_started.md +++ b/docs/docs/getting_started.md @@ -38,6 +38,27 @@ Recommended persona: [Recommended Persona] - A default `config.toml` file is generated in the `.rusty` directory. - This file includes the recommended persona and sets default models for chat and commit message generation. +## Choosing Your AI Provider + +After configuring your environment, you can select between different AI backends, including OpenAI and Ollama, depending on your needs or preferences. + +### Using the Ollama Backend + +To utilize Ollama, ensure your configuration in the `config.toml` specifies Ollama in the desired models: +```toml +[ai] +chat_model = "ollama_32" +commit_model = "ollama_32" +wish_model = "ollama_32" + +[[models]] +name = "ollama_32" +api_name = "llama3.2" +backend = "Ollama" +url = "http://localhost:11434" + +``` + ## Example Usage Once your setup is complete, you can start using Rusty Buddy right away. Here are a few common scenarios: diff --git a/docs/docs/installation.md b/docs/docs/installation.md index 863de73..9165f50 100644 --- a/docs/docs/installation.md +++ b/docs/docs/installation.md @@ -46,4 +46,16 @@ If you prefer to have more control over the installation or need to modify the s - Ensure that Rust and Cargo are installed on your system. You can install them via [rustup](https://rustup.rs/). - Network access may be required for both installation methods, particularly for downloading dependencies or connecting with the OpenAI API. -By following these instructions, you will be able to set up Rusty Buddy and harness its capability for your development workflows. Choose the installation method that aligns with your needs and system configuration. \ No newline at end of file +By following these instructions, you will be able to set up Rusty Buddy and harness its capability for your development workflows. Choose the installation method that aligns with your needs and system configuration. + +## Additional Requirements for Ollama + +To use the Ollama feature in Rusty Buddy, you need to install and configure the Ollama service. This section explains any additional dependencies or steps required for Ollama. + +### Step 1: Install Ollama + +Ensure that the Ollama service is installed and running on your machine. You can follow the installation guide on the [official Ollama documentation](https://ollama.com). + +### Step 2: Configure Firewall and Ports + +Make sure your network allows communication through the port that Ollama uses (default is 11434). \ No newline at end of file diff --git a/src/chat/service_builder.rs b/src/chat/service_builder.rs index 51493a1..b1c2d3d 100644 --- a/src/chat/service_builder.rs +++ b/src/chat/service_builder.rs @@ -2,6 +2,7 @@ use crate::chat::interface::{ChatBackend, ChatStorage}; use crate::chat::service::ChatService; use crate::config::{AIBackend, CONFIG}; use crate::persona::Persona; +use crate::provider::ollama::ollama_interface::OllamaInterface; use crate::provider::openai::openai_interface::OpenAIInterface; use log::debug; use std::error::Error; @@ -55,7 +56,10 @@ impl ChatServiceBuilder { // Check which provider to use based on the model let backend: Box = match &model.backend { AIBackend::OpenAI => Box::new(OpenAIInterface::new(model.api_name.clone())), // Additional backends can be added here - _ => return Err(format!("Unknown backend for model: {:?}", model.backend).into()), + AIBackend::Ollama => Box::new(OllamaInterface::new( + model.api_name.clone(), + model.url.clone(), + )), // New line added }; Ok(ChatService::new(backend, storage, persona, self.directory)) diff --git a/src/cli/init/mod.rs b/src/cli/init/mod.rs index c42693f..ecfa16d 100644 --- a/src/cli/init/mod.rs +++ b/src/cli/init/mod.rs @@ -126,6 +126,12 @@ backend = "OpenAI" name = "openai_complex" api_name = "gpt-4o-2024-08-06" backend = "OpenAI" + +[[models]] +name = "ollama_complex" +api_name = "llama3.2" +backend = "Ollama" +url = "http://localhost:11434" "#, recommended_persona ); diff --git a/src/config/config_file.rs b/src/config/config_file.rs index e29b197..e4bfa87 100644 --- a/src/config/config_file.rs +++ b/src/config/config_file.rs @@ -33,10 +33,7 @@ pub struct AI { pub struct Model { pub name: String, pub api_name: String, - #[allow(dead_code)] pub url: Option, - #[allow(dead_code)] - pub port: Option, pub backend: AIBackend, } diff --git a/src/provider/mod.rs b/src/provider/mod.rs index d8c3087..3ef32f6 100644 --- a/src/provider/mod.rs +++ b/src/provider/mod.rs @@ -1 +1,2 @@ +pub mod ollama; pub mod openai; diff --git a/src/provider/ollama/mod.rs b/src/provider/ollama/mod.rs new file mode 100644 index 0000000..bbfbdad --- /dev/null +++ b/src/provider/ollama/mod.rs @@ -0,0 +1 @@ +pub mod ollama_interface; diff --git a/src/provider/ollama/ollama_interface.rs b/src/provider/ollama/ollama_interface.rs new file mode 100644 index 0000000..06e5b70 --- /dev/null +++ b/src/provider/ollama/ollama_interface.rs @@ -0,0 +1,78 @@ +// src/provider/ollama/ollama_interface.rs + +use crate::chat::interface::{ChatBackend, Message, MessageRole}; +use async_trait::async_trait; +use ollama_rs::{ + generation::chat::{request::ChatMessageRequest, ChatMessage, ChatMessageResponseStream}, + IntoUrlSealed, Ollama, +}; +use std::error::Error; +use tokio_stream::StreamExt; + +pub struct OllamaInterface { + ollama: Ollama, + model: String, +} + +impl OllamaInterface { + pub fn new(model: String, ourl: Option) -> Self { + let url = ourl.unwrap_or("http://localhost:11434".into()); + OllamaInterface { + ollama: Ollama::from_url(url.clone().into_url().unwrap()), + model, + } + } + + fn convert_messages(messages: &[Message]) -> Vec { + let mut chat_messages: Vec = Vec::new(); + + // Convert Message into ChatMessage for ollama + for msg in messages { + match msg.role { + MessageRole::User => { + chat_messages.push(ChatMessage::user(msg.content.clone())); + } + MessageRole::Assistant => { + chat_messages.push(ChatMessage::assistant(msg.content.clone())); + } + MessageRole::Context => { + chat_messages.push(ChatMessage::system(msg.content.clone())); + } + MessageRole::System => { + chat_messages.push(ChatMessage::system(msg.content.clone())); + } + } + } + chat_messages + } +} + +#[async_trait] +impl ChatBackend for OllamaInterface { + async fn send_request( + &mut self, + messages: &[Message], + _use_tools: bool, + ) -> Result> { + let chat_messages = Self::convert_messages(messages); + + let request = ChatMessageRequest::new(self.model.clone(), chat_messages.clone()); + + let mut stream: ChatMessageResponseStream = + self.ollama.send_chat_messages_stream(request).await?; + + let mut response = String::new(); + + while let Some(Ok(res)) = stream.next().await { + if let Some(assistant_message) = res.message { + response += &assistant_message.content; + } + } + Ok(response) + } + + fn print_statistics(&self) { + // Implement statistics if required + println!("Using Ollama model: {}", self.model); + } +}