Skip to content

Commit

Permalink
Fix unit testing
Browse files Browse the repository at this point in the history
  • Loading branch information
theoforger committed Sep 25, 2024
1 parent ea4e5e5 commit 7d65ef0
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 22 deletions.
33 changes: 16 additions & 17 deletions src/api/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ impl Instance {
avoid_words: &[String],
model_id: &String,
) -> Result<ChatCompletionsResponse, Box<dyn std::error::Error>> {
self.validate_model_id(model_id).await?;
let request_body = self.build_request_body(link_words, avoid_words, model_id);

// Get response from API endpoint
Expand All @@ -48,6 +47,22 @@ impl Instance {
Ok(parsed_response)
}

pub async fn validate_model_id(&self, model_id: &String) -> Result<(), Box<dyn std::error::Error>> {
let models_response = self.get_models().await?;
let model_collection = ModelCollection::new(models_response);

// Return Error if the chosen model is not valid
if !model_collection.contains(model_id) {
return Err(format!(
"{} is not a valid language model from your provider",
model_id
)
.into());
}

Ok(())
}

fn build_request_body(
&self,
link_words: &[String],
Expand Down Expand Up @@ -75,20 +90,4 @@ impl Instance {
"model": model_id
})
}

async fn validate_model_id(&self, model_id: &String) -> Result<(), Box<dyn std::error::Error>> {
let models_response = self.get_models().await?;
let model_collection = ModelCollection::new(models_response);

// Return Error if the chosen model is not valid
if !model_collection.contains(model_id) {
return Err(format!(
"{} is not a valid language model from your provider",
model_id
)
.into());
}

Ok(())
}
}
5 changes: 3 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let link_words = read_words_from_file(args.to_link.unwrap())?;
let avoid_words = read_words_from_file(args.to_avoid.unwrap())?;

// If -m is present and has values, use preferred language models
// If -m is present and has values, use the preferred language models
// If -m is present but doesn't have a value, prompt interactive menu
// If -m is not present, use the default from environment variable
let selected_model_ids = match args.models {
Expand All @@ -43,9 +43,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.map_err(|_| "Cannot read environment variable: DEFAULT_MODEL_ID")?],
};

// Get responses and build ClueCollection
// Aggregate responses from each language model and build ClueCollection
let mut responses: Vec<ChatCompletionsResponse> = vec![];
for model_id in &selected_model_ids {
api_instance.validate_model_id(model_id).await?;
let response = api_instance
.post_chat_completions(&link_words, &avoid_words, model_id)
.await?;
Expand Down
7 changes: 4 additions & 3 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::*;
use crate::api::Instance;
use httpmock::prelude::*;
use crate::clue::ClueCollection;
use crate::model_collection::ModelCollection;

#[test]
fn test_api_instance() {
Expand Down Expand Up @@ -35,11 +36,11 @@ async fn test_get_models() {
api_instance.set_base_url(server.url("/"));

// Get response from mock server
let response = api_instance.get_models().await.unwrap();
let response = ModelCollection::new(api_instance.get_models().await.unwrap());
mock.assert();

// Compare outputs
let output = response.join("\n");
let output = response.generate_string();
let expected_output = fs::read_to_string("src/tests/expected_outputs/language_models.txt").unwrap();
assert_eq!(output, expected_output);
}
Expand All @@ -63,7 +64,7 @@ async fn test_post_chat_completions() {

// Get responses from mock server
let responses = vec![api_instance
.post_chat_completions(&Vec::<String>::new(), &Vec::<String>::new())
.post_chat_completions(&Vec::<String>::new(), &Vec::<String>::new(),&String::new())
.await
.unwrap()];
mock.assert();
Expand Down

0 comments on commit 7d65ef0

Please sign in to comment.