web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use crate::ui::ToolCallCardHeader;
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
  7use futures::{Future, FutureExt, TryFutureExt};
  8use gpui::{App, AppContext, Context, Entity, IntoElement, Task, Window};
  9use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 10use project::Project;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use ui::{IconName, Tooltip, prelude::*};
 14use web_search::WebSearchRegistry;
 15use zed_llm_client::{WebSearchCitation, WebSearchResponse};
 16
 17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 18pub struct WebSearchToolInput {
 19    /// The search term or question to query on the web.
 20    query: String,
 21}
 22
 23#[derive(RegisterComponent)]
 24pub struct WebSearchTool;
 25
 26impl Tool for WebSearchTool {
 27    fn name(&self) -> String {
 28        "web_search".into()
 29    }
 30
 31    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 32        false
 33    }
 34
 35    fn description(&self) -> String {
 36        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 37    }
 38
 39    fn icon(&self) -> IconName {
 40        IconName::Globe
 41    }
 42
 43    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 44        json_schema_for::<WebSearchToolInput>(format)
 45    }
 46
 47    fn ui_text(&self, _input: &serde_json::Value) -> String {
 48        "Searching the Web".to_string()
 49    }
 50
 51    fn run(
 52        self: Arc<Self>,
 53        input: serde_json::Value,
 54        _messages: &[LanguageModelRequestMessage],
 55        _project: Entity<Project>,
 56        _action_log: Entity<ActionLog>,
 57        cx: &mut App,
 58    ) -> ToolResult {
 59        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 60            Ok(input) => input,
 61            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 62        };
 63        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 64            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 65        };
 66
 67        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 68        let output = cx.background_spawn({
 69            let search_task = search_task.clone();
 70            async move {
 71                let response = search_task.await.map_err(|err| anyhow!(err))?;
 72                serde_json::to_string(&response).context("Failed to serialize search results")
 73            }
 74        });
 75
 76        ToolResult {
 77            output,
 78            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 79        }
 80    }
 81}
 82
 83struct WebSearchToolCard {
 84    response: Option<Result<WebSearchResponse>>,
 85    _task: Task<()>,
 86}
 87
 88impl WebSearchToolCard {
 89    fn new(
 90        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
 91        cx: &mut Context<Self>,
 92    ) -> Self {
 93        let _task = cx.spawn(async move |this, cx| {
 94            let response = search_task.await.map_err(|err| anyhow!(err));
 95            this.update(cx, |this, cx| {
 96                this.response = Some(response);
 97                cx.notify();
 98            })
 99            .ok();
100        });
101
102        Self {
103            response: None,
104            _task,
105        }
106    }
107}
108
109impl ToolCard for WebSearchToolCard {
110    fn render(
111        &mut self,
112        _status: &ToolUseStatus,
113        _window: &mut Window,
114        cx: &mut Context<Self>,
115    ) -> impl IntoElement {
116        let header = match self.response.as_ref() {
117            Some(Ok(response)) => {
118                let text: SharedString = if response.citations.len() == 1 {
119                    "1 result".into()
120                } else {
121                    format!("{} results", response.citations.len()).into()
122                };
123                ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
124                    .with_secondary_text(text)
125            }
126            Some(Err(error)) => {
127                ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
128            }
129            None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
130        };
131
132        let content =
133            self.response.as_ref().and_then(|response| match response {
134                Ok(response) => {
135                    Some(
136                        v_flex()
137                            .overflow_hidden()
138                            .ml_1p5()
139                            .pl(px(5.))
140                            .border_l_1()
141                            .border_color(cx.theme().colors().border_variant)
142                            .gap_1()
143                            .children(response.citations.iter().enumerate().map(
144                                |(index, citation)| {
145                                    let title = citation.title.clone();
146                                    let url = citation.url.clone();
147
148                                    Button::new(("citation", index), title)
149                                        .label_size(LabelSize::Small)
150                                        .color(Color::Muted)
151                                        .icon(IconName::ArrowUpRight)
152                                        .icon_size(IconSize::XSmall)
153                                        .icon_position(IconPosition::End)
154                                        .truncate(true)
155                                        .tooltip({
156                                            let url = url.clone();
157                                            move |window, cx| {
158                                                Tooltip::with_meta(
159                                                    "Citation Link",
160                                                    None,
161                                                    url.clone(),
162                                                    window,
163                                                    cx,
164                                                )
165                                            }
166                                        })
167                                        .on_click({
168                                            let url = url.clone();
169                                            move |_, _, cx| cx.open_url(&url)
170                                        })
171                                },
172                            ))
173                            .into_any(),
174                    )
175                }
176                Err(_) => None,
177            });
178
179        v_flex().mb_3().gap_1().child(header).children(content)
180    }
181}
182
183impl Component for WebSearchTool {
184    fn scope() -> ComponentScope {
185        ComponentScope::Agent
186    }
187
188    fn sort_name() -> &'static str {
189        "ToolWebSearch"
190    }
191
192    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
193        let in_progress_search = cx.new(|cx| WebSearchToolCard {
194            response: None,
195            _task: cx.spawn(async move |_this, cx| {
196                loop {
197                    cx.background_executor()
198                        .timer(Duration::from_secs(60))
199                        .await
200                }
201            }),
202        });
203
204        let successful_search = cx.new(|_cx| WebSearchToolCard {
205            response: Some(Ok(example_search_response())),
206            _task: Task::ready(()),
207        });
208
209        let error_search = cx.new(|_cx| WebSearchToolCard {
210            response: Some(Err(anyhow!("Failed to resolve https://google.com"))),
211            _task: Task::ready(()),
212        });
213
214        Some(
215            v_flex()
216                .gap_6()
217                .children(vec![example_group(vec![
218                    single_example(
219                        "In Progress",
220                        div()
221                            .size_full()
222                            .child(in_progress_search.update(cx, |tool, cx| {
223                                tool.render(&ToolUseStatus::Pending, window, cx)
224                                    .into_any_element()
225                            }))
226                            .into_any_element(),
227                    ),
228                    single_example(
229                        "Successful",
230                        div()
231                            .size_full()
232                            .child(successful_search.update(cx, |tool, cx| {
233                                tool.render(&ToolUseStatus::Finished("".into()), window, cx)
234                                    .into_any_element()
235                            }))
236                            .into_any_element(),
237                    ),
238                    single_example(
239                        "Error",
240                        div()
241                            .size_full()
242                            .child(error_search.update(cx, |tool, cx| {
243                                tool.render(&ToolUseStatus::Error("".into()), window, cx)
244                                    .into_any_element()
245                            }))
246                            .into_any_element(),
247                    ),
248                ])])
249                .into_any_element(),
250        )
251    }
252}
253
254fn example_search_response() -> WebSearchResponse {
255    WebSearchResponse {
256        summary: r#"Toronto boasts a vibrant culinary scene with a diverse array of..."#
257            .to_string(),
258        citations: vec![
259            WebSearchCitation {
260                title: "Alo".to_string(),
261                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
262                range: Some(147..213),
263            },
264            WebSearchCitation {
265                title: "Edulis".to_string(),
266                url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(),
267                range: Some(447..519),
268            },
269            WebSearchCitation {
270                title: "Sushi Masaki Saito".to_string(),
271                url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada"
272                    .to_string(),
273                range: Some(776..872),
274            },
275            WebSearchCitation {
276                title: "Shoushin".to_string(),
277                url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(),
278                range: Some(1072..1148),
279            },
280            WebSearchCitation {
281                title: "Restaurant 20 Victoria".to_string(),
282                url:
283                    "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada"
284                        .to_string(),
285                range: Some(1291..1395),
286            },
287        ],
288    }
289}