diff --git a/Cargo.lock b/Cargo.lock index 62f59f80d1..f404972802 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5472,6 +5472,7 @@ dependencies = [ "test-log", "thiserror", "tokio", + "tokio-retry", "tokio-test", "tonic 0.11.0", "tonic-types", @@ -5893,6 +5894,17 @@ dependencies = [ "syn 2.0.79", ] +[[package]] +name = "tokio-retry" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project", + "rand", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" diff --git a/Cargo.toml b/Cargo.toml index 77941f7749..1258628f5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ headers = "0.3.9" # previous version until hyper is updated to 1+ http = "0.2.12" # previous version until hyper is updated to 1+ insta = { version = "1.38.0", features = ["json"] } tokio = { version = "1.37.0", features = ["rt", "time"] } +tokio-retry = "0.3" reqwest = { version = "0.11", features = [ "json", "rustls-tls", @@ -64,6 +65,7 @@ rustls-pemfile = { version = "1.0.4" } schemars = { version = "0.8.17", features = ["derive"] } hyper = { version = "0.14.28", features = ["server"], default-features = false } tokio = { workspace = true } +tokio-retry = { workspace = true } anyhow = { workspace = true } reqwest = { workspace = true } derive_setters = "0.1.6" diff --git a/src/cli/llm/error.rs b/src/cli/llm/error.rs index c2b44f9ca3..49f00cefd2 100644 --- a/src/cli/llm/error.rs +++ b/src/cli/llm/error.rs @@ -1,11 +1,46 @@ -use derive_more::From; -use strum_macros::Display; +use reqwest::StatusCode; +use thiserror::Error; -#[derive(Debug, From, Display, thiserror::Error)] +#[derive(Debug, Error)] +pub enum WebcError { + #[error("Response failed with status {status}: {body}")] + ResponseFailedStatus { status: StatusCode, body: String }, + #[error("Reqwest error: {0}")] + Reqwest(#[from] reqwest::Error), +} + +#[derive(Debug, Error)] pub enum Error { + #[error("GenAI error: {0}")] GenAI(genai::Error), + #[error("Webc error: {0}")] + Webc(WebcError), + #[error("Empty response")] EmptyResponse, - Serde(serde_json::Error), + #[error("Serde error: {0}")] + Serde(#[from] serde_json::Error), +} + +impl From for Error { + fn from(err: genai::Error) -> Self { + if let genai::Error::WebModelCall { webc_error, .. } = &err { + let error_str = webc_error.to_string(); + if error_str.contains("ResponseFailedStatus") { + // Extract status and body from the error message + let parts: Vec<&str> = error_str.splitn(3, ": ").collect(); + if parts.len() >= 3 { + if let Ok(status) = parts[1].parse::() { + return Error::Webc(WebcError::ResponseFailedStatus { + status: StatusCode::from_u16(status) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR), + body: parts[2].to_string(), + }); + } + } + } + }; + err.into() + } } -pub type Result = std::result::Result; +pub type Result = std::result::Result; diff --git a/src/cli/llm/infer_type_name.rs b/src/cli/llm/infer_type_name.rs index 7183eacbde..8a2d361568 100644 --- a/src/cli/llm/infer_type_name.rs +++ b/src/cli/llm/infer_type_name.rs @@ -94,6 +94,7 @@ impl InferTypeName { } pub async fn generate(&mut self, config: &Config) -> Result> { + let mut new_name_mappings: HashMap = HashMap::new(); // Filter out root operation types and types with non-auto-generated names let types_to_be_processed = config @@ -123,6 +124,7 @@ impl InferTypeName { .collect(), }; + let mut delay = 3; loop { let answer = self.wizard.ask(question.clone()).await; @@ -137,32 +139,19 @@ impl InferTypeName { new_name_mappings.insert(type_name.to_owned(), name); break; } - tracing::info!( - "Suggestions for {}: [{}] - {}/{}", - type_name, - name, - i + 1, - total - ); - - // TODO: case where suggested names are already used, then extend the base - // question with `suggest different names, we have already used following - // names: [names list]` + new_name_mappings.insert(name, type_name.to_owned()); break; } - Err(e) => { - // TODO: log errors after certain number of retries. - if let Error::GenAI(_) = e { - // TODO: retry only when it's required. - tracing::warn!( - "Unable to retrieve a name for the type '{}'. Retrying in {}s", - type_name, - delay - ); - tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; - delay *= std::cmp::min(delay * 2, 60); - } - } + tracing::info!( + "Suggestions for {}: [{}] - {}/{}", + type_name, + name, + i + 1, + total + ); + } + Err(e) => { + tracing::error!("Failed to generate name for {}: {:?}", type_name, e); } } } diff --git a/src/cli/llm/wizard.rs b/src/cli/llm/wizard.rs index 46d7a18624..2788926c47 100644 --- a/src/cli/llm/wizard.rs +++ b/src/cli/llm/wizard.rs @@ -1,10 +1,12 @@ +use super::error::{Error, Result, WebcError}; use derive_setters::Setters; use genai::adapter::AdapterKind; use genai::chat::{ChatOptions, ChatRequest, ChatResponse}; use genai::resolver::AuthResolver; use genai::Client; - -use super::Result; +use reqwest::StatusCode; +use tokio_retry::strategy::{jitter, ExponentialBackoff}; +use tokio_retry::RetryIf; #[derive(Setters, Clone)] pub struct Wizard { @@ -40,13 +42,23 @@ impl Wizard { pub async fn ask(&self, q: Q) -> Result where - Q: TryInto, - A: TryFrom, + Q: TryInto + Clone, + A: TryFrom, { - let response = self - .client - .exec_chat(self.model.as_str(), q.try_into()?, None) - .await?; - A::try_from(response) + let retry_strategy = ExponentialBackoff::from_millis(1000).map(jitter).take(5); + + RetryIf::spawn( + retry_strategy, + || async { + let request = q.clone().try_into()?; + self.client + .exec_chat(self.model.as_str(), request, None) + .await + .map_err(Error::from) + .and_then(A::try_from) + }, + |err: &Error| matches!(err, Error::Webc(WebcError::ResponseFailedStatus { status, .. }) if *status == StatusCode::TOO_MANY_REQUESTS) + ) + .await } }