cloud.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result, anyhow};
  4use client::Client;
  5use futures::AsyncReadExt as _;
  6use gpui::{App, AppContext, Context, Entity, Subscription, Task};
  7use http_client::{HttpClient, Method};
  8use language_model::{LlmApiToken, RefreshLlmTokenListener};
  9use web_search::{WebSearchProvider, WebSearchProviderId};
 10use zed_llm_client::{WebSearchBody, WebSearchResponse};
 11
 12pub struct CloudWebSearchProvider {
 13    state: Entity<State>,
 14}
 15
 16impl CloudWebSearchProvider {
 17    pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
 18        let state = cx.new(|cx| State::new(client, cx));
 19
 20        Self { state }
 21    }
 22}
 23
 24pub struct State {
 25    client: Arc<Client>,
 26    llm_api_token: LlmApiToken,
 27    _llm_token_subscription: Subscription,
 28}
 29
 30impl State {
 31    pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
 32        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 33
 34        Self {
 35            client,
 36            llm_api_token: LlmApiToken::default(),
 37            _llm_token_subscription: cx.subscribe(
 38                &refresh_llm_token_listener,
 39                |this, _, _event, cx| {
 40                    let client = this.client.clone();
 41                    let llm_api_token = this.llm_api_token.clone();
 42                    cx.spawn(async move |_this, _cx| {
 43                        llm_api_token.refresh(&client).await?;
 44                        anyhow::Ok(())
 45                    })
 46                    .detach_and_log_err(cx);
 47                },
 48            ),
 49        }
 50    }
 51}
 52
 53impl WebSearchProvider for CloudWebSearchProvider {
 54    fn id(&self) -> WebSearchProviderId {
 55        WebSearchProviderId("zed.dev".into())
 56    }
 57
 58    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
 59        let state = self.state.read(cx);
 60        let client = state.client.clone();
 61        let llm_api_token = state.llm_api_token.clone();
 62        let body = WebSearchBody { query };
 63        cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
 64    }
 65}
 66
 67async fn perform_web_search(
 68    client: Arc<Client>,
 69    llm_api_token: LlmApiToken,
 70    body: WebSearchBody,
 71) -> Result<WebSearchResponse> {
 72    let http_client = &client.http_client();
 73
 74    let token = llm_api_token.acquire(&client).await?;
 75
 76    let request_builder = http_client::Request::builder().method(Method::POST);
 77    let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
 78        request_builder.uri(web_search_url)
 79    } else {
 80        request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
 81    };
 82    let request = request_builder
 83        .header("Content-Type", "application/json")
 84        .header("Authorization", format!("Bearer {token}"))
 85        .body(serde_json::to_string(&body)?.into())?;
 86    let mut response = http_client
 87        .send(request)
 88        .await
 89        .context("failed to send web search request")?;
 90
 91    if response.status().is_success() {
 92        let mut body = String::new();
 93        response.body_mut().read_to_string(&mut body).await?;
 94        return Ok(serde_json::from_str(&body)?);
 95    } else {
 96        let mut body = String::new();
 97        response.body_mut().read_to_string(&mut body).await?;
 98        return Err(anyhow!(
 99            "error performing web search.\nStatus: {:?}\nBody: {body}",
100            response.status(),
101        ));
102    }
103}