cloud.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result};
  4use client::Client;
  5use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag};
  6use futures::AsyncReadExt as _;
  7use gpui::{App, AppContext, Context, Entity, Subscription, Task};
  8use http_client::{HttpClient, Method};
  9use language_model::{LlmApiToken, RefreshLlmTokenListener};
 10use web_search::{WebSearchProvider, WebSearchProviderId};
 11use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
 12
 13pub struct CloudWebSearchProvider {
 14    state: Entity<State>,
 15}
 16
 17impl CloudWebSearchProvider {
 18    pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
 19        let state = cx.new(|cx| State::new(client, cx));
 20
 21        Self { state }
 22    }
 23}
 24
 25pub struct State {
 26    client: Arc<Client>,
 27    llm_api_token: LlmApiToken,
 28    _llm_token_subscription: Subscription,
 29}
 30
 31impl State {
 32    pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
 33        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 34
 35        Self {
 36            client,
 37            llm_api_token: LlmApiToken::default(),
 38            _llm_token_subscription: cx.subscribe(
 39                &refresh_llm_token_listener,
 40                |this, _, _event, cx| {
 41                    let client = this.client.clone();
 42                    let llm_api_token = this.llm_api_token.clone();
 43                    cx.spawn(async move |_this, _cx| {
 44                        llm_api_token.refresh(&client).await?;
 45                        anyhow::Ok(())
 46                    })
 47                    .detach_and_log_err(cx);
 48                },
 49            ),
 50        }
 51    }
 52}
 53
 54pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
 55
 56impl WebSearchProvider for CloudWebSearchProvider {
 57    fn id(&self) -> WebSearchProviderId {
 58        WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
 59    }
 60
 61    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
 62        let state = self.state.read(cx);
 63        let client = state.client.clone();
 64        let llm_api_token = state.llm_api_token.clone();
 65        let body = WebSearchBody { query };
 66        let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
 67        cx.background_spawn(async move {
 68            perform_web_search(client, llm_api_token, body, use_cloud).await
 69        })
 70    }
 71}
 72
 73async fn perform_web_search(
 74    client: Arc<Client>,
 75    llm_api_token: LlmApiToken,
 76    body: WebSearchBody,
 77    use_cloud: bool,
 78) -> Result<WebSearchResponse> {
 79    const MAX_RETRIES: usize = 3;
 80
 81    let http_client = &client.http_client();
 82    let mut retries_remaining = MAX_RETRIES;
 83    let mut token = llm_api_token.acquire(&client).await?;
 84
 85    loop {
 86        if retries_remaining == 0 {
 87            return Err(anyhow::anyhow!(
 88                "error performing web search, max retries exceeded"
 89            ));
 90        }
 91
 92        let request = http_client::Request::builder()
 93            .method(Method::POST)
 94            .uri(
 95                http_client
 96                    .build_zed_llm_url("/web_search", &[], use_cloud)?
 97                    .as_ref(),
 98            )
 99            .header("Content-Type", "application/json")
100            .header("Authorization", format!("Bearer {token}"))
101            .body(serde_json::to_string(&body)?.into())?;
102        let mut response = http_client
103            .send(request)
104            .await
105            .context("failed to send web search request")?;
106
107        if response.status().is_success() {
108            let mut body = String::new();
109            response.body_mut().read_to_string(&mut body).await?;
110            return Ok(serde_json::from_str(&body)?);
111        } else if response
112            .headers()
113            .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
114            .is_some()
115        {
116            token = llm_api_token.refresh(&client).await?;
117            retries_remaining -= 1;
118        } else {
119            // For now we will only retry if the LLM token is expired,
120            // not if the request failed for any other reason.
121            let mut body = String::new();
122            response.body_mut().read_to_string(&mut body).await?;
123            anyhow::bail!(
124                "error performing web search.\nStatus: {:?}\nBody: {body}",
125                response.status(),
126            );
127        }
128    }
129}