cloud.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result};
  4use client::Client;
  5use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
  6use futures::AsyncReadExt as _;
  7use gpui::{App, AppContext, Context, Entity, Subscription, Task};
  8use http_client::{HttpClient, Method};
  9use language_model::{LlmApiToken, RefreshLlmTokenListener};
 10use web_search::{WebSearchProvider, WebSearchProviderId};
 11
 12pub struct CloudWebSearchProvider {
 13    state: Entity<State>,
 14}
 15
 16impl CloudWebSearchProvider {
 17    pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
 18        let state = cx.new(|cx| State::new(client, cx));
 19
 20        Self { state }
 21    }
 22}
 23
 24pub struct State {
 25    client: Arc<Client>,
 26    llm_api_token: LlmApiToken,
 27    _llm_token_subscription: Subscription,
 28}
 29
 30impl State {
 31    pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
 32        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 33
 34        Self {
 35            client,
 36            llm_api_token: LlmApiToken::default(),
 37            _llm_token_subscription: cx.subscribe(
 38                &refresh_llm_token_listener,
 39                |this, _, _event, cx| {
 40                    let client = this.client.clone();
 41                    let llm_api_token = this.llm_api_token.clone();
 42                    cx.spawn(async move |_this, _cx| {
 43                        llm_api_token.refresh(&client).await?;
 44                        anyhow::Ok(())
 45                    })
 46                    .detach_and_log_err(cx);
 47                },
 48            ),
 49        }
 50    }
 51}
 52
 53pub const ZED_WEB_SEARCH_PROVIDER_ID: &str = "zed.dev";
 54
 55impl WebSearchProvider for CloudWebSearchProvider {
 56    fn id(&self) -> WebSearchProviderId {
 57        WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
 58    }
 59
 60    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
 61        let state = self.state.read(cx);
 62        let client = state.client.clone();
 63        let llm_api_token = state.llm_api_token.clone();
 64        let body = WebSearchBody { query };
 65        cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
 66    }
 67}
 68
 69async fn perform_web_search(
 70    client: Arc<Client>,
 71    llm_api_token: LlmApiToken,
 72    body: WebSearchBody,
 73) -> Result<WebSearchResponse> {
 74    const MAX_RETRIES: usize = 3;
 75
 76    let http_client = &client.http_client();
 77    let mut retries_remaining = MAX_RETRIES;
 78    let mut token = llm_api_token.acquire(&client).await?;
 79
 80    loop {
 81        if retries_remaining == 0 {
 82            return Err(anyhow::anyhow!(
 83                "error performing web search, max retries exceeded"
 84            ));
 85        }
 86
 87        let request = http_client::Request::builder()
 88            .method(Method::POST)
 89            .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
 90            .header("Content-Type", "application/json")
 91            .header("Authorization", format!("Bearer {token}"))
 92            .body(serde_json::to_string(&body)?.into())?;
 93        let mut response = http_client
 94            .send(request)
 95            .await
 96            .context("failed to send web search request")?;
 97
 98        if response.status().is_success() {
 99            let mut body = String::new();
100            response.body_mut().read_to_string(&mut body).await?;
101            return Ok(serde_json::from_str(&body)?);
102        } else if response
103            .headers()
104            .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
105            .is_some()
106        {
107            token = llm_api_token.refresh(&client).await?;
108            retries_remaining -= 1;
109        } else {
110            // For now we will only retry if the LLM token is expired,
111            // not if the request failed for any other reason.
112            let mut body = String::new();
113            response.body_mut().read_to_string(&mut body).await?;
114            anyhow::bail!(
115                "error performing web search.\nStatus: {:?}\nBody: {body}",
116                response.status(),
117            );
118        }
119    }
120}