Skip to content

Commit

Permalink
Finalize feature
Browse files Browse the repository at this point in the history
  • Loading branch information
theoforger committed Sep 25, 2024
1 parent 0f06867 commit 1b70ef8
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 58 deletions.
26 changes: 23 additions & 3 deletions src/api/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::Instance;
use crate::json_models::chat_completions::ChatCompletionsResponse;
use crate::model_collection::ModelCollection;
use serde_json::json;

const SYSTEM_PROMPT: &str = r#"
Expand All @@ -24,8 +25,10 @@ impl Instance {
&self,
link_words: &[String],
avoid_words: &[String],
model_id: &String,
) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error>> {
let request_body = self.build_request_body(link_words, avoid_words);
self.validate_model_id(model_id).await?;
let request_body = self.build_request_body(link_words, avoid_words, model_id);

// Get response from API endpoint
let response = self
Expand All @@ -41,14 +44,15 @@ impl Instance {
.json::<ChatCompletionsResponse>()
.await
.map_err(|e| format!("Failed to parse clues from API server: {}", e))?;

Ok(parsed_response)
}

fn build_request_body(
&self,
link_words: &[String],
avoid_words: &[String],
model_id: &String,
) -> serde_json::Value {
// Aggregate two sets of words into one prompt
let content = format!(
Expand All @@ -68,7 +72,23 @@ impl Instance {
"content": content
}
],
"model": self.model_id
"model": model_id
})
}

async fn validate_model_id(&self, model_id: &String) -> Result<(), Box<dyn std::error::Error>> {
let models_response = self.get_models().await?;
let model_collection = ModelCollection::new(models_response);

// Return Error if the chosen model is not valid
if !model_collection.contains(model_id) {
return Err(format!(
"{} is not a valid language model from your provider",
model_id
)
.into());
}

Ok(())
}
}
21 changes: 0 additions & 21 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ mod language_models;

use dotenv::dotenv;
use std::env;
use crate::model_collection::ModelCollection;

pub struct Instance {
client: reqwest::Client,
Expand Down Expand Up @@ -35,26 +34,6 @@ impl Instance {
.map_err(|_| format!("Cannot read environment variable: {}", var_name).into())
}

pub async fn set_model_id(
&mut self,
model_id: String,
) -> Result<(), Box<dyn std::error::Error>> {
// Return Error if the chosen model is not valid
let models_response = self.get_models().await?;
let model_collection = ModelCollection::new(models_response);

if !model_collection.contains(&model_id) {
return Err(format!(
"{} is not a valid language model from your provider",
model_id
)
.into());
}

self.model_id = model_id;
Ok(())
}

pub fn set_base_url(&mut self, base_url: String) {
self.base_url = base_url;
}
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Args {

/// Select language model(s)
#[arg(short, long = "set-models", default_missing_value = "interactive", num_args = 0..)]
pub model: Option<Vec<String>>,
pub models: Option<Vec<String>>,

/// Specify an output file
#[arg(short, long, value_name = "FILE")]
Expand Down
54 changes: 21 additions & 33 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use std::env;
use clap::Parser;
use dotenv::dotenv;
use mastermind::api::Instance;
use mastermind::clue::ClueCollection;
use mastermind::json_models::chat_completions::ChatCompletionsResponse;
use mastermind::model_collection::ModelCollection;
use mastermind::*;
use std::env;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -14,7 +14,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
dotenv().ok();

// Create an API instance and get all available models
let mut api_instance = Instance::new()?;
let api_instance = Instance::new()?;
let models_response = api_instance.get_models().await?;
let model_collection = ModelCollection::new(models_response);

Expand All @@ -24,45 +24,33 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
return Ok(());
}

// If -m is set, use a preferred language model
if let Some(model_ids) = args.model {
if model_ids[0] == "interactive" {
let selected_model = model_collection.prompt_selection()[0].to_string();
api_instance.set_model_id(selected_model).await?;
} else {
let selected_model = env::var("DEFAULT_MODEL_ID")
.map_err(|_| "Cannot read environment variable: DEFAULT_MODEL_ID".into())?;
api_instance.set_model_id(selected_model).await?;
}
}

// Attempt to read words from the two files
// Read words from the two files
let link_words = read_words_from_file(args.to_link.unwrap())?;
let avoid_words = read_words_from_file(args.to_avoid.unwrap())?;

// Get responses
// If -m is set, use a preferred language model(s)
// Otherwise, call the API straight away
let responses = match args.model {
// If -m is present and has values, use preferred language models
// If -m is present but doesn't have a value, prompt interactive menu
// If -m is not present, use the default from environment variable
let selected_model_ids = match args.models {
Some(model_ids) => {
let mut responses: Vec<ChatCompletionsResponse> = vec![];
for model_id in model_ids {
api_instance.set_model_id(model_id).await?;
let response = api_instance
.post_chat_completions(&link_words, &avoid_words)
.await?;
responses.push(response);
if model_ids[0] == "interactive" {
model_collection.prompt_selection()
} else {
model_ids
}
responses
}
None => vec![
api_instance
.post_chat_completions(&link_words, &avoid_words)
.await?,
],
None => vec![env::var("DEFAULT_MODEL_ID")
.map_err(|_| "Cannot read environment variable: DEFAULT_MODEL_ID")?],
};

// Build ClueCollection from the responses
// Get responses and build ClueCollection
let mut responses: Vec<ChatCompletionsResponse> = vec![];
for model_id in &selected_model_ids {
let response = api_instance
.post_chat_completions(&link_words, &avoid_words, model_id)
.await?;
responses.push(response);
}
let clue_collection = ClueCollection::new(responses);

// Output
Expand Down

0 comments on commit 1b70ef8

Please sign in to comment.