cloud.rs

  1use std::sync::Arc;
  2
  3use anyhow::{Context as _, Result};
  4use client::{Client, UserStore};
  5use cloud_api_types::OrganizationId;
  6use cloud_llm_client::{WebSearchBody, WebSearchResponse};
  7use futures::AsyncReadExt as _;
  8use gpui::{App, AppContext, Context, Entity, Subscription, Task};
  9use http_client::{HttpClient, Method};
 10use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
 11use web_search::{WebSearchProvider, WebSearchProviderId};
 12
 13pub struct CloudWebSearchProvider {
 14    state: Entity<State>,
 15}
 16
 17impl CloudWebSearchProvider {
 18    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) -> Self {
 19        let state = cx.new(|cx| State::new(client, user_store, cx));
 20
 21        Self { state }
 22    }
 23}
 24
 25pub struct State {
 26    client: Arc<Client>,
 27    user_store: Entity<UserStore>,
 28    llm_api_token: LlmApiToken,
 29    _llm_token_subscription: Subscription,
 30}
 31
 32impl State {
 33    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 34        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 35
 36        Self {
 37            client,
 38            user_store,
 39            llm_api_token: LlmApiToken::default(),
 40            _llm_token_subscription: cx.subscribe(
 41                &refresh_llm_token_listener,
 42                |this, _, _event, cx| {
 43                    let client = this.client.clone();
 44                    let llm_api_token = this.llm_api_token.clone();
 45                    let organization_id = this
 46                        .user_store
 47                        .read(cx)
 48                        .current_organization()
 49                        .map(|o| o.id.clone());
 50                    cx.spawn(async move |_this, _cx| {
 51                        llm_api_token.refresh(&client, organization_id).await?;
 52                        anyhow::Ok(())
 53                    })
 54                    .detach_and_log_err(cx);
 55                },
 56            ),
 57        }
 58    }
 59}
 60
 61pub const ZED_WEB_SEARCH_PROVIDER_ID: &str = "zed.dev";
 62
 63impl WebSearchProvider for CloudWebSearchProvider {
 64    fn id(&self) -> WebSearchProviderId {
 65        WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
 66    }
 67
 68    fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
 69        let state = self.state.read(cx);
 70        let client = state.client.clone();
 71        let llm_api_token = state.llm_api_token.clone();
 72        let organization_id = state
 73            .user_store
 74            .read(cx)
 75            .current_organization()
 76            .map(|o| o.id.clone());
 77        let body = WebSearchBody { query };
 78        cx.background_spawn(async move {
 79            perform_web_search(client, llm_api_token, organization_id, body).await
 80        })
 81    }
 82}
 83
 84async fn perform_web_search(
 85    client: Arc<Client>,
 86    llm_api_token: LlmApiToken,
 87    organization_id: Option<OrganizationId>,
 88    body: WebSearchBody,
 89) -> Result<WebSearchResponse> {
 90    const MAX_RETRIES: usize = 3;
 91
 92    let http_client = &client.http_client();
 93    let mut retries_remaining = MAX_RETRIES;
 94    let mut token = llm_api_token
 95        .acquire(&client, organization_id.clone())
 96        .await?;
 97
 98    loop {
 99        if retries_remaining == 0 {
100            return Err(anyhow::anyhow!(
101                "error performing web search, max retries exceeded"
102            ));
103        }
104
105        let request = http_client::Request::builder()
106            .method(Method::POST)
107            .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
108            .header("Content-Type", "application/json")
109            .header("Authorization", format!("Bearer {token}"))
110            .body(serde_json::to_string(&body)?.into())?;
111        let mut response = http_client
112            .send(request)
113            .await
114            .context("failed to send web search request")?;
115
116        if response.status().is_success() {
117            let mut body = String::new();
118            response.body_mut().read_to_string(&mut body).await?;
119            return Ok(serde_json::from_str(&body)?);
120        } else if response.needs_llm_token_refresh() {
121            token = llm_api_token
122                .refresh(&client, organization_id.clone())
123                .await?;
124            retries_remaining -= 1;
125        } else {
126            // For now we will only retry if the LLM token is expired,
127            // not if the request failed for any other reason.
128            let mut body = String::new();
129            response.body_mut().read_to_string(&mut body).await?;
130            anyhow::bail!(
131                "error performing web search.\nStatus: {:?}\nBody: {body}",
132                response.status(),
133            );
134        }
135    }
136}