diff --git a/Cargo.toml b/Cargo.toml index 191b8b2..66e4003 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,4 +11,5 @@ tokio = { version = "1.40.0", features = ["rt", "rt-multi-thread", "macros"] } dotenv = "0.15.0" serde_json = "1.0.128" serde = { version = "1.0.210", features = ["derive"] } -comfy-table = "7.1.1" \ No newline at end of file +comfy-table = "7.1.1" +httpmock = "0.7.0" \ No newline at end of file diff --git a/src/api/chat_completions.rs b/src/api/chat_completions.rs index 9126d48..759fc9e 100644 --- a/src/api/chat_completions.rs +++ b/src/api/chat_completions.rs @@ -1,6 +1,6 @@ -use crate::json_models::chat_completion::ChatCompletionResponse; use super::Instance; use crate::clue::ClueCollection; +use crate::json_models::chat_completion::ChatCompletionResponse; use serde_json::json; const SYSTEM_PROMPT: &str = r#" @@ -36,21 +36,18 @@ impl Instance { .json(&request_body) .send() .await - .map_err(|_| "Failed to fetch clue collection from API server")?; + .map_err(|e| format!("Failed to fetch clue collection from API server: {}", e))?; let parsed_response = response .json::() .await - .map_err(|_| "Failed to parse clues from API server")?; + .map_err(|e| format!("Failed to parse clues from API server: {}", e))?; // Extract usage information from the parsed response let token_usage = parsed_response.usage; // Extract clue strings from the parsed response - let clue_strings = parsed_response - .choices - .first() - .ok_or("Failed to parse clues from API server")? + let clue_strings = parsed_response.choices[0] .message .content .lines() diff --git a/src/api/language_models.rs b/src/api/language_models.rs index 34c9890..a1a00c5 100644 --- a/src/api/language_models.rs +++ b/src/api/language_models.rs @@ -1,20 +1,20 @@ -use crate::json_models::language_model::ModelsResponse; use super::Instance; +use crate::json_models::language_model::ModelsResponse; impl Instance { - pub async fn fetch_all_model_ids(&self) -> Result, Box> { + pub async fn fetch_language_model_ids(&self) -> Result, Box> { let response = self .client .get(format!("{}models", self.base_url)) .bearer_auth(&self.key) .send() .await - .map_err(|_| "Failed to fetch model IDs from API server")?; + .map_err(|e| format!("Failed to fetch model IDs from API server: {}", e))?; let mut all_model_ids = response .json::() .await - .map_err(|_| "Failed to parse model IDs from API server")? + .map_err(|e| format!("Failed to parse model IDs from API server: {}", e))? .data .iter() .map(|model| model.id.trim().to_string()) diff --git a/src/api/mod.rs b/src/api/mod.rs index 004cef5..cec4b7c 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -42,7 +42,7 @@ impl Instance { model_id: String, ) -> Result<(), Box> { // Return Error if the chosen model is not valid - let valid_model_ids = self.fetch_all_model_ids().await?; + let valid_model_ids = self.fetch_language_model_ids().await?; if !valid_model_ids.contains(&model_id) { return Err(format!( "{} is not a valid language model from your provider", @@ -54,4 +54,8 @@ impl Instance { self.model_id = model_id; Ok(()) } + + pub fn set_base_url(&mut self, base_url: String) { + self.base_url = base_url; + } } diff --git a/src/clue.rs b/src/clue.rs index eb8dc9d..e5f9392 100644 --- a/src/clue.rs +++ b/src/clue.rs @@ -99,7 +99,7 @@ impl ClueCollection { pub fn display_table(&self) { println!("{}", self.generate_table()); } - + pub fn display_token_info(&self) { eprintln!( "\nToken Usage:\n\ @@ -108,9 +108,21 @@ impl ClueCollection { Completion Tokens: {}\n\ ----------------------\n\ Total Tokens: {}", - self.usage.prompt_tokens, - self.usage.completion_tokens, - self.usage.total_tokens + self.usage.prompt_tokens, self.usage.completion_tokens, self.usage.total_tokens ); } + + pub fn generate_raw_list(&self) -> String { + let mut raw_list = String::new(); + for clue in &self.clues { + let clue_string = format!( + "{} {} - {}\n", + clue.clue_word, + clue.count, + clue.linked_words.join(", ") + ); + raw_list.push_str(clue_string.as_str()); + } + raw_list + } } diff --git a/src/lib.rs b/src/lib.rs index fa4dbf0..07c18ff 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,8 @@ use std::path::PathBuf; pub mod api; mod clue; mod json_models; +#[cfg(test)] +mod tests; /// Mastermind - An LLM-powered CLI tool to help you be a better spymaster in Codenames #[derive(Parser)] diff --git a/src/main.rs b/src/main.rs index e4fdaa5..2089dd4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box> { // If -g is set, call the models API endpoint instead if args.get { - println!("{}", api_instance.fetch_all_model_ids().await?.join("\n")); + println!("{}", api_instance.fetch_language_model_ids().await?.join("\n")); return Ok(()); } diff --git a/src/tests/expected_outputs/chat_completions.txt b/src/tests/expected_outputs/chat_completions.txt new file mode 100644 index 0000000..6e7afb0 --- /dev/null +++ b/src/tests/expected_outputs/chat_completions.txt @@ -0,0 +1,5 @@ +music 2 - sound, bee +film 2 - bond, tokyo +free 2 - park, penny +dive 2 - scuba diver, hospital +large 2 - walrus, scuba diver diff --git a/src/tests/expected_outputs/language_models.txt b/src/tests/expected_outputs/language_models.txt new file mode 100644 index 0000000..51b43bf --- /dev/null +++ b/src/tests/expected_outputs/language_models.txt @@ -0,0 +1,13 @@ +distil-whisper-large-v3-en +gemma-7b-it +gemma2-9b-it +llama-3.1-70b-versatile +llama-3.1-8b-instant +llama-guard-3-8b +llama3-70b-8192 +llama3-8b-8192 +llama3-groq-70b-8192-tool-use-preview +llama3-groq-8b-8192-tool-use-preview +llava-v1.5-7b-4096-preview +mixtral-8x7b-32768 +whisper-large-v3 \ No newline at end of file diff --git a/src/tests/mock_responses/chat_completions.json b/src/tests/mock_responses/chat_completions.json new file mode 100644 index 0000000..008fab6 --- /dev/null +++ b/src/tests/mock_responses/chat_completions.json @@ -0,0 +1,30 @@ +{ + "id": "chatcmpl-869ede85-2f46-4834-a039-28d757e958a5", + "object": "chat.completion", + "created": 1726870549, + "model": "llama-3.1-70b-versatile", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "music, 2, sound, bee\nfilm, 2, bond, tokyo\nfree, 2, park, penny\ndive, 2, scuba diver, hospital\nlarge, 2, walrus, scuba diver" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "queue_time": 0.005669406999999987, + "prompt_tokens": 222, + "prompt_time": 0.068204384, + "completion_tokens": 53, + "completion_time": 0.214023764, + "total_tokens": 275, + "total_time": 0.282228148 + }, + "system_fingerprint": "fp_b6828be2c9", + "x_groq": { + "id": "req_01j88r2wfmecr9zgpjn2zmnprb" + } +} \ No newline at end of file diff --git a/src/tests/mock_responses/language_models.json b/src/tests/mock_responses/language_models.json new file mode 100644 index 0000000..8eb52c1 --- /dev/null +++ b/src/tests/mock_responses/language_models.json @@ -0,0 +1,122 @@ +{ + "object": "list", + "data": [ + { + "id": "llava-v1.5-7b-4096-preview", + "object": "model", + "created": 1725402373, + "owned_by": "Other", + "active": true, + "context_window": 4096, + "public_apps": null + }, + { + "id": "gemma-7b-it", + "object": "model", + "created": 1693721698, + "owned_by": "Google", + "active": true, + "context_window": 8192, + "public_apps": null + }, + { + "id": "llama-3.1-8b-instant", + "object": "model", + "created": 1693721698, + "owned_by": "Meta", + "active": true, + "context_window": 131072, + "public_apps": null + }, + { + "id": "whisper-large-v3", + "object": "model", + "created": 1693721698, + "owned_by": "OpenAI", + "active": true, + "context_window": 448, + "public_apps": null + }, + { + "id": "mixtral-8x7b-32768", + "object": "model", + "created": 1693721698, + "owned_by": "Mistral AI", + "active": true, + "context_window": 32768, + "public_apps": null + }, + { + "id": "llama3-8b-8192", + "object": "model", + "created": 1693721698, + "owned_by": "Meta", + "active": true, + "context_window": 8192, + "public_apps": null + }, + { + "id": "distil-whisper-large-v3-en", + "object": "model", + "created": 1693721698, + "owned_by": "Hugging Face", + "active": true, + "context_window": 448, + "public_apps": null + }, + { + "id": "llama-guard-3-8b", + "object": "model", + "created": 1693721698, + "owned_by": "Meta", + "active": true, + "context_window": 8192, + "public_apps": null + }, + { + "id": "gemma2-9b-it", + "object": "model", + "created": 1693721698, + "owned_by": "Google", + "active": true, + "context_window": 8192, + "public_apps": null + }, + { + "id": "llama3-70b-8192", + "object": "model", + "created": 1693721698, + "owned_by": "Meta", + "active": true, + "context_window": 8192, + "public_apps": null + }, + { + "id": "llama3-groq-70b-8192-tool-use-preview", + "object": "model", + "created": 1693721698, + "owned_by": "Groq", + "active": true, + "context_window": 8192, + "public_apps": null + }, + { + "id": "llama-3.1-70b-versatile", + "object": "model", + "created": 1693721698, + "owned_by": "Meta", + "active": true, + "context_window": 131072, + "public_apps": null + }, + { + "id": "llama3-groq-8b-8192-tool-use-preview", + "object": "model", + "created": 1693721698, + "owned_by": "Groq", + "active": true, + "context_window": 8192, + "public_apps": null + } + ] +} \ No newline at end of file diff --git a/src/tests/mod.rs b/src/tests/mod.rs new file mode 100644 index 0000000..5cbbdf3 --- /dev/null +++ b/src/tests/mod.rs @@ -0,0 +1,75 @@ +use super::*; +use crate::api::Instance; +use httpmock::prelude::*; + +#[test] +fn test_api_instance() { + let api_instance = api::Instance::new(); + assert!(api_instance.is_ok()); +} + +#[test] +fn test_read_words_from_file() { + let to_link = read_words_from_file(PathBuf::from("examples/link.txt")); + assert!(to_link.is_ok()); + let to_avoid = read_words_from_file(PathBuf::from("examples/avoid.txt")); + assert!(to_avoid.is_ok()); +} + +#[tokio::test] +async fn test_fetch_language_models() { + // Start a lightweight mock server. + let server = MockServer::start_async().await; + + // Create a mock on the server. + let mock = server.mock(|when, then| { + when.method(GET).path("/models"); + then.status(200) + .header("content-type", "application/json") + .body_from_file("src/tests/mock_responses/language_models.json"); + }); + + // Create an API instance and set the base url to mock server url + let mut api_instance = Instance::new().unwrap(); + api_instance.set_base_url(server.url("/")); + + // Get response from mock server + let response = api_instance.fetch_language_model_ids().await.unwrap(); + mock.assert(); + + // Compare outputs + let output = response.join("\n"); + let expected_output = fs::read_to_string("src/tests/expected_outputs/language_models.txt").unwrap(); + assert_eq!(output, expected_output); +} + +#[tokio::test] +async fn test_fetch_clue_collection() { + // Start a lightweight mock server. + let server = MockServer::start_async().await; + + // Create a mock on the server. + let mock = server.mock(|when, then| { + when.method(POST).path("/chat/completions"); + then.status(200) + .header("content-type", "application/json") + .body_from_file("src/tests/mock_responses/chat_completions.json"); + }); + + // Create an API instance and set the base url to mock server url + let mut api_instance = Instance::new().unwrap(); + api_instance.set_base_url(server.url("/")); + + // Get response from mock server + let response = api_instance + .fetch_clue_collection(vec![], vec![]) + .await + .unwrap(); + mock.assert(); + + // Compare outputs + let output = response.generate_raw_list(); + let expected_output = + fs::read_to_string("src/tests/expected_outputs/chat_completions.txt").unwrap(); + assert_eq!(output, expected_output); +}