Skip to content

Commit

Permalink
Rename various components
Browse files Browse the repository at this point in the history
  • Loading branch information
theoforger committed Sep 25, 2024
1 parent 7a915dd commit 9634cbd
Show file tree
Hide file tree
Showing 9 changed files with 19 additions and 19 deletions.
6 changes: 3 additions & 3 deletions src/api/chat_completions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::Instance;
use crate::clue::ClueCollection;
use crate::json_models::chat_completion::ChatCompletionResponse;
use crate::json_models::chat_completions::ChatCompletionsResponse;
use serde_json::json;

const SYSTEM_PROMPT: &str = r#"
Expand All @@ -21,7 +21,7 @@ Here are the requirements:
"#;

impl Instance {
pub async fn fetch_clue_collection(
pub async fn post_chat_completions(
&self,
link_words: Vec<String>,
avoid_words: Vec<String>,
Expand All @@ -39,7 +39,7 @@ impl Instance {
.map_err(|e| format!("Failed to fetch clue collection from API server: {}", e))?;

let parsed_response = response
.json::<ChatCompletionResponse>()
.json::<ChatCompletionsResponse>()
.await
.map_err(|e| format!("Failed to parse clues from API server: {}", e))?;

Expand Down
4 changes: 2 additions & 2 deletions src/api/language_models.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::Instance;
use crate::json_models::language_model::ModelsResponse;
use crate::json_models::language_models::ModelsResponse;

impl Instance {
pub async fn fetch_language_model_ids(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
pub async fn get_models(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let response = self
.client
.get(format!("{}models", self.base_url))
Expand Down
2 changes: 1 addition & 1 deletion src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Instance {
model_id: String,
) -> Result<(), Box<dyn std::error::Error>> {
// Return Error if the chosen model is not valid
let valid_model_ids = self.fetch_language_model_ids().await?;
let valid_model_ids = self.get_models().await?;
if !valid_model_ids.contains(&model_id) {
return Err(format!(
"{} is not a valid language model from your provider",
Expand Down
2 changes: 1 addition & 1 deletion src/clue.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::json_models::chat_completion::Usage;
use crate::json_models::chat_completions::Usage;
use comfy_table::modifiers::UTF8_ROUND_CORNERS;
use comfy_table::presets::UTF8_FULL;
use comfy_table::{Attribute, Cell, CellAlignment, ContentArrangement, Table};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub struct Usage {
}

#[derive(Deserialize)]
pub struct ChatCompletionResponse {
pub struct ChatCompletionsResponse {
pub choices: Vec<Choice>,
pub usage: Usage,
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use serde::Deserialize;

#[derive(Deserialize)]
pub struct LanguageModel {
pub struct Model {
pub id: String,
}

#[derive(Deserialize)]
pub struct ModelsResponse {
pub data: Vec<LanguageModel>,
}
pub data: Vec<Model>,
}
4 changes: 2 additions & 2 deletions src/json_models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod chat_completion;
pub mod language_model;
pub mod chat_completions;
pub mod language_models;
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// If -g is set, call the models API endpoint instead
if args.get {
println!("{}", api_instance.fetch_language_model_ids().await?.join("\n"));
println!("{}", api_instance.get_models().await?.join("\n"));
return Ok(());
}

Expand All @@ -28,7 +28,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

// Get clues from API
let clue_collection = api_instance
.fetch_clue_collection(link_words, avoid_words)
.post_chat_completions(link_words, avoid_words)
.await?;

// Output
Expand Down
8 changes: 4 additions & 4 deletions src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ fn test_read_words_from_file() {
}

#[tokio::test]
async fn test_fetch_language_models() {
async fn test_get_models() {
// Start a lightweight mock server.
let server = MockServer::start_async().await;

Expand All @@ -34,7 +34,7 @@ async fn test_fetch_language_models() {
api_instance.set_base_url(server.url("/"));

// Get response from mock server
let response = api_instance.fetch_language_model_ids().await.unwrap();
let response = api_instance.get_models().await.unwrap();
mock.assert();

// Compare outputs
Expand All @@ -44,7 +44,7 @@ async fn test_fetch_language_models() {
}

#[tokio::test]
async fn test_fetch_clue_collection() {
async fn test_post_chat_completions() {
// Start a lightweight mock server.
let server = MockServer::start_async().await;

Expand All @@ -62,7 +62,7 @@ async fn test_fetch_clue_collection() {

// Get response from mock server
let response = api_instance
.fetch_clue_collection(vec![], vec![])
.post_chat_completions(vec![], vec![])
.await
.unwrap();
mock.assert();
Expand Down

0 comments on commit 9634cbd

Please sign in to comment.