1use std::sync::Arc;
2
3use crate::{AgentTool, ToolCallEventStream};
4use agent_client_protocol as acp;
5use anyhow::{Result, anyhow};
6use cloud_llm_client::WebSearchResponse;
7use futures::FutureExt as _;
8use gpui::{App, AppContext, Task};
9use language_model::{
10 LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
11};
12use schemars::JsonSchema;
13use serde::{Deserialize, Serialize};
14use ui::prelude::*;
15use web_search::WebSearchRegistry;
16
17/// Search the web for information using your query.
18/// Use this when you need real-time information, facts, or data that might not be in your training.
19/// Results will include snippets and links from relevant web pages.
20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
21pub struct WebSearchToolInput {
22 /// The search term or question to query on the web.
23 query: String,
24}
25
26#[derive(Debug, Serialize, Deserialize)]
27#[serde(transparent)]
28pub struct WebSearchToolOutput(WebSearchResponse);
29
30impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
31 fn from(value: WebSearchToolOutput) -> Self {
32 serde_json::to_string(&value.0)
33 .expect("Failed to serialize WebSearchResponse")
34 .into()
35 }
36}
37
38pub struct WebSearchTool;
39
40impl AgentTool for WebSearchTool {
41 type Input = WebSearchToolInput;
42 type Output = WebSearchToolOutput;
43
44 fn name() -> &'static str {
45 "web_search"
46 }
47
48 fn kind() -> acp::ToolKind {
49 acp::ToolKind::Fetch
50 }
51
52 fn initial_title(
53 &self,
54 _input: Result<Self::Input, serde_json::Value>,
55 _cx: &mut App,
56 ) -> SharedString {
57 "Searching the Web".into()
58 }
59
60 /// We currently only support Zed Cloud as a provider.
61 fn supports_provider(provider: &LanguageModelProviderId) -> bool {
62 provider == &ZED_CLOUD_PROVIDER_ID
63 }
64
65 fn run(
66 self: Arc<Self>,
67 input: Self::Input,
68 event_stream: ToolCallEventStream,
69 cx: &mut App,
70 ) -> Task<Result<Self::Output>> {
71 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
72 return Task::ready(Err(anyhow!("Web search is not available.")));
73 };
74
75 let search_task = provider.search(input.query, cx);
76 cx.background_spawn(async move {
77 let response = futures::select! {
78 result = search_task.fuse() => {
79 match result {
80 Ok(response) => response,
81 Err(err) => {
82 event_stream
83 .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
84 return Err(err);
85 }
86 }
87 }
88 _ = event_stream.cancelled_by_user().fuse() => {
89 anyhow::bail!("Web search cancelled by user");
90 }
91 };
92
93 emit_update(&response, &event_stream);
94 Ok(WebSearchToolOutput(response))
95 })
96 }
97
98 fn replay(
99 &self,
100 _input: Self::Input,
101 output: Self::Output,
102 event_stream: ToolCallEventStream,
103 _cx: &mut App,
104 ) -> Result<()> {
105 emit_update(&output.0, &event_stream);
106 Ok(())
107 }
108}
109
110fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
111 let result_text = if response.results.len() == 1 {
112 "1 result".to_string()
113 } else {
114 format!("{} results", response.results.len())
115 };
116 event_stream.update_fields(
117 acp::ToolCallUpdateFields::new()
118 .title(format!("Searched the web: {result_text}"))
119 .content(
120 response
121 .results
122 .iter()
123 .map(|result| {
124 acp::ToolCallContent::Content(acp::Content::new(
125 acp::ContentBlock::ResourceLink(
126 acp::ResourceLink::new(result.title.clone(), result.url.clone())
127 .title(result.title.clone())
128 .description(result.text.clone()),
129 ),
130 ))
131 })
132 .collect::<Vec<_>>(),
133 ),
134 );
135}