Skip to content

Commit

Permalink
Merge pull request #15 from theoforger/feature/unit-testing
Browse files Browse the repository at this point in the history
Add unit testing
  • Loading branch information
theoforger authored Sep 20, 2024
2 parents 8315d78 + bec161b commit 28a1374
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 18 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros"] }
dotenv = "0.15.0"
serde_json = "1.0.128"
serde = { version = "1.0.210", features = ["derive"] }
comfy-table = "7.1.1"
comfy-table = "7.1.1"
httpmock = "0.7.0"
11 changes: 4 additions & 7 deletions src/api/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::json_models::chat_completion::ChatCompletionResponse;
use super::Instance;
use crate::clue::ClueCollection;
use crate::json_models::chat_completion::ChatCompletionResponse;
use serde_json::json;

const SYSTEM_PROMPT: &str = r#"
Expand Down Expand Up @@ -36,21 +36,18 @@ impl Instance {
.json(&request_body)
.send()
.await
.map_err(|_| "Failed to fetch clue collection from API server")?;
.map_err(|e| format!("Failed to fetch clue collection from API server: {}", e))?;

let parsed_response = response
.json::<ChatCompletionResponse>()
.await
.map_err(|_| "Failed to parse clues from API server")?;
.map_err(|e| format!("Failed to parse clues from API server: {}", e))?;

// Extract usage information from the parsed response
let token_usage = parsed_response.usage;

// Extract clue strings from the parsed response
let clue_strings = parsed_response
.choices
.first()
.ok_or("Failed to parse clues from API server")?
let clue_strings = parsed_response.choices[0]
.message
.content
.lines()
Expand Down
8 changes: 4 additions & 4 deletions src/api/language_models.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
use crate::json_models::language_model::ModelsResponse;
use super::Instance;
use crate::json_models::language_model::ModelsResponse;

impl Instance {
pub async fn fetch_all_model_ids(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
pub async fn fetch_language_model_ids(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let response = self
.client
.get(format!("{}models", self.base_url))
.bearer_auth(&self.key)
.send()
.await
.map_err(|_| "Failed to fetch model IDs from API server")?;
.map_err(|e| format!("Failed to fetch model IDs from API server: {}", e))?;

let mut all_model_ids = response
.json::<ModelsResponse>()
.await
.map_err(|_| "Failed to parse model IDs from API server")?
.map_err(|e| format!("Failed to parse model IDs from API server: {}", e))?
.data
.iter()
.map(|model| model.id.trim().to_string())
Expand Down
6 changes: 5 additions & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Instance {
model_id: String,
) -> Result<(), Box<dyn std::error::Error>> {
// Return Error if the chosen model is not valid
let valid_model_ids = self.fetch_all_model_ids().await?;
let valid_model_ids = self.fetch_language_model_ids().await?;
if !valid_model_ids.contains(&model_id) {
return Err(format!(
"{} is not a valid language model from your provider",
Expand All @@ -54,4 +54,8 @@ impl Instance {
self.model_id = model_id;
Ok(())
}

pub fn set_base_url(&mut self, base_url: String) {
self.base_url = base_url;
}
}
20 changes: 16 additions & 4 deletions src/clue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl ClueCollection {
pub fn display_table(&self) {
println!("{}", self.generate_table());
}

pub fn display_token_info(&self) {
eprintln!(
"\nToken Usage:\n\
Expand All @@ -108,9 +108,21 @@ impl ClueCollection {
Completion Tokens: {}\n\
----------------------\n\
Total Tokens: {}",
self.usage.prompt_tokens,
self.usage.completion_tokens,
self.usage.total_tokens
self.usage.prompt_tokens, self.usage.completion_tokens, self.usage.total_tokens
);
}

pub fn generate_raw_list(&self) -> String {
let mut raw_list = String::new();
for clue in &self.clues {
let clue_string = format!(
"{} {} - {}\n",
clue.clue_word,
clue.count,
clue.linked_words.join(", ")
);
raw_list.push_str(clue_string.as_str());
}
raw_list
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use std::path::PathBuf;
pub mod api;
mod clue;
mod json_models;
#[cfg(test)]
mod tests;

/// Mastermind - An LLM-powered CLI tool to help you be a better spymaster in Codenames
#[derive(Parser)]
Expand Down
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// If -g is set, call the models API endpoint instead
if args.get {
println!("{}", api_instance.fetch_all_model_ids().await?.join("\n"));
println!("{}", api_instance.fetch_language_model_ids().await?.join("\n"));
return Ok(());
}

Expand Down
5 changes: 5 additions & 0 deletions src/tests/expected_outputs/chat_completions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
music 2 - sound, bee
film 2 - bond, tokyo
free 2 - park, penny
dive 2 - scuba diver, hospital
large 2 - walrus, scuba diver
13 changes: 13 additions & 0 deletions src/tests/expected_outputs/language_models.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
distil-whisper-large-v3-en
gemma-7b-it
gemma2-9b-it
llama-3.1-70b-versatile
llama-3.1-8b-instant
llama-guard-3-8b
llama3-70b-8192
llama3-8b-8192
llama3-groq-70b-8192-tool-use-preview
llama3-groq-8b-8192-tool-use-preview
llava-v1.5-7b-4096-preview
mixtral-8x7b-32768
whisper-large-v3
30 changes: 30 additions & 0 deletions src/tests/mock_responses/chat_completions.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"id": "chatcmpl-869ede85-2f46-4834-a039-28d757e958a5",
"object": "chat.completion",
"created": 1726870549,
"model": "llama-3.1-70b-versatile",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "music, 2, sound, bee\nfilm, 2, bond, tokyo\nfree, 2, park, penny\ndive, 2, scuba diver, hospital\nlarge, 2, walrus, scuba diver"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"queue_time": 0.005669406999999987,
"prompt_tokens": 222,
"prompt_time": 0.068204384,
"completion_tokens": 53,
"completion_time": 0.214023764,
"total_tokens": 275,
"total_time": 0.282228148
},
"system_fingerprint": "fp_b6828be2c9",
"x_groq": {
"id": "req_01j88r2wfmecr9zgpjn2zmnprb"
}
}
122 changes: 122 additions & 0 deletions src/tests/mock_responses/language_models.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{
"object": "list",
"data": [
{
"id": "llava-v1.5-7b-4096-preview",
"object": "model",
"created": 1725402373,
"owned_by": "Other",
"active": true,
"context_window": 4096,
"public_apps": null
},
{
"id": "gemma-7b-it",
"object": "model",
"created": 1693721698,
"owned_by": "Google",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "llama-3.1-8b-instant",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 131072,
"public_apps": null
},
{
"id": "whisper-large-v3",
"object": "model",
"created": 1693721698,
"owned_by": "OpenAI",
"active": true,
"context_window": 448,
"public_apps": null
},
{
"id": "mixtral-8x7b-32768",
"object": "model",
"created": 1693721698,
"owned_by": "Mistral AI",
"active": true,
"context_window": 32768,
"public_apps": null
},
{
"id": "llama3-8b-8192",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "distil-whisper-large-v3-en",
"object": "model",
"created": 1693721698,
"owned_by": "Hugging Face",
"active": true,
"context_window": 448,
"public_apps": null
},
{
"id": "llama-guard-3-8b",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "gemma2-9b-it",
"object": "model",
"created": 1693721698,
"owned_by": "Google",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "llama3-70b-8192",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "llama3-groq-70b-8192-tool-use-preview",
"object": "model",
"created": 1693721698,
"owned_by": "Groq",
"active": true,
"context_window": 8192,
"public_apps": null
},
{
"id": "llama-3.1-70b-versatile",
"object": "model",
"created": 1693721698,
"owned_by": "Meta",
"active": true,
"context_window": 131072,
"public_apps": null
},
{
"id": "llama3-groq-8b-8192-tool-use-preview",
"object": "model",
"created": 1693721698,
"owned_by": "Groq",
"active": true,
"context_window": 8192,
"public_apps": null
}
]
}
75 changes: 75 additions & 0 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use super::*;
use crate::api::Instance;
use httpmock::prelude::*;

#[test]
fn test_api_instance() {
let api_instance = api::Instance::new();
assert!(api_instance.is_ok());
}

#[test]
fn test_read_words_from_file() {
let to_link = read_words_from_file(PathBuf::from("examples/link.txt"));
assert!(to_link.is_ok());
let to_avoid = read_words_from_file(PathBuf::from("examples/avoid.txt"));
assert!(to_avoid.is_ok());
}

#[tokio::test]
async fn test_fetch_language_models() {
// Start a lightweight mock server.
let server = MockServer::start_async().await;

// Create a mock on the server.
let mock = server.mock(|when, then| {
when.method(GET).path("/models");
then.status(200)
.header("content-type", "application/json")
.body_from_file("src/tests/mock_responses/language_models.json");
});

// Create an API instance and set the base url to mock server url
let mut api_instance = Instance::new().unwrap();
api_instance.set_base_url(server.url("/"));

// Get response from mock server
let response = api_instance.fetch_language_model_ids().await.unwrap();
mock.assert();

// Compare outputs
let output = response.join("\n");
let expected_output = fs::read_to_string("src/tests/expected_outputs/language_models.txt").unwrap();
assert_eq!(output, expected_output);
}

#[tokio::test]
async fn test_fetch_clue_collection() {
// Start a lightweight mock server.
let server = MockServer::start_async().await;

// Create a mock on the server.
let mock = server.mock(|when, then| {
when.method(POST).path("/chat/completions");
then.status(200)
.header("content-type", "application/json")
.body_from_file("src/tests/mock_responses/chat_completions.json");
});

// Create an API instance and set the base url to mock server url
let mut api_instance = Instance::new().unwrap();
api_instance.set_base_url(server.url("/"));

// Get response from mock server
let response = api_instance
.fetch_clue_collection(vec![], vec![])
.await
.unwrap();
mock.assert();

// Compare outputs
let output = response.generate_raw_list();
let expected_output =
fs::read_to_string("src/tests/expected_outputs/chat_completions.txt").unwrap();
assert_eq!(output, expected_output);
}

0 comments on commit 28a1374

Please sign in to comment.