1use std::sync::Arc;
2
3use anyhow::{Context as _, Result, anyhow};
4use client::Client;
5use futures::AsyncReadExt as _;
6use gpui::{App, AppContext, Context, Entity, Subscription, Task};
7use http_client::{HttpClient, Method};
8use language_model::{LlmApiToken, RefreshLlmTokenListener};
9use web_search::{WebSearchProvider, WebSearchProviderId};
10use zed_llm_client::{WebSearchBody, WebSearchResponse};
11
12pub struct CloudWebSearchProvider {
13 state: Entity<State>,
14}
15
16impl CloudWebSearchProvider {
17 pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
18 let state = cx.new(|cx| State::new(client, cx));
19
20 Self { state }
21 }
22}
23
24pub struct State {
25 client: Arc<Client>,
26 llm_api_token: LlmApiToken,
27 _llm_token_subscription: Subscription,
28}
29
30impl State {
31 pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
32 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
33
34 Self {
35 client,
36 llm_api_token: LlmApiToken::default(),
37 _llm_token_subscription: cx.subscribe(
38 &refresh_llm_token_listener,
39 |this, _, _event, cx| {
40 let client = this.client.clone();
41 let llm_api_token = this.llm_api_token.clone();
42 cx.spawn(async move |_this, _cx| {
43 llm_api_token.refresh(&client).await?;
44 anyhow::Ok(())
45 })
46 .detach_and_log_err(cx);
47 },
48 ),
49 }
50 }
51}
52
53pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
54
55impl WebSearchProvider for CloudWebSearchProvider {
56 fn id(&self) -> WebSearchProviderId {
57 WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
58 }
59
60 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
61 let state = self.state.read(cx);
62 let client = state.client.clone();
63 let llm_api_token = state.llm_api_token.clone();
64 let body = WebSearchBody { query };
65 cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
66 }
67}
68
69async fn perform_web_search(
70 client: Arc<Client>,
71 llm_api_token: LlmApiToken,
72 body: WebSearchBody,
73) -> Result<WebSearchResponse> {
74 let http_client = &client.http_client();
75
76 let token = llm_api_token.acquire(&client).await?;
77
78 let request_builder = http_client::Request::builder().method(Method::POST);
79 let request_builder = if let Ok(web_search_url) = std::env::var("ZED_WEB_SEARCH_URL") {
80 request_builder.uri(web_search_url)
81 } else {
82 request_builder.uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
83 };
84 let request = request_builder
85 .header("Content-Type", "application/json")
86 .header("Authorization", format!("Bearer {token}"))
87 .body(serde_json::to_string(&body)?.into())?;
88 let mut response = http_client
89 .send(request)
90 .await
91 .context("failed to send web search request")?;
92
93 if response.status().is_success() {
94 let mut body = String::new();
95 response.body_mut().read_to_string(&mut body).await?;
96 return Ok(serde_json::from_str(&body)?);
97 } else {
98 let mut body = String::new();
99 response.body_mut().read_to_string(&mut body).await?;
100 return Err(anyhow!(
101 "error performing web search.\nStatus: {:?}\nBody: {body}",
102 response.status(),
103 ));
104 }
105}