1use std::sync::Arc;
2
3use crate::{
4 AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
5};
6use agent_client_protocol as acp;
7use agent_settings::AgentSettings;
8use anyhow::{Result, anyhow};
9use cloud_llm_client::WebSearchResponse;
10use futures::FutureExt as _;
11use gpui::{App, AppContext, Task};
12use language_model::{
13 LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
14};
15use schemars::JsonSchema;
16use serde::{Deserialize, Serialize};
17use settings::Settings;
18use ui::prelude::*;
19use util::markdown::MarkdownInlineCode;
20use web_search::WebSearchRegistry;
21
22/// Search the web for information using your query.
23/// Use this when you need real-time information, facts, or data that might not be in your training.
24/// Results will include snippets and links from relevant web pages.
25#[derive(Debug, Serialize, Deserialize, JsonSchema)]
26pub struct WebSearchToolInput {
27 /// The search term or question to query on the web.
28 query: String,
29}
30
31#[derive(Debug, Serialize, Deserialize)]
32#[serde(transparent)]
33pub struct WebSearchToolOutput(WebSearchResponse);
34
35impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
36 fn from(value: WebSearchToolOutput) -> Self {
37 serde_json::to_string(&value.0)
38 .expect("Failed to serialize WebSearchResponse")
39 .into()
40 }
41}
42
43pub struct WebSearchTool;
44
45impl AgentTool for WebSearchTool {
46 type Input = WebSearchToolInput;
47 type Output = WebSearchToolOutput;
48
49 const NAME: &'static str = "web_search";
50
51 fn kind() -> acp::ToolKind {
52 acp::ToolKind::Fetch
53 }
54
55 fn initial_title(
56 &self,
57 _input: Result<Self::Input, serde_json::Value>,
58 _cx: &mut App,
59 ) -> SharedString {
60 "Searching the Web".into()
61 }
62
63 /// We currently only support Zed Cloud as a provider.
64 fn supports_provider(provider: &LanguageModelProviderId) -> bool {
65 provider == &ZED_CLOUD_PROVIDER_ID
66 }
67
68 fn run(
69 self: Arc<Self>,
70 input: Self::Input,
71 event_stream: ToolCallEventStream,
72 cx: &mut App,
73 ) -> Task<Result<Self::Output>> {
74 let settings = AgentSettings::get_global(cx);
75 let decision = decide_permission_from_settings(
76 Self::NAME,
77 std::slice::from_ref(&input.query),
78 settings,
79 );
80
81 let authorize = match decision {
82 ToolPermissionDecision::Allow => None,
83 ToolPermissionDecision::Deny(reason) => {
84 return Task::ready(Err(anyhow!("{}", reason)));
85 }
86 ToolPermissionDecision::Confirm => {
87 let context = crate::ToolPermissionContext {
88 tool_name: Self::NAME.to_string(),
89 input_values: vec![input.query.clone()],
90 };
91 Some(event_stream.authorize(
92 format!("Search the web for {}", MarkdownInlineCode(&input.query)),
93 context,
94 cx,
95 ))
96 }
97 };
98
99 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
100 return Task::ready(Err(anyhow!("Web search is not available.")));
101 };
102
103 let search_task = provider.search(input.query, cx);
104 cx.background_spawn(async move {
105 if let Some(authorize) = authorize {
106 authorize.await?;
107 }
108
109 let response = futures::select! {
110 result = search_task.fuse() => {
111 match result {
112 Ok(response) => response,
113 Err(err) => {
114 event_stream
115 .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
116 return Err(err);
117 }
118 }
119 }
120 _ = event_stream.cancelled_by_user().fuse() => {
121 anyhow::bail!("Web search cancelled by user");
122 }
123 };
124
125 emit_update(&response, &event_stream);
126 Ok(WebSearchToolOutput(response))
127 })
128 }
129
130 fn replay(
131 &self,
132 _input: Self::Input,
133 output: Self::Output,
134 event_stream: ToolCallEventStream,
135 _cx: &mut App,
136 ) -> Result<()> {
137 emit_update(&output.0, &event_stream);
138 Ok(())
139 }
140}
141
142fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
143 let result_text = if response.results.len() == 1 {
144 "1 result".to_string()
145 } else {
146 format!("{} results", response.results.len())
147 };
148 event_stream.update_fields(
149 acp::ToolCallUpdateFields::new()
150 .title(format!("Searched the web: {result_text}"))
151 .content(
152 response
153 .results
154 .iter()
155 .map(|result| {
156 acp::ToolCallContent::Content(acp::Content::new(
157 acp::ContentBlock::ResourceLink(
158 acp::ResourceLink::new(result.title.clone(), result.url.clone())
159 .title(result.title.clone())
160 .description(result.text.clone()),
161 ),
162 ))
163 })
164 .collect::<Vec<_>>(),
165 ),
166 );
167}