diff --git a/src/api/mod.rs b/src/api/mod.rs index d9ac4f1..fe92b0c 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,7 +9,6 @@ pub struct Instance { client: reqwest::Client, base_url: String, key: String, - model_id: String, } impl Instance { @@ -23,13 +22,11 @@ impl Instance { base_url }; let key = Self::get_env_var("API_KEY")?; - let model_id = Self::get_env_var("DEFAULT_MODEL_ID")?; Ok(Self { client: reqwest::Client::new(), base_url, key, - model_id, }) } diff --git a/src/main.rs b/src/main.rs index 01019bb..40d7cfa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,6 @@ +use std::env; use clap::Parser; - +use dotenv::dotenv; use mastermind::api::Instance; use mastermind::clue::ClueCollection; use mastermind::json_models::chat_completions::ChatCompletionsResponse; @@ -10,6 +11,7 @@ use mastermind::*; async fn main() -> Result<(), Box> { // Read arguments and environment variables let args = Args::parse(); + dotenv().ok(); // Create an API instance and get all available models let mut api_instance = Instance::new()?; @@ -23,12 +25,14 @@ async fn main() -> Result<(), Box> { } // If -m is set, use a preferred language model - if let Some(model_id) = args.model { - if model_id == "interactive" { + if let Some(model_ids) = args.model { + if model_ids[0] == "interactive" { let selected_model = model_collection.prompt_selection()[0].to_string(); api_instance.set_model_id(selected_model).await?; } else { - api_instance.set_model_id(model_id).await?; + let selected_model = env::var("DEFAULT_MODEL_ID") + .map_err(|_| "Cannot read environment variable: DEFAULT_MODEL_ID".into())?; + api_instance.set_model_id(selected_model).await?; } }