@@ -27,9 +27,11 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::value::RawValue;
use settings::{Settings, SettingsStore};
use smol::io::{AsyncReadExt, BufReader};
+use smol::Timer;
use std::{
future,
sync::{Arc, LazyLock},
+ time::Duration,
};
use strum::IntoEnumIterator;
use ui::{prelude::*, TintColor};
@@ -456,6 +458,8 @@ pub struct CloudLanguageModel {
}
impl CloudLanguageModel {
+ const MAX_RETRIES: usize = 3;
+
async fn perform_llm_completion(
client: Arc<Client>,
llm_api_token: LlmApiToken,
@@ -464,9 +468,10 @@ impl CloudLanguageModel {
let http_client = &client.http_client();
let mut token = llm_api_token.acquire(&client).await?;
- let mut did_retry = false;
+ let mut retries_remaining = Self::MAX_RETRIES;
+ let mut retry_delay = Duration::from_secs(1);
- let response = loop {
+ loop {
let request_builder = http_client::Request::builder();
let request = request_builder
.method(Method::POST)
@@ -475,36 +480,53 @@ impl CloudLanguageModel {
.header("Authorization", format!("Bearer {token}"))
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?;
- if response.status().is_success() {
- break response;
- } else if !did_retry
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
+ let status = response.status();
+ if status.is_success() {
+ return Ok(response);
+ } else if response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
{
- did_retry = true;
+ retries_remaining -= 1;
token = llm_api_token.refresh(&client).await?;
- } else if response.status() == StatusCode::FORBIDDEN
+ } else if status == StatusCode::FORBIDDEN
&& response
.headers()
.get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
.is_some()
{
- break Err(anyhow!(MaxMonthlySpendReachedError))?;
- } else if response.status() == StatusCode::PAYMENT_REQUIRED {
- break Err(anyhow!(PaymentRequiredError))?;
+ return Err(anyhow!(MaxMonthlySpendReachedError));
+ } else if status.as_u16() >= 500 && status.as_u16() < 600 {
+ // If we encounter an error in the 500 range, retry after a delay.
+ // We've seen at least these in the wild from API providers:
+ // * 500 Internal Server Error
+ // * 502 Bad Gateway
+ // * 529 Service Overloaded
+
+ if retries_remaining == 0 {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "cloud language model completion failed after {} retries with status {status}: {body}",
+ Self::MAX_RETRIES
+ ));
+ }
+
+ Timer::after(retry_delay).await;
+
+ retries_remaining -= 1;
+ retry_delay *= 2; // If it fails again, wait longer.
+ } else if status == StatusCode::PAYMENT_REQUIRED {
+ return Err(anyhow!(PaymentRequiredError));
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
- break Err(anyhow!(
- "cloud language model completion failed with status {}: {body}",
- response.status()
- ))?;
+ return Err(anyhow!(
+ "cloud language model completion failed with status {status}: {body}",
+ ));
}
- };
-
- Ok(response)
+ }
}
}