diff --git a/src/api/chat_completions.rs b/src/api/chat_completions.rs index edd117a..2de9465 100644 --- a/src/api/chat_completions.rs +++ b/src/api/chat_completions.rs @@ -1,5 +1,6 @@ use super::Instance; use crate::json_models::chat_completions::ChatCompletionsResponse; +use crate::model_collection::ModelCollection; use serde_json::json; const SYSTEM_PROMPT: &str = r#" @@ -24,8 +25,10 @@ impl Instance { &self, link_words: &[String], avoid_words: &[String], + model_id: &String, ) -> Result> { - let request_body = self.build_request_body(link_words, avoid_words); + self.validate_model_id(model_id).await?; + let request_body = self.build_request_body(link_words, avoid_words, model_id); // Get response from API endpoint let response = self @@ -41,7 +44,7 @@ impl Instance { .json::() .await .map_err(|e| format!("Failed to parse clues from API server: {}", e))?; - + Ok(parsed_response) } @@ -49,6 +52,7 @@ impl Instance { &self, link_words: &[String], avoid_words: &[String], + model_id: &String, ) -> serde_json::Value { // Aggregate two sets of words into one prompt let content = format!( @@ -68,7 +72,23 @@ impl Instance { "content": content } ], - "model": self.model_id + "model": model_id }) } + + async fn validate_model_id(&self, model_id: &String) -> Result<(), Box> { + let models_response = self.get_models().await?; + let model_collection = ModelCollection::new(models_response); + + // Return Error if the chosen model is not valid + if !model_collection.contains(model_id) { + return Err(format!( + "{} is not a valid language model from your provider", + model_id + ) + .into()); + } + + Ok(()) + } } diff --git a/src/api/mod.rs b/src/api/mod.rs index fe92b0c..d4286a2 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -3,7 +3,6 @@ mod language_models; use dotenv::dotenv; use std::env; -use crate::model_collection::ModelCollection; pub struct Instance { client: reqwest::Client, @@ -35,26 +34,6 @@ impl Instance { .map_err(|_| format!("Cannot read environment variable: {}", var_name).into()) } - pub async fn set_model_id( - &mut self, - model_id: String, - ) -> Result<(), Box> { - // Return Error if the chosen model is not valid - let models_response = self.get_models().await?; - let model_collection = ModelCollection::new(models_response); - - if !model_collection.contains(&model_id) { - return Err(format!( - "{} is not a valid language model from your provider", - model_id - ) - .into()); - } - - self.model_id = model_id; - Ok(()) - } - pub fn set_base_url(&mut self, base_url: String) { self.base_url = base_url; } diff --git a/src/lib.rs b/src/lib.rs index 8a3e54d..33faa31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,7 @@ pub struct Args { /// Select language model(s) #[arg(short, long = "set-models", default_missing_value = "interactive", num_args = 0..)] - pub model: Option>, + pub models: Option>, /// Specify an output file #[arg(short, long, value_name = "FILE")] diff --git a/src/main.rs b/src/main.rs index 40d7cfa..fbd8b03 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,3 @@ -use std::env; use clap::Parser; use dotenv::dotenv; use mastermind::api::Instance; @@ -6,6 +5,7 @@ use mastermind::clue::ClueCollection; use mastermind::json_models::chat_completions::ChatCompletionsResponse; use mastermind::model_collection::ModelCollection; use mastermind::*; +use std::env; #[tokio::main] async fn main() -> Result<(), Box> { @@ -14,7 +14,7 @@ async fn main() -> Result<(), Box> { dotenv().ok(); // Create an API instance and get all available models - let mut api_instance = Instance::new()?; + let api_instance = Instance::new()?; let models_response = api_instance.get_models().await?; let model_collection = ModelCollection::new(models_response); @@ -24,45 +24,33 @@ async fn main() -> Result<(), Box> { return Ok(()); } - // If -m is set, use a preferred language model - if let Some(model_ids) = args.model { - if model_ids[0] == "interactive" { - let selected_model = model_collection.prompt_selection()[0].to_string(); - api_instance.set_model_id(selected_model).await?; - } else { - let selected_model = env::var("DEFAULT_MODEL_ID") - .map_err(|_| "Cannot read environment variable: DEFAULT_MODEL_ID".into())?; - api_instance.set_model_id(selected_model).await?; - } - } - - // Attempt to read words from the two files + // Read words from the two files let link_words = read_words_from_file(args.to_link.unwrap())?; let avoid_words = read_words_from_file(args.to_avoid.unwrap())?; - // Get responses - // If -m is set, use a preferred language model(s) - // Otherwise, call the API straight away - let responses = match args.model { + // If -m is present and has values, use preferred language models + // If -m is present but doesn't have a value, prompt interactive menu + // If -m is not present, use the default from environment variable + let selected_model_ids = match args.models { Some(model_ids) => { - let mut responses: Vec = vec![]; - for model_id in model_ids { - api_instance.set_model_id(model_id).await?; - let response = api_instance - .post_chat_completions(&link_words, &avoid_words) - .await?; - responses.push(response); + if model_ids[0] == "interactive" { + model_collection.prompt_selection() + } else { + model_ids } - responses } - None => vec![ - api_instance - .post_chat_completions(&link_words, &avoid_words) - .await?, - ], + None => vec![env::var("DEFAULT_MODEL_ID") + .map_err(|_| "Cannot read environment variable: DEFAULT_MODEL_ID")?], }; - // Build ClueCollection from the responses + // Get responses and build ClueCollection + let mut responses: Vec = vec![]; + for model_id in &selected_model_ids { + let response = api_instance + .post_chat_completions(&link_words, &avoid_words, model_id) + .await?; + responses.push(response); + } let clue_collection = ClueCollection::new(responses); // Output