1use std::sync::Arc;
2
3use anyhow::{Context as _, Result};
4use client::Client;
5use cloud_llm_client::{EXPIRED_LLM_TOKEN_HEADER_NAME, WebSearchBody, WebSearchResponse};
6use futures::AsyncReadExt as _;
7use gpui::{App, AppContext, Context, Entity, Subscription, Task};
8use http_client::{HttpClient, Method};
9use language_model::{LlmApiToken, RefreshLlmTokenListener};
10use web_search::{WebSearchProvider, WebSearchProviderId};
11
12pub struct CloudWebSearchProvider {
13 state: Entity<State>,
14}
15
16impl CloudWebSearchProvider {
17 pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
18 let state = cx.new(|cx| State::new(client, cx));
19
20 Self { state }
21 }
22}
23
24pub struct State {
25 client: Arc<Client>,
26 llm_api_token: LlmApiToken,
27 _llm_token_subscription: Subscription,
28}
29
30impl State {
31 pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
32 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
33
34 Self {
35 client,
36 llm_api_token: LlmApiToken::default(),
37 _llm_token_subscription: cx.subscribe(
38 &refresh_llm_token_listener,
39 |this, _, _event, cx| {
40 let client = this.client.clone();
41 let llm_api_token = this.llm_api_token.clone();
42 cx.spawn(async move |_this, _cx| {
43 llm_api_token.refresh(&client).await?;
44 anyhow::Ok(())
45 })
46 .detach_and_log_err(cx);
47 },
48 ),
49 }
50 }
51}
52
53pub const ZED_WEB_SEARCH_PROVIDER_ID: &str = "zed.dev";
54
55impl WebSearchProvider for CloudWebSearchProvider {
56 fn id(&self) -> WebSearchProviderId {
57 WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
58 }
59
60 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
61 let state = self.state.read(cx);
62 let client = state.client.clone();
63 let llm_api_token = state.llm_api_token.clone();
64 let body = WebSearchBody { query };
65 cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
66 }
67}
68
69async fn perform_web_search(
70 client: Arc<Client>,
71 llm_api_token: LlmApiToken,
72 body: WebSearchBody,
73) -> Result<WebSearchResponse> {
74 const MAX_RETRIES: usize = 3;
75
76 let http_client = &client.http_client();
77 let mut retries_remaining = MAX_RETRIES;
78 let mut token = llm_api_token.acquire(&client).await?;
79
80 loop {
81 if retries_remaining == 0 {
82 return Err(anyhow::anyhow!(
83 "error performing web search, max retries exceeded"
84 ));
85 }
86
87 let request = http_client::Request::builder()
88 .method(Method::POST)
89 .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
90 .header("Content-Type", "application/json")
91 .header("Authorization", format!("Bearer {token}"))
92 .body(serde_json::to_string(&body)?.into())?;
93 let mut response = http_client
94 .send(request)
95 .await
96 .context("failed to send web search request")?;
97
98 if response.status().is_success() {
99 let mut body = String::new();
100 response.body_mut().read_to_string(&mut body).await?;
101 return Ok(serde_json::from_str(&body)?);
102 } else if response
103 .headers()
104 .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
105 .is_some()
106 {
107 token = llm_api_token.refresh(&client).await?;
108 retries_remaining -= 1;
109 } else {
110 // For now we will only retry if the LLM token is expired,
111 // not if the request failed for any other reason.
112 let mut body = String::new();
113 response.body_mut().read_to_string(&mut body).await?;
114 anyhow::bail!(
115 "error performing web search.\nStatus: {:?}\nBody: {body}",
116 response.status(),
117 );
118 }
119 }
120}