diff --git a/README.md b/README.md index 14d8383..ff28d59 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,7 @@ Feel free to run the program multiple times to get the best result! - `-o`, `--output` : Specify an output file - `-h`, `--help` : Print help - `-V`, `--version` : Print version +- `-t`, `--token-usage` : Print token usage ## 🛠️ Building diff --git a/src/api/chat_completions.rs b/src/api/chat_completions.rs index 9306b34..faa7e43 100644 --- a/src/api/chat_completions.rs +++ b/src/api/chat_completions.rs @@ -1,8 +1,10 @@ -use super::json_models::chat_completion::ChatCompletionResponse; +use super::json_models::chat_completion::{ChatCompletionResponse}; use super::Instance; use crate::clue::ClueCollection; use serde_json::json; + + const SYSTEM_PROMPT: &str = r#" You are the spymaster in Codenames. I will give you a list of [agent word], followed by a list of [avoid word]. @@ -38,12 +40,20 @@ impl Instance { .await .map_err(|_| "Failed to fetch clue collection from API server")?; - // Deserialize the response - let clue_strings = response + let parsed_response = response .json::() .await - .map_err(|_| "Failed to parse clues from API server")? - .choices[0] + .map_err(|_| "Failed to parse clues from API server")?; + + // Extract usage information from the parsed response + let token_usage = parsed_response.usage; + + + // Extract clue strings from the parsed response + let clue_strings = parsed_response + .choices + .get(0) + .ok_or("No choices returned from API")? .message .content .lines() @@ -51,7 +61,7 @@ impl Instance { .collect::>(); // Build clues - let clue_collection = ClueCollection::new(clue_strings); + let clue_collection = ClueCollection::new(clue_strings, token_usage); Ok(clue_collection) } diff --git a/src/api/json_models/chat_completion.rs b/src/api/json_models/chat_completion.rs index 47fd4b2..08e1469 100644 --- a/src/api/json_models/chat_completion.rs +++ b/src/api/json_models/chat_completion.rs @@ -10,7 +10,15 @@ pub struct Choice { pub message: Message, } +#[derive(Deserialize)] +pub struct Usage { + pub prompt_tokens: usize, + pub completion_tokens: usize, + pub total_tokens: usize, +} + #[derive(Deserialize)] pub struct ChatCompletionResponse { pub choices: Vec, + pub usage: Usage, } diff --git a/src/api/json_models/mod.rs b/src/api/json_models/mod.rs index 7381405..c19183e 100644 --- a/src/api/json_models/mod.rs +++ b/src/api/json_models/mod.rs @@ -1,2 +1,2 @@ -pub mod chat_completion; pub mod language_model; +pub mod chat_completion; diff --git a/src/api/mod.rs b/src/api/mod.rs index b95b552..8b9fa6f 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,5 +1,5 @@ -mod chat_completions; -mod json_models; +pub mod chat_completions; +pub mod json_models; mod language_models; use dotenv::dotenv; diff --git a/src/clue.rs b/src/clue.rs index 3f53ce9..4736355 100644 --- a/src/clue.rs +++ b/src/clue.rs @@ -1,7 +1,7 @@ use comfy_table::modifiers::UTF8_ROUND_CORNERS; use comfy_table::presets::UTF8_FULL; use comfy_table::{Attribute, Cell, CellAlignment, ContentArrangement, Table}; - +use crate::api::json_models::chat_completion::Usage; struct Clue { clue_word: String, count: usize, @@ -10,6 +10,7 @@ struct Clue { pub struct ClueCollection { clues: Vec, + pub usage: Usage, } impl Clue { @@ -17,6 +18,8 @@ impl Clue { pub fn new(clue_line: &str) -> Option { let chunks: Vec<&str> = clue_line.split(", ").collect(); + + // Discard empty lines as well as clues with only one word linked if chunks.len() < 4 { return None; @@ -44,13 +47,13 @@ impl Clue { impl ClueCollection { /// Create an instance of `ClueCollection` from `Vec`, which contains lines of clue response from the API - pub fn new(clue_strings: Vec) -> Self { + pub fn new(clue_strings: Vec, usage: Usage) -> Self { let mut clues: Vec = clue_strings.iter().filter_map(|s| Clue::new(s)).collect(); // Sort the clues by the number of words they link together clues.sort_by(|a, b| b.count.cmp(&a.count)); - Self { clues } + Self { clues, usage } } pub fn is_empty(&self) -> bool { diff --git a/src/lib.rs b/src/lib.rs index d66cda0..14e64f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,10 @@ pub struct Args { #[arg(short, long, value_name = "FILE")] pub output: Option, + /// Print all token usage information + #[arg(short, long = "token-usage")] + pub token: bool, + /// File containing words to link together - the words from your team #[arg(required_unless_present = "get")] pub to_link: Option, diff --git a/src/main.rs b/src/main.rs index fbdf4f0..983c8f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,6 +31,7 @@ async fn main() -> Result<(), Box> { .fetch_clue_collection(link_words, avoid_words) .await?; + // Output if clue_collection.is_empty() { println!("The language model didn't return any useful clues. Maybe try again?"); @@ -41,5 +42,12 @@ async fn main() -> Result<(), Box> { clue_collection.display(); } + // If -t is set, output the token usage information + if args.token { + // Write to stderr in the format: prompt_tokens, completion_tokens, total_tokens + eprintln!("\nTokens Usage\n----------------------\nPrompt Tokens: {}\nCompletion Tokens: {}\n----------------------\nTotal Tokens: {}", + clue_collection.usage.prompt_tokens, clue_collection.usage.completion_tokens, clue_collection.usage.total_tokens); + } + Ok(()) }