1use std::sync::Arc;
2
3use anyhow::{Context as _, Result};
4use client::{Client, UserStore};
5use cloud_api_types::OrganizationId;
6use cloud_llm_client::{WebSearchBody, WebSearchResponse};
7use futures::AsyncReadExt as _;
8use gpui::{App, AppContext, Context, Entity, Task};
9use http_client::{HttpClient, Method};
10use language_model::{LlmApiToken, NeedsLlmTokenRefresh};
11use web_search::{WebSearchProvider, WebSearchProviderId};
12
13pub struct CloudWebSearchProvider {
14 state: Entity<State>,
15}
16
17impl CloudWebSearchProvider {
18 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) -> Self {
19 let state = cx.new(|cx| State::new(client, user_store, cx));
20
21 Self { state }
22 }
23}
24
25pub struct State {
26 client: Arc<Client>,
27 user_store: Entity<UserStore>,
28 llm_api_token: LlmApiToken,
29}
30
31impl State {
32 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
33 let llm_api_token = LlmApiToken::global(cx);
34
35 Self {
36 client,
37 user_store,
38 llm_api_token,
39 }
40 }
41}
42
43pub const ZED_WEB_SEARCH_PROVIDER_ID: &str = "zed.dev";
44
45impl WebSearchProvider for CloudWebSearchProvider {
46 fn id(&self) -> WebSearchProviderId {
47 WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
48 }
49
50 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
51 let state = self.state.read(cx);
52 let client = state.client.clone();
53 let llm_api_token = state.llm_api_token.clone();
54 let organization_id = state
55 .user_store
56 .read(cx)
57 .current_organization()
58 .map(|organization| organization.id.clone());
59 let body = WebSearchBody { query };
60 cx.background_spawn(async move {
61 perform_web_search(client, llm_api_token, organization_id, body).await
62 })
63 }
64}
65
66async fn perform_web_search(
67 client: Arc<Client>,
68 llm_api_token: LlmApiToken,
69 organization_id: Option<OrganizationId>,
70 body: WebSearchBody,
71) -> Result<WebSearchResponse> {
72 const MAX_RETRIES: usize = 3;
73
74 let http_client = &client.http_client();
75 let mut retries_remaining = MAX_RETRIES;
76 let mut token = llm_api_token
77 .acquire(&client, organization_id.clone())
78 .await?;
79
80 loop {
81 if retries_remaining == 0 {
82 return Err(anyhow::anyhow!(
83 "error performing web search, max retries exceeded"
84 ));
85 }
86
87 let request = http_client::Request::builder()
88 .method(Method::POST)
89 .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
90 .header("Content-Type", "application/json")
91 .header("Authorization", format!("Bearer {token}"))
92 .body(serde_json::to_string(&body)?.into())?;
93 let mut response = http_client
94 .send(request)
95 .await
96 .context("failed to send web search request")?;
97
98 if response.status().is_success() {
99 let mut body = String::new();
100 response.body_mut().read_to_string(&mut body).await?;
101 return Ok(serde_json::from_str(&body)?);
102 } else if response.needs_llm_token_refresh() {
103 token = llm_api_token
104 .refresh(&client, organization_id.clone())
105 .await?;
106 retries_remaining -= 1;
107 } else {
108 // For now we will only retry if the LLM token is expired,
109 // not if the request failed for any other reason.
110 let mut body = String::new();
111 response.body_mut().read_to_string(&mut body).await?;
112 anyhow::bail!(
113 "error performing web search.\nStatus: {:?}\nBody: {body}",
114 response.status(),
115 );
116 }
117 }
118}