1use std::sync::Arc;
2
3use anyhow::{Context as _, Result};
4use client::Client;
5use futures::AsyncReadExt as _;
6use gpui::{App, AppContext, Context, Entity, Subscription, Task};
7use http_client::{HttpClient, Method};
8use language_model::{LlmApiToken, RefreshLlmTokenListener};
9use web_search::{WebSearchProvider, WebSearchProviderId};
10use zed_llm_client::{
11 CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, WebSearchBody, WebSearchResponse,
12};
13
14pub struct CloudWebSearchProvider {
15 state: Entity<State>,
16}
17
18impl CloudWebSearchProvider {
19 pub fn new(client: Arc<Client>, cx: &mut App) -> Self {
20 let state = cx.new(|cx| State::new(client, cx));
21
22 Self { state }
23 }
24}
25
26pub struct State {
27 client: Arc<Client>,
28 llm_api_token: LlmApiToken,
29 _llm_token_subscription: Subscription,
30}
31
32impl State {
33 pub fn new(client: Arc<Client>, cx: &mut Context<Self>) -> Self {
34 let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
35
36 Self {
37 client,
38 llm_api_token: LlmApiToken::default(),
39 _llm_token_subscription: cx.subscribe(
40 &refresh_llm_token_listener,
41 |this, _, _event, cx| {
42 let client = this.client.clone();
43 let llm_api_token = this.llm_api_token.clone();
44 cx.spawn(async move |_this, _cx| {
45 llm_api_token.refresh(&client).await?;
46 anyhow::Ok(())
47 })
48 .detach_and_log_err(cx);
49 },
50 ),
51 }
52 }
53}
54
55pub const ZED_WEB_SEARCH_PROVIDER_ID: &'static str = "zed.dev";
56
57impl WebSearchProvider for CloudWebSearchProvider {
58 fn id(&self) -> WebSearchProviderId {
59 WebSearchProviderId(ZED_WEB_SEARCH_PROVIDER_ID.into())
60 }
61
62 fn search(&self, query: String, cx: &mut App) -> Task<Result<WebSearchResponse>> {
63 let state = self.state.read(cx);
64 let client = state.client.clone();
65 let llm_api_token = state.llm_api_token.clone();
66 let body = WebSearchBody { query };
67 cx.background_spawn(async move { perform_web_search(client, llm_api_token, body).await })
68 }
69}
70
71async fn perform_web_search(
72 client: Arc<Client>,
73 llm_api_token: LlmApiToken,
74 body: WebSearchBody,
75) -> Result<WebSearchResponse> {
76 let http_client = &client.http_client();
77
78 let token = llm_api_token.acquire(&client).await?;
79
80 let request = http_client::Request::builder()
81 .method(Method::POST)
82 .uri(http_client.build_zed_llm_url("/web_search", &[])?.as_ref())
83 .header("Content-Type", "application/json")
84 .header("Authorization", format!("Bearer {token}"))
85 .header(CLIENT_SUPPORTS_EXA_WEB_SEARCH_PROVIDER_HEADER_NAME, "true")
86 .body(serde_json::to_string(&body)?.into())?;
87 let mut response = http_client
88 .send(request)
89 .await
90 .context("failed to send web search request")?;
91
92 if response.status().is_success() {
93 let mut body = String::new();
94 response.body_mut().read_to_string(&mut body).await?;
95 return Ok(serde_json::from_str(&body)?);
96 } else {
97 let mut body = String::new();
98 response.body_mut().read_to_string(&mut body).await?;
99 anyhow::bail!(
100 "error performing web search.\nStatus: {:?}\nBody: {body}",
101 response.status(),
102 );
103 }
104}