1use std::sync::Arc;
2
3use anyhow::{Context as _, Result};
4use client::{Client, UserStore};
5use cloud_api_types::OrganizationId;
6use cloud_llm_client::{WebSearchBody, WebSearchResponse};
7use futures::AsyncReadExt as _;
8use gpui::{App, AppContext, Context, Entity, Subscription, Task};
9use http_client::{HttpClient, Method};
10use language_model::{LlmApiToken, NeedsLlmTokenRefresh, RefreshLlmTokenListener};
11use web_search::{WebSearchProvider, WebSearchProviderId};
12
13pub struct CloudWebSearchProvider {
14 state: Entity<State>,
15}
16
17impl CloudWebSearchProvider {
18 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) -> Self {
19 let state = cx.new(|cx| State::new(client, user_store, cx));
20
21 Self { state }
22 }
23}
24
25pub struct State {
26 client: Arc<Client>,
27 user_store: Entity<UserStore>,
28 llm_api_token: LlmApiToken,
29 _llm_token_subscription: Subscription,
30}
31
32impl State {
33 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
34 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
35
36 Self {
37 client,
38 user_store,
39 llm_api_token: LlmApiToken::default(),
40 _llm_token_subscription: cx.subscribe(
41 &refresh_llm_token_listener,
42 |this, _, _event, cx| {
43 let client = this.client.clone();
44 let llm_api_token = this.llm_api_token.clone();
45 let organization_id = this
46 .user_store
47 .read(cx)
48 .current_organization()
49 .map(|o| o.id.clone());
50 cx.spawn(async move |_this, _cx| {
51 llm_api_token.refresh(&client, organization_id).await?;
52 anyhow::Ok(())
53 })
54 .detach_and_log_err(cx);
55 },
56 ),
57 }
58 }
59}
60
61pub const ZED_WEB_SEARCH_PROVIDER_ID: &str = "zed.dev";
62
63impl WebSearchProvider for CloudWebSearchProvider {
64 fn id(&self) -> WebSearchProviderId {
65 WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
66 }
67
68 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
69 let state = self.state.read(cx);
70 let client = state.client.clone();
71 let llm_api_token = state.llm_api_token.clone();
72 let organization_id = state
73 .user_store
74 .read(cx)
75 .current_organization()
76 .map(|o| o.id.clone());
77 let body = WebSearchBody { query };
78 cx.background_spawn(async move {
79 perform_web_search(client, llm_api_token, organization_id, body).await
80 })
81 }
82}
83
84async fn perform_web_search(
85 client: Arc<Client>,
86 llm_api_token: LlmApiToken,
87 organization_id: Option<OrganizationId>,
88 body: WebSearchBody,
89) -> Result<WebSearchResponse> {
90 const MAX_RETRIES: usize = 3;
91
92 let http_client = &client.http_client();
93 let mut retries_remaining = MAX_RETRIES;
94 let mut token = llm_api_token
95 .acquire(&client, organization_id.clone())
96 .await?;
97
98 loop {
99 if retries_remaining == 0 {
100 return Err(anyhow::anyhow!(
101 "error performing web search, max retries exceeded"
102 ));
103 }
104
105 let request = http_client::Request::builder()
106 .method(Method::POST)
107 .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
108 .header("Content-Type", "application/json")
109 .header("Authorization", format!("Bearer {token}"))
110 .body(serde_json::to_string(&body)?.into())?;
111 let mut response = http_client
112 .send(request)
113 .await
114 .context("failed to send web search request")?;
115
116 if response.status().is_success() {
117 let mut body = String::new();
118 response.body_mut().read_to_string(&mut body).await?;
119 return Ok(serde_json::from_str(&body)?);
120 } else if response.needs_llm_token_refresh() {
121 token = llm_api_token
122 .refresh(&client, organization_id.clone())
123 .await?;
124 retries_remaining -= 1;
125 } else {
126 // For now we will only retry if the LLM token is expired,
127 // not if the request failed for any other reason.
128 let mut body = String::new();
129 response.body_mut().read_to_string(&mut body).await?;
130 anyhow::bail!(
131 "error performing web search.\nStatus: {:?}\nBody: {body}",
132 response.status(),
133 );
134 }
135 }
136}