1use std::sync::Arc;
2
3use crate::{
4 AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
5};
6use agent_client_protocol as acp;
7use agent_settings::AgentSettings;
8use anyhow::Result;
9use cloud_llm_client::WebSearchResponse;
10use futures::FutureExt as _;
11use gpui::{App, AppContext, Task};
12use language_model::{
13 LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::Settings;
18use ui::prelude::*;
19use util::markdown::MarkdownInlineCode;
20use web_search::WebSearchRegistry;
21
22/// Search the web for information using your query.
23/// Use this when you need real-time information, facts, or data that might not be in your training.
24/// Results will include snippets and links from relevant web pages.
25#[derive(Debug, Serialize, Deserialize, JsonSchema)]
26pub struct WebSearchToolInput {
27 /// The search term or question to query on the web.
28 query: String,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32#[serde(untagged)]
33pub enum WebSearchToolOutput {
34 Success(WebSearchResponse),
35 Error { error: String },
36}
37
38impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
39 fn from(value: WebSearchToolOutput) -> Self {
40 match value {
41 WebSearchToolOutput::Success(response) => serde_json::to_string(&response)
42 .unwrap_or_else(|e| format!("Failed to serialize web search response: {e}"))
43 .into(),
44 WebSearchToolOutput::Error { error } => error.into(),
45 }
46 }
47}
48
49pub struct WebSearchTool;
50
51impl AgentTool for WebSearchTool {
52 type Input = WebSearchToolInput;
53 type Output = WebSearchToolOutput;
54
55 const NAME: &'static str = "web_search";
56
57 fn kind() -> acp::ToolKind {
58 acp::ToolKind::Fetch
59 }
60
61 fn initial_title(
62 &self,
63 _input: Result<Self::Input, serde_json::Value>,
64 _cx: &mut App,
65 ) -> SharedString {
66 "Searching the Web".into()
67 }
68
69 /// We currently only support Zed Cloud as a provider.
70 fn supports_provider(provider: &LanguageModelProviderId) -> bool {
71 provider == &ZED_CLOUD_PROVIDER_ID
72 }
73
74 fn run(
75 self: Arc<Self>,
76 input: Self::Input,
77 event_stream: ToolCallEventStream,
78 cx: &mut App,
79 ) -> Task<Result<Self::Output, Self::Output>> {
80 let settings = AgentSettings::get_global(cx);
81 let decision = decide_permission_from_settings(
82 Self::NAME,
83 std::slice::from_ref(&input.query),
84 settings,
85 );
86
87 let authorize = match decision {
88 ToolPermissionDecision::Allow => None,
89 ToolPermissionDecision::Deny(reason) => {
90 return Task::ready(Err(WebSearchToolOutput::Error { error: reason }));
91 }
92 ToolPermissionDecision::Confirm => {
93 let context =
94 crate::ToolPermissionContext::new(Self::NAME, vec![input.query.clone()]);
95 Some(event_stream.authorize(
96 format!("Search the web for {}", MarkdownInlineCode(&input.query)),
97 context,
98 cx,
99 ))
100 }
101 };
102
103 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
104 return Task::ready(Err(WebSearchToolOutput::Error {
105 error: "Web search is not available.".to_string(),
106 }));
107 };
108
109 let search_task = provider.search(input.query, cx);
110 cx.background_spawn(async move {
111 if let Some(authorize) = authorize {
112 authorize.await.map_err(|e| WebSearchToolOutput::Error { error: e.to_string() })?;
113 }
114
115 let response = futures::select! {
116 result = search_task.fuse() => {
117 match result {
118 Ok(response) => response,
119 Err(err) => {
120 event_stream
121 .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
122 return Err(WebSearchToolOutput::Error { error: err.to_string() });
123 }
124 }
125 }
126 _ = event_stream.cancelled_by_user().fuse() => {
127 return Err(WebSearchToolOutput::Error { error: "Web search cancelled by user".to_string() });
128 }
129 };
130
131 emit_update(&response, &event_stream);
132 Ok(WebSearchToolOutput::Success(response))
133 })
134 }
135
136 fn replay(
137 &self,
138 _input: Self::Input,
139 output: Self::Output,
140 event_stream: ToolCallEventStream,
141 _cx: &mut App,
142 ) -> Result<()> {
143 if let WebSearchToolOutput::Success(response) = &output {
144 emit_update(response, &event_stream);
145 }
146 Ok(())
147 }
148}
149
150fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
151 let result_text = if response.results.len() == 1 {
152 "1 result".to_string()
153 } else {
154 format!("{} results", response.results.len())
155 };
156 event_stream.update_fields(
157 acp::ToolCallUpdateFields::new()
158 .title(format!("Searched the web: {result_text}"))
159 .content(
160 response
161 .results
162 .iter()
163 .map(|result| {
164 acp::ToolCallContent::Content(acp::Content::new(
165 acp::ContentBlock::ResourceLink(
166 acp::ResourceLink::new(result.title.clone(), result.url.clone())
167 .title(result.title.clone())
168 .description(result.text.clone()),
169 ),
170 ))
171 })
172 .collect::<Vec<_>>(),
173 ),
174 );
175}