Skip to content

Commit

Permalink
fix(wizard): use RetryIf for more specific error handling
Browse files Browse the repository at this point in the history
Signed-off-by: David Anyatonwu <[email protected]>
  • Loading branch information
onyedikachi-david committed Aug 27, 2024
1 parent 5ee7b1a commit 49863dd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 28 deletions.
4 changes: 2 additions & 2 deletions src/cli/llm/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ impl From<genai::Error> for Error {
}
}
}
}
Error::GenAI(err)
};
err.into()
}
}

Expand Down
40 changes: 14 additions & 26 deletions src/cli/llm/wizard.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@

use super::error::{Error, Result, WebcError};
use derive_setters::Setters;
use genai::adapter::AdapterKind;
use genai::chat::{ChatOptions, ChatRequest, ChatResponse};
use genai::resolver::AuthResolver;
use genai::Client;
use reqwest::StatusCode;
use tokio_retry::strategy::{jitter, ExponentialBackoff};
use tokio_retry::Retry;
use super::error::{Error, Result, WebcError};
use crate::cli::llm::model::Model;

use tokio_retry::RetryIf;

#[derive(Setters, Clone)]
pub struct Wizard<Q, A> {
Expand Down Expand Up @@ -50,27 +47,18 @@ impl<Q, A> Wizard<Q, A> {
{
let retry_strategy = ExponentialBackoff::from_millis(1000).map(jitter).take(5);

Retry::spawn(retry_strategy, || async {
let request = q.clone().try_into()?;
match self
.client
.exec_chat(self.model.as_str(), request, None)
.await
{
Ok(response) => Ok(A::try_from(response)?),
Err(err) => {
let error = Error::from(err);
match &error {
Error::Webc(WebcError::ResponseFailedStatus { status, .. })
if *status == StatusCode::TOO_MANY_REQUESTS =>
{
Err(error) // Propagate the error to trigger a retry
}
_ => Ok(Err(error)?), // Other errors are returned without retrying
}
}
}
})
RetryIf::spawn(
retry_strategy,
|| async {
let request = q.clone().try_into()?;
self.client
.exec_chat(self.model.as_str(), request, None)
.await
.map_err(Error::from)
.and_then(A::try_from)
},
|err: &Error| matches!(err, Error::Webc(WebcError::ResponseFailedStatus { status, .. }) if *status == StatusCode::TOO_MANY_REQUESTS)
)
.await
}
}

0 comments on commit 49863dd

Please sign in to comment.