Retry on 5xx errors from cloud language model providers (#27584)

Richard Feldman created

Release Notes:

- N/A

Change summary

crates/language_models/src/provider/cloud.rs | 64 ++++++++++++++-------
1 file changed, 43 insertions(+), 21 deletions(-)

Detailed changes

crates/language_models/src/provider/cloud.rs 🔗

@@ -27,9 +27,11 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use serde_json::value::RawValue;
 use settings::{Settings, SettingsStore};
 use smol::io::{AsyncReadExt, BufReader};
+use smol::Timer;
 use std::{
     future,
     sync::{Arc, LazyLock},
+    time::Duration,
 };
 use strum::IntoEnumIterator;
 use ui::{prelude::*, TintColor};
@@ -456,6 +458,8 @@ pub struct CloudLanguageModel {
 }
 
 impl CloudLanguageModel {
+    const MAX_RETRIES: usize = 3;
+
     async fn perform_llm_completion(
         client: Arc<Client>,
         llm_api_token: LlmApiToken,
@@ -464,9 +468,10 @@ impl CloudLanguageModel {
         let http_client = &client.http_client();
 
         let mut token = llm_api_token.acquire(&client).await?;
-        let mut did_retry = false;
+        let mut retries_remaining = Self::MAX_RETRIES;
+        let mut retry_delay = Duration::from_secs(1);
 
-        let response = loop {
+        loop {
             let request_builder = http_client::Request::builder();
             let request = request_builder
                 .method(Method::POST)
@@ -475,36 +480,53 @@ impl CloudLanguageModel {
                 .header("Authorization", format!("Bearer {token}"))
                 .body(serde_json::to_string(&body)?.into())?;
             let mut response = http_client.send(request).await?;
-            if response.status().is_success() {
-                break response;
-            } else if !did_retry
-                && response
-                    .headers()
-                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
-                    .is_some()
+            let status = response.status();
+            if status.is_success() {
+                return Ok(response);
+            } else if response
+                .headers()
+                .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+                .is_some()
             {
-                did_retry = true;
+                retries_remaining -= 1;
                 token = llm_api_token.refresh(&client).await?;
-            } else if response.status() == StatusCode::FORBIDDEN
+            } else if status == StatusCode::FORBIDDEN
                 && response
                     .headers()
                     .get(MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME)
                     .is_some()
             {
-                break Err(anyhow!(MaxMonthlySpendReachedError))?;
-            } else if response.status() == StatusCode::PAYMENT_REQUIRED {
-                break Err(anyhow!(PaymentRequiredError))?;
+                return Err(anyhow!(MaxMonthlySpendReachedError));
+            } else if status.as_u16() >= 500 && status.as_u16() < 600 {
+                // If we encounter an error in the 500 range, retry after a delay.
+                // We've seen at least these in the wild from API providers:
+                // * 500 Internal Server Error
+                // * 502 Bad Gateway
+                // * 529 Service Overloaded
+
+                if retries_remaining == 0 {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    return Err(anyhow!(
+                        "cloud language model completion failed after {} retries with status {status}: {body}",
+                        Self::MAX_RETRIES
+                    ));
+                }
+
+                Timer::after(retry_delay).await;
+
+                retries_remaining -= 1;
+                retry_delay *= 2; // If it fails again, wait longer.
+            } else if status == StatusCode::PAYMENT_REQUIRED {
+                return Err(anyhow!(PaymentRequiredError));
             } else {
                 let mut body = String::new();
                 response.body_mut().read_to_string(&mut body).await?;
-                break Err(anyhow!(
-                    "cloud language model completion failed with status {}: {body}",
-                    response.status()
-                ))?;
+                return Err(anyhow!(
+                    "cloud language model completion failed with status {status}: {body}",
+                ));
             }
-        };
-
-        Ok(response)
+        }
     }
 }