1use std::sync::Arc;
2
3use anyhow::{Context as _, Result};
4use client::Client;
5use feature_flags::{FeatureFlagAppExt as _, ZedCloudFeatureFlag};
6use futures::AsyncReadExt as _;
7use gpui::{App, AppContext, Context, Entity, Subscription, Task};
8use http_client::{HttpClient, Method};
9use language_model::{LlmApiToken, RefreshLlmTokenListener};
10use web_search::{WebSearchProvider, WebSearchProviderId};
11use zed_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
12
13pub struct CloudWebSearchProvider {
14 state: Entity<State>,
15}
16
17impl CloudWebSearchProvider {
18 pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
19 let state = cx.new(|cx| State::new(client, cx));
20
21 Self { state }
22 }
23}
24
25pub struct State {
26 client: Arc<Client>,
27 llm_api_token: LlmApiToken,
28 _llm_token_subscription: Subscription,
29}
30
31impl State {
32 pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
33 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
34
35 Self {
36 client,
37 llm_api_token: LlmApiToken::default(),
38 _llm_token_subscription: cx.subscribe(
39 &refresh_llm_token_listener,
40 |this, _, _event, cx| {
41 let client = this.client.clone();
42 let llm_api_token = this.llm_api_token.clone();
43 cx.spawn(async move |_this, _cx| {
44 llm_api_token.refresh(&client).await?;
45 anyhow::Ok(())
46 })
47 .detach_and_log_err(cx);
48 },
49 ),
50 }
51 }
52}
53
54pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
55
56impl WebSearchProvider for CloudWebSearchProvider {
57 fn id(&self) -> WebSearchProviderId {
58 WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
59 }
60
61 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
62 let state = self.state.read(cx);
63 let client = state.client.clone();
64 let llm_api_token = state.llm_api_token.clone();
65 let body = WebSearchBody { query };
66 let use_cloud = cx.has_flag::<ZedCloudFeatureFlag>();
67 cx.background_spawn(async move {
68 perform_web_search(client, llm_api_token, body, use_cloud).await
69 })
70 }
71}
72
73async fn perform_web_search(
74 client: Arc<Client>,
75 llm_api_token: LlmApiToken,
76 body: WebSearchBody,
77 use_cloud: bool,
78) -> Result<WebSearchResponse> {
79 const MAX_RETRIES: usize = 3;
80
81 let http_client = &client.http_client();
82 let mut retries_remaining = MAX_RETRIES;
83 let mut token = llm_api_token.acquire(&client).await?;
84
85 loop {
86 if retries_remaining == 0 {
87 return Err(anyhow::anyhow!(
88 "error performing web search, max retries exceeded"
89 ));
90 }
91
92 let request = http_client::Request::builder()
93 .method(Method::POST)
94 .uri(
95 http_client
96 .build_zed_llm_url("/web_search", &[], use_cloud)?
97 .as_ref(),
98 )
99 .header("Content-Type", "application/json")
100 .header("Authorization", format!("Bearer {token}"))
101 .body(serde_json::to_string(&body)?.into())?;
102 let mut response = http_client
103 .send(request)
104 .await
105 .context("failed to send web search request")?;
106
107 if response.status().is_success() {
108 let mut body = String::new();
109 response.body_mut().read_to_string(&mut body).await?;
110 return Ok(serde_json::from_str(&body)?);
111 } else if response
112 .headers()
113 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
114 .is_some()
115 {
116 token = llm_api_token.refresh(&client).await?;
117 retries_remaining -= 1;
118 } else {
119 // For now we will only retry if the LLM token is expired,
120 // not if the request failed for any other reason.
121 let mut body = String::new();
122 response.body_mut().read_to_string(&mut body).await?;
123 anyhow::bail!(
124 "error performing web search.\nStatus: {:?}\nBody: {body}",
125 response.status(),
126 );
127 }
128 }
129}