1use std::sync::Arc;
2
3use crate::{
4 AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision,
5 decide_permission_from_settings,
6};
7use agent_client_protocol as acp;
8use agent_settings::AgentSettings;
9use anyhow::Result;
10use cloud_llm_client::WebSearchResponse;
11use futures::FutureExt as _;
12use gpui::{App, Task};
13use language_model::{
14 LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
15};
16use schemars::JsonSchema;
17use serde::{Deserialize, Serialize};
18use settings::Settings;
19use ui::prelude::*;
20use util::markdown::MarkdownInlineCode;
21use web_search::WebSearchRegistry;
22
23/// Search the web for information using your query.
24/// Use this when you need real-time information, facts, or data that might not be in your training.
25/// Results will include snippets and links from relevant web pages.
26#[derive(Debug, Serialize, Deserialize, JsonSchema)]
27pub struct WebSearchToolInput {
28 /// The search term or question to query on the web.
29 query: String,
30}
31
32#[derive(Debug, Serialize, Deserialize)]
33#[serde(untagged)]
34pub enum WebSearchToolOutput {
35 Success(WebSearchResponse),
36 Error { error: String },
37}
38
39impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
40 fn from(value: WebSearchToolOutput) -> Self {
41 match value {
42 WebSearchToolOutput::Success(response) => serde_json::to_string(&response)
43 .unwrap_or_else(|e| format!("Failed to serialize web search response: {e}"))
44 .into(),
45 WebSearchToolOutput::Error { error } => error.into(),
46 }
47 }
48}
49
50pub struct WebSearchTool;
51
52impl AgentTool for WebSearchTool {
53 type Input = WebSearchToolInput;
54 type Output = WebSearchToolOutput;
55
56 const NAME: &'static str = "web_search";
57
58 fn kind() -> acp::ToolKind {
59 acp::ToolKind::Fetch
60 }
61
62 fn initial_title(
63 &self,
64 _input: Result<Self::Input, serde_json::Value>,
65 _cx: &mut App,
66 ) -> SharedString {
67 "Searching the Web".into()
68 }
69
70 /// We currently only support Zed Cloud as a provider.
71 fn supports_provider(provider: &LanguageModelProviderId) -> bool {
72 provider == &ZED_CLOUD_PROVIDER_ID
73 }
74
75 fn run(
76 self: Arc<Self>,
77 input: ToolInput<Self::Input>,
78 event_stream: ToolCallEventStream,
79 cx: &mut App,
80 ) -> Task<Result<Self::Output, Self::Output>> {
81 cx.spawn(async move |cx| {
82 let input = input
83 .recv()
84 .await
85 .map_err(|e| WebSearchToolOutput::Error {
86 error: format!("Failed to receive tool input: {e}"),
87 })?;
88
89 let (authorize, search_task) = cx.update(|cx| {
90 let decision = decide_permission_from_settings(
91 Self::NAME,
92 std::slice::from_ref(&input.query),
93 AgentSettings::get_global(cx),
94 );
95
96 let authorize = match decision {
97 ToolPermissionDecision::Allow => None,
98 ToolPermissionDecision::Deny(reason) => {
99 return Err(WebSearchToolOutput::Error { error: reason });
100 }
101 ToolPermissionDecision::Confirm => {
102 let context =
103 crate::ToolPermissionContext::new(Self::NAME, vec![input.query.clone()]);
104 Some(event_stream.authorize(
105 format!("Search the web for {}", MarkdownInlineCode(&input.query)),
106 context,
107 cx,
108 ))
109 }
110 };
111
112 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
113 return Err(WebSearchToolOutput::Error {
114 error: "Web search is not available.".to_string(),
115 });
116 };
117
118 let search_task = provider.search(input.query, cx);
119 Ok((authorize, search_task))
120 })?;
121
122 if let Some(authorize) = authorize {
123 authorize.await.map_err(|e| WebSearchToolOutput::Error { error: e.to_string() })?;
124 }
125
126 let response = futures::select! {
127 result = search_task.fuse() => {
128 match result {
129 Ok(response) => response,
130 Err(err) => {
131 event_stream
132 .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
133 return Err(WebSearchToolOutput::Error { error: err.to_string() });
134 }
135 }
136 }
137 _ = event_stream.cancelled_by_user().fuse() => {
138 return Err(WebSearchToolOutput::Error { error: "Web search cancelled by user".to_string() });
139 }
140 };
141
142 emit_update(&response, &event_stream);
143 Ok(WebSearchToolOutput::Success(response))
144 })
145 }
146
147 fn replay(
148 &self,
149 _input: Self::Input,
150 output: Self::Output,
151 event_stream: ToolCallEventStream,
152 _cx: &mut App,
153 ) -> Result<()> {
154 if let WebSearchToolOutput::Success(response) = &output {
155 emit_update(response, &event_stream);
156 }
157 Ok(())
158 }
159}
160
161fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
162 let result_text = if response.results.len() == 1 {
163 "1 result".to_string()
164 } else {
165 format!("{} results", response.results.len())
166 };
167 event_stream.update_fields(
168 acp::ToolCallUpdateFields::new()
169 .title(format!("Searched the web: {result_text}"))
170 .content(
171 response
172 .results
173 .iter()
174 .map(|result| {
175 acp::ToolCallContent::Content(acp::Content::new(
176 acp::ContentBlock::ResourceLink(
177 acp::ResourceLink::new(result.title.clone(), result.url.clone())
178 .title(result.title.clone())
179 .description(result.text.clone()),
180 ),
181 ))
182 })
183 .collect::<Vec<_>>(),
184 ),
185 );
186}