@@ -8,7 +8,8 @@ use http_client::{HttpClient, Method};
use language_model::{LlmApiToken, RefreshLlmTokenListener};
use web_search::{WebSearchProvider, WebSearchProviderId};
use zed_llm_client::{
- CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, WebSearchBody, WebSearchResponse,
+ CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
+ WebSearchBody, WebSearchResponse,
};
pub struct CloudWebSearchProvider {
@@ -73,32 +74,51 @@ async fn perform_web_search(
llm_api_token: LlmApiToken,
body: WebSearchBody,
) -> Result<WebSearchResponse> {
+ const MAX_RETRIES: usize = 3;
+
let http_client = &client.http_client();
+ let mut retries_remaining = MAX_RETRIES;
+ let mut token = llm_api_token.acquire(&client).await?;
- let token = llm_api_token.acquire(&client).await?;
+ loop {
+ if retries_remaining == 0 {
+ return Err(anyhow::anyhow!(
+ "error performing web search, max retries exceeded"
+ ));
+ }
- let request = http_client::Request::builder()
- .method(Method::POST)
- .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {token}"))
- .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
- .body(serde_json::to_string(&body)?.into())?;
- let mut response = http_client
- .send(request)
- .await
- .context("failed to send web search request")?;
+ let request = http_client::Request::builder()
+ .method(Method::POST)
+ .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {token}"))
+ .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
+ .body(serde_json::to_string(&body)?.into())?;
+ let mut response = http_client
+ .send(request)
+ .await
+ .context("failed to send web search request")?;
- if response.status().is_success() {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Ok(serde_json::from_str(&body)?);
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- anyhow::bail!(
- "error performing web search.\nStatus: {:?}\nBody: {body}",
- response.status(),
- );
+ if response.status().is_success() {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Ok(serde_json::from_str(&body)?);
+ } else if response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ {
+ token = llm_api_token.refresh(&client).await?;
+ retries_remaining -= 1;
+ } else {
+ // For now we will only retry if the LLM token is expired,
+ // not if the request failed for any other reason.
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ anyhow::bail!(
+ "error performing web search.\nStatus: {:?}\nBody: {body}",
+ response.status(),
+ );
+ }
}
}