agent: Fix issue where web search could return 401 (#33639)

Bennet Bo Fenner created

Closes #33524

Release Notes:

- agent: Fix an issue where performing a web search request would
sometimes fail

Change summary

crates/web_search_providers/src/cloud.rs | 68 ++++++++++++++++---------
1 file changed, 44 insertions(+), 24 deletions(-)

Detailed changes

crates/web_search_providers/src/cloud.rs 🔗

@@ -8,7 +8,8 @@ use http_client::{HttpClient, Method};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use web_search::{WebSearchProvider, WebSearchProviderId};
 use zed_llm_client::{
-    CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, WebSearchBody, WebSearchResponse,
+    CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
+    WebSearchBody, WebSearchResponse,
 };
 
 pub struct CloudWebSearchProvider {
@@ -73,32 +74,51 @@ async fn perform_web_search(
     llm_api_token: LlmApiToken,
     body: WebSearchBody,
 ) -> Result<WebSearchResponse> {
+    const MAX_RETRIES: usize = 3;
+
     let http_client = &client.http_client();
+    let mut retries_remaining = MAX_RETRIES;
+    let mut token = llm_api_token.acquire(&client).await?;
 
-    let token = llm_api_token.acquire(&client).await?;
+    loop {
+        if retries_remaining == 0 {
+            return Err(anyhow::anyhow!(
+                "error performing web search, max retries exceeded"
+            ));
+        }
 
-    let request = http_client::Request::builder()
-        .method(Method::POST)
-        .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
-        .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {token}"))
-        .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
-        .body(serde_json::to_string(&body)?.into())?;
-    let mut response = http_client
-        .send(request)
-        .await
-        .context("failed to send web search request")?;
+        let request = http_client::Request::builder()
+            .method(Method::POST)
+            .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
+            .header("Content-Type", "application/json")
+            .header("Authorization", format!("Bearer {token}"))
+            .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
+            .body(serde_json::to_string(&body)?.into())?;
+        let mut response = http_client
+            .send(request)
+            .await
+            .context("failed to send web search request")?;
 
-    if response.status().is_success() {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-        return Ok(serde_json::from_str(&body)?);
-    } else {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-        anyhow::bail!(
-            "error performing web search.\nStatus: {:?}\nBody: {body}",
-            response.status(),
-        );
+        if response.status().is_success() {
+            let mut body = String::new();
+            response.body_mut().read_to_string(&mut body).await?;
+            return Ok(serde_json::from_str(&body)?);
+        } else if response
+            .headers()
+            .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+            .is_some()
+        {
+            token = llm_api_token.refresh(&client).await?;
+            retries_remaining -= 1;
+        } else {
+            // For now we will only retry if the LLM token is expired,
+            // not if the request failed for any other reason.
+            let mut body = String::new();
+            response.body_mut().read_to_string(&mut body).await?;
+            anyhow::bail!(
+                "error performing web search.\nStatus: {:?}\nBody: {body}",
+                response.status(),
+            );
+        }
     }
 }