Skip to content

Commit

Permalink
refactor: Move retry logic from infer_type_name to wizard
Browse files Browse the repository at this point in the history
  • Loading branch information
mehul-m-prajapati committed Dec 18, 2024
1 parent d8947cd commit f807e90
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 45 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ headers = "0.3.9" # previous version until hyper is updated to 1+
http = "0.2.12" # previous version until hyper is updated to 1+
insta = { version = "1.38.0", features = ["json"] }
tokio = { version = "1.37.0", features = ["rt", "time"] }
tokio-retry = "0.3"
reqwest = { version = "0.11", features = [
"json",
"rustls-tls",
Expand Down Expand Up @@ -66,6 +67,7 @@ rustls-pemfile = { version = "1.0.4" }
schemars = { version = "0.8.17", features = ["derive"] }
hyper = { version = "0.14.28", features = ["server"], default-features = false }
tokio = { workspace = true }
tokio-retry = { workspace = true }
anyhow = { workspace = true }
reqwest = { workspace = true }
derive_setters = "0.1.6"
Expand Down
1 change: 1 addition & 0 deletions src/cli/llm/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub enum Error {
GenAI(genai::Error),
EmptyResponse,
Serde(serde_json::Error),
Reqwest(reqwest::Error),
}

pub type Result<A> = std::result::Result<A, Error>;
67 changes: 30 additions & 37 deletions src/cli/llm/infer_type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,46 +123,39 @@ impl InferTypeName {
.collect(),
};

let mut delay = 3;
loop {
let answer = self.wizard.ask(question.clone()).await;
match answer {
Ok(answer) => {
let name = &answer.suggestions.join(", ");
for name in answer.suggestions {
if config.types.contains_key(&name) || used_type_names.contains(&name) {
continue;
}
used_type_names.insert(name.clone());
new_name_mappings.insert(type_name.to_owned(), name);
break;
// Directly use the wizard's ask method to get a result
let answer = self.wizard.ask(question.clone()).await;

match answer {
Ok(answer) => {
let name = &answer.suggestions.join(", ");
for name in answer.suggestions {
if config.types.contains_key(&name) || used_type_names.contains(&name) {
continue;
}
tracing::info!(
"Suggestions for {}: [{}] - {}/{}",
type_name,
name,
i + 1,
total
);

// TODO: case where suggested names are already used, then extend the base
// question with `suggest different names, we have already used following
// names: [names list]`
used_type_names.insert(name.clone());
new_name_mappings.insert(type_name.to_owned(), name);
break;
}
Err(e) => {
// TODO: log errors after certain number of retries.
if let Error::GenAI(_) = e {
// TODO: retry only when it's required.
tracing::warn!(
"Unable to retrieve a name for the type '{}'. Retrying in {}s",
type_name,
delay
);
tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await;
delay *= std::cmp::min(delay * 2, 60);
}
}
tracing::info!(
"Suggestions for {}: [{}] - {}/{}",
type_name,
name,
i + 1,
total
);

// TODO: case where suggested names are already used, then extend the base
// question with `suggest different names, we have already used following
// names: [names list]`
}
Err(e) => {
// Handle errors in case of failure
tracing::error!(
"Failed to get suggestions for type '{}': {:?}",
type_name,
e
);
}
}
}
Expand Down
38 changes: 30 additions & 8 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ use genai::adapter::AdapterKind;
use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::resolver::AuthResolver;
use genai::Client;

use super::Result;
use super::error::{Error, Result};
use reqwest::{StatusCode};
use tokio_retry::strategy::{ExponentialBackoff};
use tokio_retry::RetryIf;

#[derive(Setters, Clone)]
pub struct Wizard<Q, A> {
Expand Down Expand Up @@ -40,13 +42,33 @@ impl<Q, A> Wizard<Q, A> {

pub async fn ask(&self, q: Q) -> Result<A>
where
Q: TryInto<ChatRequest, Error = super::Error>,
Q: TryInto<ChatRequest, Error = super::Error> + Clone,
A: TryFrom<ChatResponse, Error = super::Error>,
{
let response = self
.client
.exec_chat(self.model.as_str(), q.try_into()?, None)
.await?;
A::try_from(response)
let retry_strategy = ExponentialBackoff::from_millis(500)
.max_delay(std::time::Duration::from_secs(30))
.take(5);

RetryIf::spawn(
retry_strategy,
|| async {
let request = q.clone().try_into()?; // Convert the question to a request
self.client
.exec_chat(self.model.as_str(), request, None) // Execute chat request
.await
.map_err(Error::from)
.and_then(A::try_from) // Convert the response into the desired result
},
|err: &Error| {
// Check if the error is a ReqwestError and if the status is 429
if let Error::Reqwest(reqwest_err) = err {
if let Some(status) = reqwest_err.status() {
return status == StatusCode::TOO_MANY_REQUESTS;
}
}
false
}
)
.await
}
}

0 comments on commit f807e90

Please sign in to comment.