1use std::sync::Arc;
2
3use anyhow::{Context as _, Result};
4use client::Client;
5use futures::AsyncReadExt as _;
6use gpui::{App, AppContext, Context, Entity, Subscription, Task};
7use http_client::{HttpClient, Method};
8use language_model::{LlmApiToken, RefreshLlmTokenListener};
9use web_search::{WebSearchProvider, WebSearchProviderId};
10use zed_llm_client::{
11 CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, EXPIRED_LLM_TOKEN_HEADER_NAME,
12 WebSearchBody, WebSearchResponse,
13};
14
15pub struct CloudWebSearchProvider {
16 state: Entity<State>,
17}
18
19impl CloudWebSearchProvider {
20 pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
21 let state = cx.new(|cx| State::new(client, cx));
22
23 Self { state }
24 }
25}
26
27pub struct State {
28 client: Arc<Client>,
29 llm_api_token: LlmApiToken,
30 _llm_token_subscription: Subscription,
31}
32
33impl State {
34 pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
35 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
36
37 Self {
38 client,
39 llm_api_token: LlmApiToken::default(),
40 _llm_token_subscription: cx.subscribe(
41 &refresh_llm_token_listener,
42 |this, _, _event, cx| {
43 let client = this.client.clone();
44 let llm_api_token = this.llm_api_token.clone();
45 cx.spawn(async move |_this, _cx| {
46 llm_api_token.refresh(&client).await?;
47 anyhow::Ok(())
48 })
49 .detach_and_log_err(cx);
50 },
51 ),
52 }
53 }
54}
55
56pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
57
58impl WebSearchProvider for CloudWebSearchProvider {
59 fn id(&self) -> WebSearchProviderId {
60 WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
61 }
62
63 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
64 let state = self.state.read(cx);
65 let client = state.client.clone();
66 let llm_api_token = state.llm_api_token.clone();
67 let body = WebSearchBody { query };
68 cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
69 }
70}
71
72async fn perform_web_search(
73 client: Arc<Client>,
74 llm_api_token: LlmApiToken,
75 body: WebSearchBody,
76) -> Result<WebSearchResponse> {
77 const MAX_RETRIES: usize = 3;
78
79 let http_client = &client.http_client();
80 let mut retries_remaining = MAX_RETRIES;
81 let mut token = llm_api_token.acquire(&client).await?;
82
83 loop {
84 if retries_remaining == 0 {
85 return Err(anyhow::anyhow!(
86 "error performing web search, max retries exceeded"
87 ));
88 }
89
90 let request = http_client::Request::builder()
91 .method(Method::POST)
92 .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
93 .header("Content-Type", "application/json")
94 .header("Authorization", format!("Bearer {token}"))
95 .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
96 .body(serde_json::to_string(&body)?.into())?;
97 let mut response = http_client
98 .send(request)
99 .await
100 .context("failed to send web search request")?;
101
102 if response.status().is_success() {
103 let mut body = String::new();
104 response.body_mut().read_to_string(&mut body).await?;
105 return Ok(serde_json::from_str(&body)?);
106 } else if response
107 .headers()
108 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
109 .is_some()
110 {
111 token = llm_api_token.refresh(&client).await?;
112 retries_remaining -= 1;
113 } else {
114 // For now we will only retry if the LLM token is expired,
115 // not if the request failed for any other reason.
116 let mut body = String::new();
117 response.body_mut().read_to_string(&mut body).await?;
118 anyhow::bail!(
119 "error performing web search.\nStatus: {:?}\nBody: {body}",
120 response.status(),
121 );
122 }
123 }
124}