web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use crate::ui::ToolCallCardHeader;
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
  7use futures::{Future, FutureExt, TryFutureExt};
  8use gpui::{
  9    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 10};
 11use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use ui::{IconName, Tooltip, prelude::*};
 16use web_search::WebSearchRegistry;
 17use workspace::Workspace;
 18use zed_llm_client::{WebSearchCitation, WebSearchResponse};
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct WebSearchToolInput {
 22    /// The search term or question to query on the web.
 23    query: String,
 24}
 25
 26#[derive(RegisterComponent)]
 27pub struct WebSearchTool;
 28
 29impl Tool for WebSearchTool {
 30    fn name(&self) -> String {
 31        "web_search".into()
 32    }
 33
 34    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 35        false
 36    }
 37
 38    fn description(&self) -> String {
 39        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 40    }
 41
 42    fn icon(&self) -> IconName {
 43        IconName::Globe
 44    }
 45
 46    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 47        json_schema_for::<WebSearchToolInput>(format)
 48    }
 49
 50    fn ui_text(&self, _input: &serde_json::Value) -> String {
 51        "Searching the Web".to_string()
 52    }
 53
 54    fn run(
 55        self: Arc<Self>,
 56        input: serde_json::Value,
 57        _messages: &[LanguageModelRequestMessage],
 58        _project: Entity<Project>,
 59        _action_log: Entity<ActionLog>,
 60        _window: Option<AnyWindowHandle>,
 61        cx: &mut App,
 62    ) -> ToolResult {
 63        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 64            Ok(input) => input,
 65            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 66        };
 67        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 68            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 69        };
 70
 71        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 72        let output = cx.background_spawn({
 73            let search_task = search_task.clone();
 74            async move {
 75                let response = search_task.await.map_err(|err| anyhow!(err))?;
 76                serde_json::to_string(&response).context("Failed to serialize search results")
 77            }
 78        });
 79
 80        ToolResult {
 81            output,
 82            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 83        }
 84    }
 85}
 86
 87struct WebSearchToolCard {
 88    response: Option<Result<WebSearchResponse>>,
 89    _task: Task<()>,
 90}
 91
 92impl WebSearchToolCard {
 93    fn new(
 94        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
 95        cx: &mut Context<Self>,
 96    ) -> Self {
 97        let _task = cx.spawn(async move |this, cx| {
 98            let response = search_task.await.map_err(|err| anyhow!(err));
 99            this.update(cx, |this, cx| {
100                this.response = Some(response);
101                cx.notify();
102            })
103            .ok();
104        });
105
106        Self {
107            response: None,
108            _task,
109        }
110    }
111}
112
113impl ToolCard for WebSearchToolCard {
114    fn render(
115        &mut self,
116        _status: &ToolUseStatus,
117        _window: &mut Window,
118        _workspace: WeakEntity<Workspace>,
119        cx: &mut Context<Self>,
120    ) -> impl IntoElement {
121        let header = match self.response.as_ref() {
122            Some(Ok(response)) => {
123                let text: SharedString = if response.citations.len() == 1 {
124                    "1 result".into()
125                } else {
126                    format!("{} results", response.citations.len()).into()
127                };
128                ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
129                    .with_secondary_text(text)
130            }
131            Some(Err(error)) => {
132                ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
133            }
134            None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
135        };
136
137        let content =
138            self.response.as_ref().and_then(|response| match response {
139                Ok(response) => {
140                    Some(
141                        v_flex()
142                            .overflow_hidden()
143                            .ml_1p5()
144                            .pl(px(5.))
145                            .border_l_1()
146                            .border_color(cx.theme().colors().border_variant)
147                            .gap_1()
148                            .children(response.citations.iter().enumerate().map(
149                                |(index, citation)| {
150                                    let title = citation.title.clone();
151                                    let url = citation.url.clone();
152
153                                    Button::new(("citation", index), title)
154                                        .label_size(LabelSize::Small)
155                                        .color(Color::Muted)
156                                        .icon(IconName::ArrowUpRight)
157                                        .icon_size(IconSize::XSmall)
158                                        .icon_position(IconPosition::End)
159                                        .truncate(true)
160                                        .tooltip({
161                                            let url = url.clone();
162                                            move |window, cx| {
163                                                Tooltip::with_meta(
164                                                    "Citation Link",
165                                                    None,
166                                                    url.clone(),
167                                                    window,
168                                                    cx,
169                                                )
170                                            }
171                                        })
172                                        .on_click({
173                                            let url = url.clone();
174                                            move |_, _, cx| cx.open_url(&url)
175                                        })
176                                },
177                            ))
178                            .into_any(),
179                    )
180                }
181                Err(_) => None,
182            });
183
184        v_flex().mb_3().gap_1().child(header).children(content)
185    }
186}
187
188impl Component for WebSearchTool {
189    fn scope() -> ComponentScope {
190        ComponentScope::Agent
191    }
192
193    fn sort_name() -> &'static str {
194        "ToolWebSearch"
195    }
196
197    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
198        let in_progress_search = cx.new(|cx| WebSearchToolCard {
199            response: None,
200            _task: cx.spawn(async move |_this, cx| {
201                loop {
202                    cx.background_executor()
203                        .timer(Duration::from_secs(60))
204                        .await
205                }
206            }),
207        });
208
209        let successful_search = cx.new(|_cx| WebSearchToolCard {
210            response: Some(Ok(example_search_response())),
211            _task: Task::ready(()),
212        });
213
214        let error_search = cx.new(|_cx| WebSearchToolCard {
215            response: Some(Err(anyhow!("Failed to resolve https://google.com"))),
216            _task: Task::ready(()),
217        });
218
219        Some(
220            v_flex()
221                .gap_6()
222                .children(vec![example_group(vec![
223                    single_example(
224                        "In Progress",
225                        div()
226                            .size_full()
227                            .child(in_progress_search.update(cx, |tool, cx| {
228                                tool.render(
229                                    &ToolUseStatus::Pending,
230                                    window,
231                                    WeakEntity::new_invalid(),
232                                    cx,
233                                )
234                                .into_any_element()
235                            }))
236                            .into_any_element(),
237                    ),
238                    single_example(
239                        "Successful",
240                        div()
241                            .size_full()
242                            .child(successful_search.update(cx, |tool, cx| {
243                                tool.render(
244                                    &ToolUseStatus::Finished("".into()),
245                                    window,
246                                    WeakEntity::new_invalid(),
247                                    cx,
248                                )
249                                .into_any_element()
250                            }))
251                            .into_any_element(),
252                    ),
253                    single_example(
254                        "Error",
255                        div()
256                            .size_full()
257                            .child(error_search.update(cx, |tool, cx| {
258                                tool.render(
259                                    &ToolUseStatus::Error("".into()),
260                                    window,
261                                    WeakEntity::new_invalid(),
262                                    cx,
263                                )
264                                .into_any_element()
265                            }))
266                            .into_any_element(),
267                    ),
268                ])])
269                .into_any_element(),
270        )
271    }
272}
273
274fn example_search_response() -> WebSearchResponse {
275    WebSearchResponse {
276        summary: r#"Toronto boasts a vibrant culinary scene with a diverse array of..."#
277            .to_string(),
278        citations: vec![
279            WebSearchCitation {
280                title: "Alo".to_string(),
281                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
282                range: Some(147..213),
283            },
284            WebSearchCitation {
285                title: "Edulis".to_string(),
286                url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(),
287                range: Some(447..519),
288            },
289            WebSearchCitation {
290                title: "Sushi Masaki Saito".to_string(),
291                url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada"
292                    .to_string(),
293                range: Some(776..872),
294            },
295            WebSearchCitation {
296                title: "Shoushin".to_string(),
297                url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(),
298                range: Some(1072..1148),
299            },
300            WebSearchCitation {
301                title: "Restaurant 20 Victoria".to_string(),
302                url:
303                    "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada"
304                        .to_string(),
305                range: Some(1291..1395),
306            },
307        ],
308    }
309}