diff --git a/Cargo.lock b/Cargo.lock
index e36f881d60..bd8d994061 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -5711,6 +5711,7 @@ dependencies = [
"test-log",
"thiserror 1.0.69",
"tokio",
+ "tokio-retry",
"tokio-test",
"tonic 0.11.0",
"tonic-types",
@@ -6170,6 +6171,17 @@ dependencies = [
"syn 2.0.90",
]
+[[package]]
+name = "tokio-retry"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f"
+dependencies = [
+ "pin-project",
+ "rand",
+ "tokio",
+]
+
[[package]]
name = "tokio-rustls"
version = "0.24.1"
diff --git a/Cargo.toml b/Cargo.toml
index 9ee5d9b070..c378b63a27 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,6 +16,7 @@ headers = "0.3.9" # previous version until hyper is updated to 1+
http = "0.2.12" # previous version until hyper is updated to 1+
insta = { version = "1.38.0", features = ["json"] }
tokio = { version = "1.37.0", features = ["rt", "time"] }
+tokio-retry = "0.3"
reqwest = { version = "0.11", features = [
"json",
"rustls-tls",
@@ -66,6 +67,7 @@ rustls-pemfile = { version = "1.0.4" }
schemars = { version = "0.8.17", features = ["derive"] }
hyper = { version = "0.14.28", features = ["server"], default-features = false }
tokio = { workspace = true }
+tokio-retry = { workspace = true }
anyhow = { workspace = true }
reqwest = { workspace = true }
derive_setters = "0.1.6"
diff --git a/src/cli/llm/error.rs b/src/cli/llm/error.rs
index c2b44f9ca3..0712266908 100644
--- a/src/cli/llm/error.rs
+++ b/src/cli/llm/error.rs
@@ -6,6 +6,7 @@ pub enum Error {
GenAI(genai::Error),
EmptyResponse,
Serde(serde_json::Error),
+ Reqwest(reqwest::Error),
}
pub type Result = std::result::Result;
diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs
index 7183eacbde..2f2812c7d2 100644
--- a/src/cli/llm/infer_type_name.rs
+++ b/src/cli/llm/infer_type_name.rs
@@ -123,46 +123,40 @@ impl InferTypeName {
.collect(),
};
- let mut delay = 3;
- loop {
- let answer = self.wizard.ask(question.clone()).await;
- match answer {
- Ok(answer) => {
- let name = &answer.suggestions.join(", ");
- for name in answer.suggestions {
- if config.types.contains_key(&name) || used_type_names.contains(&name) {
- continue;
- }
- used_type_names.insert(name.clone());
- new_name_mappings.insert(type_name.to_owned(), name);
- break;
+ // Directly use the wizard's ask method to get a result
+ let answer = self.wizard.ask(question.clone()).await;
+
+ match answer {
+ Ok(answer) => {
+ let name = &answer.suggestions.join(", ");
+ for name in answer.suggestions {
+ if config.types.contains_key(&name) || used_type_names.contains(&name) {
+ continue;
}
- tracing::info!(
- "Suggestions for {}: [{}] - {}/{}",
- type_name,
- name,
- i + 1,
- total
- );
-
- // TODO: case where suggested names are already used, then extend the base
- // question with `suggest different names, we have already used following
- // names: [names list]`
+ used_type_names.insert(name.clone());
+ new_name_mappings.insert(type_name.to_owned(), name);
break;
}
- Err(e) => {
- // TODO: log errors after certain number of retries.
- if let Error::GenAI(_) = e {
- // TODO: retry only when it's required.
- tracing::warn!(
- "Unable to retrieve a name for the type '{}'. Retrying in {}s",
- type_name,
- delay
- );
- tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
- delay *= std::cmp::min(delay * 2, 60);
- }
- }
+ tracing::info!(
+ "Suggestions for {}: [{}] - {}/{}",
+ type_name,
+ name,
+ i + 1,
+ total
+ );
+
+ // TODO: case where suggested names are already used, then
+ // extend the base question with
+ // `suggest different names, we have already used following
+ // names: [names list]`
+ }
+ Err(e) => {
+ // Handle errors in case of failure
+ tracing::error!(
+ "Failed to get suggestions for type '{}': {:?}",
+ type_name,
+ e
+ );
}
}
}
diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs
index 46d7a18624..90a063626c 100644
--- a/src/cli/llm/wizard.rs
+++ b/src/cli/llm/wizard.rs
@@ -3,8 +3,11 @@ use genai::adapter::AdapterKind;
use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::resolver::AuthResolver;
use genai::Client;
+use reqwest::StatusCode;
+use tokio_retry::strategy::ExponentialBackoff;
+use tokio_retry::RetryIf;
-use super::Result;
+use super::error::{Error, Result};
#[derive(Setters, Clone)]
pub struct Wizard {
@@ -40,13 +43,34 @@ impl Wizard {
pub async fn ask(&self, q: Q) -> Result
where
- Q: TryInto,
+ Q: TryInto + Clone,
A: TryFrom,
{
- let response = self
- .client
- .exec_chat(self.model.as_str(), q.try_into()?, None)
- .await?;
- A::try_from(response)
+ let retry_strategy = ExponentialBackoff::from_millis(500)
+ .max_delay(std::time::Duration::from_secs(30))
+ .take(5);
+
+ RetryIf::spawn(
+ retry_strategy,
+ || async {
+ let request = q.clone().try_into()?; // Convert the question to a request
+ self.client
+ .exec_chat(self.model.as_str(), request, None) // Execute chat request
+ .await
+ .map_err(Error::from)
+ .and_then(A::try_from) // Convert the response into the
+ // desired result
+ },
+ |err: &Error| {
+ // Check if the error is a ReqwestError and if the status is 429
+ if let Error::Reqwest(reqwest_err) = err {
+ if let Some(status) = reqwest_err.status() {
+ return status == StatusCode::TOO_MANY_REQUESTS;
+ }
+ }
+ false
+ },
+ )
+ .await
}
}