cloud.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use client::Client;
  5use futures::AsyncReadExt as _;
  6use gpui::{App, AppContext, Context, Entity, Subscription, Task};
  7use http_client::{HttpClient, Method};
  8use language_model::{LlmApiToken, RefreshLlmTokenListener};
  9use web_search::{WebSearchProvider, WebSearchProviderId};
 10use zed_llm_client::{WebSearchBody, WebSearchResponse};
 11
 12pub struct CloudWebSearchProvider {
 13    state: Entity<State>,
 14}
 15
 16impl CloudWebSearchProvider {
 17    pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
 18        let state = cx.new(|cx| State::new(client, cx));
 19
 20        Self { state }
 21    }
 22}
 23
 24pub struct State {
 25    client: Arc<Client>,
 26    llm_api_token: LlmApiToken,
 27    _llm_token_subscription: Subscription,
 28}
 29
 30impl State {
 31    pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
 32        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 33
 34        Self {
 35            client,
 36            llm_api_token: LlmApiToken::default(),
 37            _llm_token_subscription: cx.subscribe(
 38                &refresh_llm_token_listener,
 39                |this, _, _event, cx| {
 40                    let client = this.client.clone();
 41                    let llm_api_token = this.llm_api_token.clone();
 42                    cx.spawn(async move |_this, _cx| {
 43                        llm_api_token.refresh(&client).await?;
 44                        anyhow::Ok(())
 45                    })
 46                    .detach_and_log_err(cx);
 47                },
 48            ),
 49        }
 50    }
 51}
 52
 53pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
 54
 55impl WebSearchProvider for CloudWebSearchProvider {
 56    fn id(&self) -> WebSearchProviderId {
 57        WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
 58    }
 59
 60    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
 61        let state = self.state.read(cx);
 62        let client = state.client.clone();
 63        let llm_api_token = state.llm_api_token.clone();
 64        let body = WebSearchBody { query };
 65        cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
 66    }
 67}
 68
 69async fn perform_web_search(
 70    client: Arc<Client>,
 71    llm_api_token: LlmApiToken,
 72    body: WebSearchBody,
 73) -> Result<WebSearchResponse> {
 74    let http_client = &client.http_client();
 75
 76    let token = llm_api_token.acquire(&client).await?;
 77
 78    let request_builder = http_client::Request::builder().method(Method::POST);
 79    let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
 80        request_builder.uri(web_search_url)
 81    } else {
 82        request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
 83    };
 84    let request = request_builder
 85        .header("Content-Type", "application/json")
 86        .header("Authorization", format!("Bearer {token}"))
 87        .body(serde_json::to_string(&body)?.into())?;
 88    let mut response = http_client
 89        .send(request)
 90        .await
 91        .context("failed to send web search request")?;
 92
 93    if response.status().is_success() {
 94        let mut body = String::new();
 95        response.body_mut().read_to_string(&mut body).await?;
 96        return Ok(serde_json::from_str(&body)?);
 97    } else {
 98        let mut body = String::new();
 99        response.body_mut().read_to_string(&mut body).await?;
100        return Err(anyhow!(
101            "error performing web search.\nStatus: {:?}\nBody: {body}",
102            response.status(),
103        ));
104    }
105}