cloud.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result};
  4use client::Client;
  5use futures::AsyncReadExt as _;
  6use gpui::{App, AppContext, Context, Entity, Subscription, Task};
  7use http_client::{HttpClient, Method};
  8use language_model::{LlmApiToken, RefreshLlmTokenListener};
  9use web_search::{WebSearchProvider, WebSearchProviderId};
 10use zed_llm_client::{
 11    CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, WebSearchBody, WebSearchResponse,
 12};
 13
 14pub struct CloudWebSearchProvider {
 15    state: Entity<State>,
 16}
 17
 18impl CloudWebSearchProvider {
 19    pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
 20        let state = cx.new(|cx| State::new(client, cx));
 21
 22        Self { state }
 23    }
 24}
 25
 26pub struct State {
 27    client: Arc<Client>,
 28    llm_api_token: LlmApiToken,
 29    _llm_token_subscription: Subscription,
 30}
 31
 32impl State {
 33    pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
 34        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 35
 36        Self {
 37            client,
 38            llm_api_token: LlmApiToken::default(),
 39            _llm_token_subscription: cx.subscribe(
 40                &refresh_llm_token_listener,
 41                |this, _, _event, cx| {
 42                    let client = this.client.clone();
 43                    let llm_api_token = this.llm_api_token.clone();
 44                    cx.spawn(async move |_this, _cx| {
 45                        llm_api_token.refresh(&client).await?;
 46                        anyhow::Ok(())
 47                    })
 48                    .detach_and_log_err(cx);
 49                },
 50            ),
 51        }
 52    }
 53}
 54
 55pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
 56
 57impl WebSearchProvider for CloudWebSearchProvider {
 58    fn id(&self) -> WebSearchProviderId {
 59        WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
 60    }
 61
 62    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
 63        let state = self.state.read(cx);
 64        let client = state.client.clone();
 65        let llm_api_token = state.llm_api_token.clone();
 66        let body = WebSearchBody { query };
 67        cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
 68    }
 69}
 70
 71async fn perform_web_search(
 72    client: Arc<Client>,
 73    llm_api_token: LlmApiToken,
 74    body: WebSearchBody,
 75) -> Result<WebSearchResponse> {
 76    let http_client = &client.http_client();
 77
 78    let token = llm_api_token.acquire(&client).await?;
 79
 80    let request = http_client::Request::builder()
 81        .method(Method::POST)
 82        .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
 83        .header("Content-Type", "application/json")
 84        .header("Authorization", format!("Bearer {token}"))
 85        .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
 86        .body(serde_json::to_string(&body)?.into())?;
 87    let mut response = http_client
 88        .send(request)
 89        .await
 90        .context("failed to send web search request")?;
 91
 92    if response.status().is_success() {
 93        let mut body = String::new();
 94        response.body_mut().read_to_string(&mut body).await?;
 95        return Ok(serde_json::from_str(&body)?);
 96    } else {
 97        let mut body = String::new();
 98        response.body_mut().read_to_string(&mut body).await?;
 99        anyhow::bail!(
100            "error performing web search.\nStatus: {:?}\nBody: {body}",
101            response.status(),
102        );
103    }
104}