cloud.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result};
  4use client::Client;
  5use futures::AsyncReadExt as _;
  6use gpui::{App, AppContext, Context, Entity, Subscription, Task};
  7use http_client::{HttpClient, Method};
  8use language_model::{LlmApiToken, RefreshLlmTokenListener};
  9use web_search::{WebSearchProvider, WebSearchProviderId};
 10use zed_llm_client::{
 11    CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
 12    WebSearchBody, WebSearchResponse,
 13};
 14
 15pub struct CloudWebSearchProvider {
 16    state: Entity<State>,
 17}
 18
 19impl CloudWebSearchProvider {
 20    pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
 21        let state = cx.new(|cx| State::new(client, cx));
 22
 23        Self { state }
 24    }
 25}
 26
 27pub struct State {
 28    client: Arc<Client>,
 29    llm_api_token: LlmApiToken,
 30    _llm_token_subscription: Subscription,
 31}
 32
 33impl State {
 34    pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
 35        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 36
 37        Self {
 38            client,
 39            llm_api_token: LlmApiToken::default(),
 40            _llm_token_subscription: cx.subscribe(
 41                &refresh_llm_token_listener,
 42                |this, _, _event, cx| {
 43                    let client = this.client.clone();
 44                    let llm_api_token = this.llm_api_token.clone();
 45                    cx.spawn(async move |_this, _cx| {
 46                        llm_api_token.refresh(&client).await?;
 47                        anyhow::Ok(())
 48                    })
 49                    .detach_and_log_err(cx);
 50                },
 51            ),
 52        }
 53    }
 54}
 55
 56pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
 57
 58impl WebSearchProvider for CloudWebSearchProvider {
 59    fn id(&self) -> WebSearchProviderId {
 60        WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
 61    }
 62
 63    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
 64        let state = self.state.read(cx);
 65        let client = state.client.clone();
 66        let llm_api_token = state.llm_api_token.clone();
 67        let body = WebSearchBody { query };
 68        cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
 69    }
 70}
 71
 72async fn perform_web_search(
 73    client: Arc<Client>,
 74    llm_api_token: LlmApiToken,
 75    body: WebSearchBody,
 76) -> Result<WebSearchResponse> {
 77    const MAX_RETRIES: usize = 3;
 78
 79    let http_client = &client.http_client();
 80    let mut retries_remaining = MAX_RETRIES;
 81    let mut token = llm_api_token.acquire(&client).await?;
 82
 83    loop {
 84        if retries_remaining == 0 {
 85            return Err(anyhow::anyhow!(
 86                "error performing web search, max retries exceeded"
 87            ));
 88        }
 89
 90        let request = http_client::Request::builder()
 91            .method(Method::POST)
 92            .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
 93            .header("Content-Type", "application/json")
 94            .header("Authorization", format!("Bearer {token}"))
 95            .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
 96            .body(serde_json::to_string(&body)?.into())?;
 97        let mut response = http_client
 98            .send(request)
 99            .await
100            .context("failed to send web search request")?;
101
102        if response.status().is_success() {
103            let mut body = String::new();
104            response.body_mut().read_to_string(&mut body).await?;
105            return Ok(serde_json::from_str(&body)?);
106        } else if response
107            .headers()
108            .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
109            .is_some()
110        {
111            token = llm_api_token.refresh(&client).await?;
112            retries_remaining -= 1;
113        } else {
114            // For now we will only retry if the LLM token is expired,
115            // not if the request failed for any other reason.
116            let mut body = String::new();
117            response.body_mut().read_to_string(&mut body).await?;
118            anyhow::bail!(
119                "error performing web search.\nStatus: {:?}\nBody: {body}",
120                response.status(),
121            );
122        }
123    }
124}