cloud.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result};
  4use client::{Client, UserStore};
  5use cloud_api_types::OrganizationId;
  6use cloud_llm_client::{WebSearchBody, WebSearchResponse};
  7use futures::AsyncReadExt as _;
  8use gpui::{App, AppContext, Context, Entity, Task};
  9use http_client::{HttpClient, Method};
 10use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
 11use web_search::{WebSearchProvider, WebSearchProviderId};
 12
 13pub struct CloudWebSearchProvider {
 14    state: Entity<State>,
 15}
 16
 17impl CloudWebSearchProvider {
 18    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) -> Self {
 19        let state = cx.new(|cx| State::new(client, user_store, cx));
 20
 21        Self { state }
 22    }
 23}
 24
 25pub struct State {
 26    client: Arc<Client>,
 27    user_store: Entity<UserStore>,
 28    llm_api_token: LlmApiToken,
 29}
 30
 31impl State {
 32    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 33        let llm_api_token = LlmApiToken::global(cx);
 34
 35        Self {
 36            client,
 37            user_store,
 38            llm_api_token,
 39        }
 40    }
 41}
 42
 43pub const ZED_WEB_SEARCH_PROVIDER_ID: &str = "zed.dev";
 44
 45impl WebSearchProvider for CloudWebSearchProvider {
 46    fn id(&self) -> WebSearchProviderId {
 47        WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
 48    }
 49
 50    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
 51        let state = self.state.read(cx);
 52        let client = state.client.clone();
 53        let llm_api_token = state.llm_api_token.clone();
 54        let organization_id = state
 55            .user_store
 56            .read(cx)
 57            .current_organization()
 58            .map(|organization| organization.id.clone());
 59        let body = WebSearchBody { query };
 60        cx.background_spawn(async move {
 61            perform_web_search(client, llm_api_token, organization_id, body).await
 62        })
 63    }
 64}
 65
 66async fn perform_web_search(
 67    client: Arc<Client>,
 68    llm_api_token: LlmApiToken,
 69    organization_id: Option<OrganizationId>,
 70    body: WebSearchBody,
 71) -> Result<WebSearchResponse> {
 72    const MAX_RETRIES: usize = 3;
 73
 74    let http_client = &client.http_client();
 75    let mut retries_remaining = MAX_RETRIES;
 76    let mut token = llm_api_token
 77        .acquire(&client, organization_id.clone())
 78        .await?;
 79
 80    loop {
 81        if retries_remaining == 0 {
 82            return Err(anyhow::anyhow!(
 83                "error performing web search, max retries exceeded"
 84            ));
 85        }
 86
 87        let request = http_client::Request::builder()
 88            .method(Method::POST)
 89            .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
 90            .header("Content-Type", "application/json")
 91            .header("Authorization", format!("Bearer {token}"))
 92            .body(serde_json::to_string(&body)?.into())?;
 93        let mut response = http_client
 94            .send(request)
 95            .await
 96            .context("failed to send web search request")?;
 97
 98        if response.status().is_success() {
 99            let mut body = String::new();
100            response.body_mut().read_to_string(&mut body).await?;
101            return Ok(serde_json::from_str(&body)?);
102        } else if response.needs_llm_token_refresh() {
103            token = llm_api_token
104                .refresh(&client, organization_id.clone())
105                .await?;
106            retries_remaining -= 1;
107        } else {
108            // For now we will only retry if the LLM token is expired,
109            // not if the request failed for any other reason.
110            let mut body = String::new();
111            response.body_mut().read_to_string(&mut body).await?;
112            anyhow::bail!(
113                "error performing web search.\nStatus: {:?}\nBody: {body}",
114                response.status(),
115            );
116        }
117    }
118}