From f807e908cb4d9c2f34f09098bc447cd0c4f7ebce Mon Sep 17 00:00:00 2001 From: Mehul Date: Wed, 18 Dec 2024 16:05:13 -0500 Subject: [PATCH 1/2] refactor: Move retry logic from infer_type_name to wizard --- Cargo.lock | 12 ++++++ Cargo.toml | 2 + src/cli/llm/error.rs | 1 + src/cli/llm/infer_type_name.rs | 67 +++++++++++++++------------------- src/cli/llm/wizard.rs | 38 +++++++++++++++---- 5 files changed, 75 insertions(+), 45 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5e0fd46574..1f9c722a74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5711,6 +5711,7 @@ dependencies = [ "test-log", "thiserror 1.0.69", "tokio", + "tokio-retry", "tokio-test", "tonic 0.11.0", "tonic-types", @@ -6169,6 +6170,17 @@ dependencies = [ "syn 2.0.90", ] +[[package]] +name = "tokio-retry" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project", + "rand", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" diff --git a/Cargo.toml b/Cargo.toml index 68f25713f8..ea60a11603 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ headers = "0.3.9" # previous version until hyper is updated to 1+ http = "0.2.12" # previous version until hyper is updated to 1+ insta = { version = "1.38.0", features = ["json"] } tokio = { version = "1.37.0", features = ["rt", "time"] } +tokio-retry = "0.3" reqwest = { version = "0.11", features = [ "json", "rustls-tls", @@ -66,6 +67,7 @@ rustls-pemfile = { version = "1.0.4" } schemars = { version = "0.8.17", features = ["derive"] } hyper = { version = "0.14.28", features = ["server"], default-features = false } tokio = { workspace = true } +tokio-retry = { workspace = true } anyhow = { workspace = true } reqwest = { workspace = true } derive_setters = "0.1.6" diff --git a/src/cli/llm/error.rs b/src/cli/llm/error.rs index c2b44f9ca3..0712266908 100644 --- a/src/cli/llm/error.rs +++ b/src/cli/llm/error.rs @@ -6,6 +6,7 @@ pub enum Error { GenAI(genai::Error), EmptyResponse, Serde(serde_json::Error), + Reqwest(reqwest::Error), } pub type Result = std::result::Result; diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 7183eacbde..0a0ca20dd1 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -123,46 +123,39 @@ impl InferTypeName { .collect(), }; - let mut delay = 3; - loop { - let answer = self.wizard.ask(question.clone()).await; - match answer { - Ok(answer) => { - let name = &answer.suggestions.join(", "); - for name in answer.suggestions { - if config.types.contains_key(&name) || used_type_names.contains(&name) { - continue; - } - used_type_names.insert(name.clone()); - new_name_mappings.insert(type_name.to_owned(), name); - break; + // Directly use the wizard's ask method to get a result + let answer = self.wizard.ask(question.clone()).await; + + match answer { + Ok(answer) => { + let name = &answer.suggestions.join(", "); + for name in answer.suggestions { + if config.types.contains_key(&name) || used_type_names.contains(&name) { + continue; } - tracing::info!( - "Suggestions for {}: [{}] - {}/{}", - type_name, - name, - i + 1, - total - ); - - // TODO: case where suggested names are already used, then extend the base - // question with `suggest different names, we have already used following - // names: [names list]` + used_type_names.insert(name.clone()); + new_name_mappings.insert(type_name.to_owned(), name); break; } - Err(e) => { - // TODO: log errors after certain number of retries. - if let Error::GenAI(_) = e { - // TODO: retry only when it's required. - tracing::warn!( - "Unable to retrieve a name for the type '{}'. Retrying in {}s", - type_name, - delay - ); - tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; - delay *= std::cmp::min(delay * 2, 60); - } - } + tracing::info!( + "Suggestions for {}: [{}] - {}/{}", + type_name, + name, + i + 1, + total + ); + + // TODO: case where suggested names are already used, then extend the base + // question with `suggest different names, we have already used following + // names: [names list]` + } + Err(e) => { + // Handle errors in case of failure + tracing::error!( + "Failed to get suggestions for type '{}': {:?}", + type_name, + e + ); } } } diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index 46d7a18624..46c7a3288b 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -3,8 +3,10 @@ use genai::adapter::AdapterKind; use genai::chat::{ChatOptions, ChatRequest, ChatResponse}; use genai::resolver::AuthResolver; use genai::Client; - -use super::Result; +use super::error::{Error, Result}; +use reqwest::{StatusCode}; +use tokio_retry::strategy::{ExponentialBackoff}; +use tokio_retry::RetryIf; #[derive(Setters, Clone)] pub struct Wizard { @@ -40,13 +42,33 @@ impl Wizard { pub async fn ask(&self, q: Q) -> Result where - Q: TryInto, + Q: TryInto + Clone, A: TryFrom, { - let response = self - .client - .exec_chat(self.model.as_str(), q.try_into()?, None) - .await?; - A::try_from(response) + let retry_strategy = ExponentialBackoff::from_millis(500) + .max_delay(std::time::Duration::from_secs(30)) + .take(5); + + RetryIf::spawn( + retry_strategy, + || async { + let request = q.clone().try_into()?; // Convert the question to a request + self.client + .exec_chat(self.model.as_str(), request, None) // Execute chat request + .await + .map_err(Error::from) + .and_then(A::try_from) // Convert the response into the desired result + }, + |err: &Error| { + // Check if the error is a ReqwestError and if the status is 429 + if let Error::Reqwest(reqwest_err) = err { + if let Some(status) = reqwest_err.status() { + return status == StatusCode::TOO_MANY_REQUESTS; + } + } + false + } + ) + .await } } From d97c68ee46766f9369f116d7f0084dbf81918a85 Mon Sep 17 00:00:00 2001 From: Mehul Date: Wed, 18 Dec 2024 16:36:58 -0500 Subject: [PATCH 2/2] refactor: lint fixes --- src/cli/llm/infer_type_name.rs | 5 +++-- src/cli/llm/wizard.rs | 12 +++++++----- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 0a0ca20dd1..2f2812c7d2 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -145,8 +145,9 @@ impl InferTypeName { total ); - // TODO: case where suggested names are already used, then extend the base - // question with `suggest different names, we have already used following + // TODO: case where suggested names are already used, then + // extend the base question with + // `suggest different names, we have already used following // names: [names list]` } Err(e) => { diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index 46c7a3288b..90a063626c 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -3,11 +3,12 @@ use genai::adapter::AdapterKind; use genai::chat::{ChatOptions, ChatRequest, ChatResponse}; use genai::resolver::AuthResolver; use genai::Client; -use super::error::{Error, Result}; -use reqwest::{StatusCode}; -use tokio_retry::strategy::{ExponentialBackoff}; +use reqwest::StatusCode; +use tokio_retry::strategy::ExponentialBackoff; use tokio_retry::RetryIf; +use super::error::{Error, Result}; + #[derive(Setters, Clone)] pub struct Wizard { client: Client, @@ -57,7 +58,8 @@ impl Wizard { .exec_chat(self.model.as_str(), request, None) // Execute chat request .await .map_err(Error::from) - .and_then(A::try_from) // Convert the response into the desired result + .and_then(A::try_from) // Convert the response into the + // desired result }, |err: &Error| { // Check if the error is a ReqwestError and if the status is 429 @@ -67,7 +69,7 @@ impl Wizard { } } false - } + }, ) .await }