diff --git a/src/api/chat_completions.rs b/src/api/chat_completions.rs index 30a5974..a78df05 100644 --- a/src/api/chat_completions.rs +++ b/src/api/chat_completions.rs @@ -78,7 +78,9 @@ impl Instance { #[cfg(test)] mod tests { use super::*; + use crate::read_words_from_file; use httpmock::prelude::*; + use std::path::PathBuf; #[tokio::test] async fn test_post_chat_completions() { @@ -104,4 +106,39 @@ mod tests { .unwrap(); mock.assert(); } + + #[test] + fn test_build_request_body() { + // Mock input data + let link_words = vec!["link1".to_string(), "link2".to_string()]; + let avoid_words = vec!["avoid1".to_string(), "avoid2".to_string()]; + let model_id = "model".to_string(); + + // Assign result to the result of build_request_body() method + let result = Instance::build_request_body(&link_words, &avoid_words, &model_id); + + // Format expected content + let expected_content = format!( + "To Link:\n{}\n\nTo Avoid:\n{}", + link_words.join("\n"), + avoid_words.join("\n") + ); + + // Mock expected output + let expected = json!( + { + "messages": [ + { + "role": "system", + "content": SYSTEM_PROMPT, + }, + { + "role": "user", + "content": expected_content + } + ], + "model": model_id + }); + assert_eq!(expected, result); + } }