1use std::sync::Arc;
2
3use anyhow::{Context as _, Result, anyhow};
4use client::Client;
5use futures::AsyncReadExt as _;
6use gpui::{App, AppContext, Context, Entity, Subscription, Task};
7use http_client::{HttpClient, Method};
8use language_model::{LlmApiToken, RefreshLlmTokenListener};
9use web_search::{WebSearchProvider, WebSearchProviderId};
10use zed_llm_client::{WebSearchBody, WebSearchResponse};
11
12pub struct CloudWebSearchProvider {
13 state: Entity<State>,
14}
15
16impl CloudWebSearchProvider {
17 pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
18 let state = cx.new(|cx| State::new(client, cx));
19
20 Self { state }
21 }
22}
23
24pub struct State {
25 client: Arc<Client>,
26 llm_api_token: LlmApiToken,
27 _llm_token_subscription: Subscription,
28}
29
30impl State {
31 pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
32 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
33
34 Self {
35 client,
36 llm_api_token: LlmApiToken::default(),
37 _llm_token_subscription: cx.subscribe(
38 &refresh_llm_token_listener,
39 |this, _, _event, cx| {
40 let client = this.client.clone();
41 let llm_api_token = this.llm_api_token.clone();
42 cx.spawn(async move |_this, _cx| {
43 llm_api_token.refresh(&client).await?;
44 anyhow::Ok(())
45 })
46 .detach_and_log_err(cx);
47 },
48 ),
49 }
50 }
51}
52
53impl WebSearchProvider for CloudWebSearchProvider {
54 fn id(&self) -> WebSearchProviderId {
55 WebSearchProviderId("zed.dev".into())
56 }
57
58 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
59 let state = self.state.read(cx);
60 let client = state.client.clone();
61 let llm_api_token = state.llm_api_token.clone();
62 let body = WebSearchBody { query };
63 cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
64 }
65}
66
67async fn perform_web_search(
68 client: Arc<Client>,
69 llm_api_token: LlmApiToken,
70 body: WebSearchBody,
71) -> Result<WebSearchResponse> {
72 let http_client = &client.http_client();
73
74 let token = llm_api_token.acquire(&client).await?;
75
76 let request_builder = http_client::Request::builder().method(Method::POST);
77 let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
78 request_builder.uri(web_search_url)
79 } else {
80 request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
81 };
82 let request = request_builder
83 .header("Content-Type", "application/json")
84 .header("Authorization", format!("Bearer {token}"))
85 .body(serde_json::to_string(&body)?.into())?;
86 let mut response = http_client
87 .send(request)
88 .await
89 .context("failed to send web search request")?;
90
91 if response.status().is_success() {
92 let mut body = String::new();
93 response.body_mut().read_to_string(&mut body).await?;
94 return Ok(serde_json::from_str(&body)?);
95 } else {
96 let mut body = String::new();
97 response.body_mut().read_to_string(&mut body).await?;
98 return Err(anyhow!(
99 "error performing web search.\nStatus: {:?}\nBody: {body}",
100 response.status(),
101 ));
102 }
103}