web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use anyhow::{Context as _, Result, anyhow};
  5use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
  6use futures::{FutureExt, TryFutureExt};
  7use gpui::{
  8    Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
  9    pulsating_between,
 10};
 11use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use ui::{IconName, Tooltip, prelude::*};
 16use web_search::WebSearchRegistry;
 17use zed_llm_client::{WebSearchCitation, WebSearchResponse};
 18
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct WebSearchToolInput {
 21    /// The search term or question to query on the web.
 22    query: String,
 23}
 24
 25#[derive(RegisterComponent)]
 26pub struct WebSearchTool;
 27
 28impl Tool for WebSearchTool {
 29    fn name(&self) -> String {
 30        "web_search".into()
 31    }
 32
 33    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 34        false
 35    }
 36
 37    fn description(&self) -> String {
 38        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 39    }
 40
 41    fn icon(&self) -> IconName {
 42        IconName::Globe
 43    }
 44
 45    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 46        json_schema_for::<WebSearchToolInput>(format)
 47    }
 48
 49    fn ui_text(&self, _input: &serde_json::Value) -> String {
 50        "Web Search".to_string()
 51    }
 52
 53    fn run(
 54        self: Arc<Self>,
 55        input: serde_json::Value,
 56        _messages: &[LanguageModelRequestMessage],
 57        _project: Entity<Project>,
 58        _action_log: Entity<ActionLog>,
 59        cx: &mut App,
 60    ) -> ToolResult {
 61        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 62            Ok(input) => input,
 63            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 64        };
 65        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 66            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 67        };
 68
 69        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 70        let output = cx.background_spawn({
 71            let search_task = search_task.clone();
 72            async move {
 73                let response = search_task.await.map_err(|err| anyhow!(err))?;
 74                serde_json::to_string(&response).context("Failed to serialize search results")
 75            }
 76        });
 77
 78        ToolResult {
 79            output,
 80            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 81        }
 82    }
 83}
 84
 85struct WebSearchToolCard {
 86    response: Option<Result<WebSearchResponse>>,
 87    _task: Task<()>,
 88}
 89
 90impl WebSearchToolCard {
 91    fn new(
 92        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
 93        cx: &mut Context<Self>,
 94    ) -> Self {
 95        let _task = cx.spawn(async move |this, cx| {
 96            let response = search_task.await.map_err(|err| anyhow!(err));
 97            this.update(cx, |this, cx| {
 98                this.response = Some(response);
 99                cx.notify();
100            })
101            .ok();
102        });
103
104        Self {
105            response: None,
106            _task,
107        }
108    }
109}
110
111impl ToolCard for WebSearchToolCard {
112    fn render(
113        &mut self,
114        _status: &ToolUseStatus,
115        _window: &mut Window,
116        cx: &mut Context<Self>,
117    ) -> impl IntoElement {
118        let header = h_flex()
119            .id("tool-label-container")
120            .gap_1p5()
121            .max_w_full()
122            .overflow_x_scroll()
123            .child(
124                Icon::new(IconName::Globe)
125                    .size(IconSize::XSmall)
126                    .color(Color::Muted),
127            )
128            .child(match self.response.as_ref() {
129                Some(Ok(response)) => {
130                    let text: SharedString = if response.citations.len() == 1 {
131                        "1 result".into()
132                    } else {
133                        format!("{} results", response.citations.len()).into()
134                    };
135                    h_flex()
136                        .gap_1p5()
137                        .child(Label::new("Searched the Web").size(LabelSize::Small))
138                        .child(
139                            div()
140                                .size(px(3.))
141                                .rounded_full()
142                                .bg(cx.theme().colors().text),
143                        )
144                        .child(Label::new(text).size(LabelSize::Small))
145                        .into_any_element()
146                }
147                Some(Err(error)) => div()
148                    .id("web-search-error")
149                    .child(Label::new("Web Search failed").size(LabelSize::Small))
150                    .tooltip(Tooltip::text(error.to_string()))
151                    .into_any_element(),
152
153                None => Label::new("Searching the Web…")
154                    .size(LabelSize::Small)
155                    .with_animation(
156                        "web-search-label",
157                        Animation::new(Duration::from_secs(2))
158                            .repeat()
159                            .with_easing(pulsating_between(0.6, 1.)),
160                        |label, delta| label.alpha(delta),
161                    )
162                    .into_any_element(),
163            })
164            .into_any();
165
166        let content =
167            self.response.as_ref().and_then(|response| match response {
168                Ok(response) => {
169                    Some(
170                        v_flex()
171                            .ml_1p5()
172                            .pl_1p5()
173                            .border_l_1()
174                            .border_color(cx.theme().colors().border_variant)
175                            .gap_1()
176                            .children(response.citations.iter().enumerate().map(
177                                |(index, citation)| {
178                                    let title = citation.title.clone();
179                                    let url = citation.url.clone();
180
181                                    Button::new(("citation", index), title)
182                                        .label_size(LabelSize::Small)
183                                        .color(Color::Muted)
184                                        .icon(IconName::ArrowUpRight)
185                                        .icon_size(IconSize::XSmall)
186                                        .icon_position(IconPosition::End)
187                                        .truncate(true)
188                                        .tooltip({
189                                            let url = url.clone();
190                                            move |window, cx| {
191                                                Tooltip::with_meta(
192                                                    "Citation Link",
193                                                    None,
194                                                    url.clone(),
195                                                    window,
196                                                    cx,
197                                                )
198                                            }
199                                        })
200                                        .on_click({
201                                            let url = url.clone();
202                                            move |_, _, cx| cx.open_url(&url)
203                                        })
204                                },
205                            ))
206                            .into_any(),
207                    )
208                }
209                Err(_) => None,
210            });
211
212        v_flex().my_2().gap_1().child(header).children(content)
213    }
214}
215
216impl Component for WebSearchTool {
217    fn scope() -> ComponentScope {
218        ComponentScope::Agent
219    }
220
221    fn sort_name() -> &'static str {
222        "ToolWebSearch"
223    }
224
225    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
226        let in_progress_search = cx.new(|cx| WebSearchToolCard {
227            response: None,
228            _task: cx.spawn(async move |_this, cx| {
229                loop {
230                    cx.background_executor()
231                        .timer(Duration::from_secs(60))
232                        .await
233                }
234            }),
235        });
236
237        let successful_search = cx.new(|_cx| WebSearchToolCard {
238            response: Some(Ok(example_search_response())),
239            _task: Task::ready(()),
240        });
241
242        let error_search = cx.new(|_cx| WebSearchToolCard {
243            response: Some(Err(anyhow!("Failed to resolve https://google.com"))),
244            _task: Task::ready(()),
245        });
246
247        Some(
248            v_flex()
249                .gap_6()
250                .children(vec![example_group(vec![
251                    single_example(
252                        "In Progress",
253                        div()
254                            .size_full()
255                            .child(in_progress_search.update(cx, |tool, cx| {
256                                tool.render(&ToolUseStatus::Pending, window, cx)
257                                    .into_any_element()
258                            }))
259                            .into_any_element(),
260                    ),
261                    single_example(
262                        "Successful",
263                        div()
264                            .size_full()
265                            .child(successful_search.update(cx, |tool, cx| {
266                                tool.render(&ToolUseStatus::Finished("".into()), window, cx)
267                                    .into_any_element()
268                            }))
269                            .into_any_element(),
270                    ),
271                    single_example(
272                        "Error",
273                        div()
274                            .size_full()
275                            .child(error_search.update(cx, |tool, cx| {
276                                tool.render(&ToolUseStatus::Error("".into()), window, cx)
277                                    .into_any_element()
278                            }))
279                            .into_any_element(),
280                    ),
281                ])])
282                .into_any_element(),
283        )
284    }
285}
286
287fn example_search_response() -> WebSearchResponse {
288    WebSearchResponse {
289        summary: r#"Toronto boasts a vibrant culinary scene with a diverse array of..."#
290            .to_string(),
291        citations: vec![
292            WebSearchCitation {
293                title: "Alo".to_string(),
294                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
295                range: Some(147..213),
296            },
297            WebSearchCitation {
298                title: "Edulis".to_string(),
299                url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(),
300                range: Some(447..519),
301            },
302            WebSearchCitation {
303                title: "Sushi Masaki Saito".to_string(),
304                url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada"
305                    .to_string(),
306                range: Some(776..872),
307            },
308            WebSearchCitation {
309                title: "Shoushin".to_string(),
310                url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(),
311                range: Some(1072..1148),
312            },
313            WebSearchCitation {
314                title: "Restaurant 20 Victoria".to_string(),
315                url:
316                    "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada"
317                        .to_string(),
318                range: Some(1291..1395),
319            },
320        ],
321    }
322}