web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use crate::ui::ToolCallCardHeader;
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolResultOutput, ToolUseStatus};
  7use futures::{Future, FutureExt, TryFutureExt};
  8use gpui::{
  9    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 10};
 11use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use ui::{IconName, Tooltip, prelude::*};
 16use web_search::WebSearchRegistry;
 17use workspace::Workspace;
 18use zed_llm_client::{WebSearchResponse, WebSearchResult};
 19
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct WebSearchToolInput {
 22    /// The search term or question to query on the web.
 23    query: String,
 24}
 25
 26pub struct WebSearchTool;
 27
 28impl Tool for WebSearchTool {
 29    fn name(&self) -> String {
 30        "web_search".into()
 31    }
 32
 33    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 34        false
 35    }
 36
 37    fn description(&self) -> String {
 38        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 39    }
 40
 41    fn icon(&self) -> IconName {
 42        IconName::Globe
 43    }
 44
 45    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 46        json_schema_for::<WebSearchToolInput>(format)
 47    }
 48
 49    fn ui_text(&self, _input: &serde_json::Value) -> String {
 50        "Searching the Web".to_string()
 51    }
 52
 53    fn run(
 54        self: Arc<Self>,
 55        input: serde_json::Value,
 56        _request: Arc<LanguageModelRequest>,
 57        _project: Entity<Project>,
 58        _action_log: Entity<ActionLog>,
 59        _model: Arc<dyn LanguageModel>,
 60        _window: Option<AnyWindowHandle>,
 61        cx: &mut App,
 62    ) -> ToolResult {
 63        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 64            Ok(input) => input,
 65            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 66        };
 67        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 68            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 69        };
 70
 71        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 72        let output = cx.background_spawn({
 73            let search_task = search_task.clone();
 74            async move {
 75                let response = search_task.await.map_err(|err| anyhow!(err))?;
 76                Ok(ToolResultOutput {
 77                    content: serde_json::to_string(&response)
 78                        .context("Failed to serialize search results")?,
 79                    output: Some(serde_json::to_value(response)?),
 80                })
 81            }
 82        });
 83
 84        ToolResult {
 85            output,
 86            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 87        }
 88    }
 89
 90    fn deserialize_card(
 91        self: Arc<Self>,
 92        output: serde_json::Value,
 93        _project: Entity<Project>,
 94        _window: &mut Window,
 95        cx: &mut App,
 96    ) -> Option<assistant_tool::AnyToolCard> {
 97        let output = serde_json::from_value::<WebSearchResponse>(output).ok()?;
 98        let card = cx.new(|cx| WebSearchToolCard::new(Task::ready(Ok(output)), cx));
 99        Some(card.into())
100    }
101}
102
103#[derive(RegisterComponent)]
104struct WebSearchToolCard {
105    response: Option<Result<WebSearchResponse>>,
106    _task: Task<()>,
107}
108
109impl WebSearchToolCard {
110    fn new(
111        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
112        cx: &mut Context<Self>,
113    ) -> Self {
114        let _task = cx.spawn(async move |this, cx| {
115            let response = search_task.await.map_err(|err| anyhow!(err));
116            this.update(cx, |this, cx| {
117                this.response = Some(response);
118                cx.notify();
119            })
120            .ok();
121        });
122
123        Self {
124            response: None,
125            _task,
126        }
127    }
128}
129
130impl ToolCard for WebSearchToolCard {
131    fn render(
132        &mut self,
133        _status: &ToolUseStatus,
134        _window: &mut Window,
135        _workspace: WeakEntity<Workspace>,
136        cx: &mut Context<Self>,
137    ) -> impl IntoElement {
138        let header = match self.response.as_ref() {
139            Some(Ok(response)) => {
140                let text: SharedString = if response.results.len() == 1 {
141                    "1 result".into()
142                } else {
143                    format!("{} results", response.results.len()).into()
144                };
145                ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
146                    .with_secondary_text(text)
147            }
148            Some(Err(error)) => {
149                ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
150            }
151            None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
152        };
153
154        let content = self.response.as_ref().and_then(|response| match response {
155            Ok(response) => Some(
156                v_flex()
157                    .overflow_hidden()
158                    .ml_1p5()
159                    .pl(px(5.))
160                    .border_l_1()
161                    .border_color(cx.theme().colors().border_variant)
162                    .gap_1()
163                    .children(response.results.iter().enumerate().map(|(index, result)| {
164                        let title = result.title.clone();
165                        let url = result.url.clone();
166
167                        Button::new(("result", index), title)
168                            .label_size(LabelSize::Small)
169                            .color(Color::Muted)
170                            .icon(IconName::ArrowUpRight)
171                            .icon_size(IconSize::XSmall)
172                            .icon_position(IconPosition::End)
173                            .truncate(true)
174                            .tooltip({
175                                let url = url.clone();
176                                move |window, cx| {
177                                    Tooltip::with_meta(
178                                        "Web Search Result",
179                                        None,
180                                        url.clone(),
181                                        window,
182                                        cx,
183                                    )
184                                }
185                            })
186                            .on_click({
187                                let url = url.clone();
188                                move |_, _, cx| cx.open_url(&url)
189                            })
190                    }))
191                    .into_any(),
192            ),
193            Err(_) => None,
194        });
195
196        v_flex().mb_3().gap_1().child(header).children(content)
197    }
198}
199
200impl Component for WebSearchToolCard {
201    fn scope() -> ComponentScope {
202        ComponentScope::Agent
203    }
204
205    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
206        let in_progress_search = cx.new(|cx| WebSearchToolCard {
207            response: None,
208            _task: cx.spawn(async move |_this, cx| {
209                loop {
210                    cx.background_executor()
211                        .timer(Duration::from_secs(60))
212                        .await
213                }
214            }),
215        });
216
217        let successful_search = cx.new(|_cx| WebSearchToolCard {
218            response: Some(Ok(example_search_response())),
219            _task: Task::ready(()),
220        });
221
222        let error_search = cx.new(|_cx| WebSearchToolCard {
223            response: Some(Err(anyhow!("Failed to resolve https://google.com"))),
224            _task: Task::ready(()),
225        });
226
227        Some(
228            v_flex()
229                .gap_6()
230                .children(vec![example_group(vec![
231                    single_example(
232                        "In Progress",
233                        div()
234                            .size_full()
235                            .child(in_progress_search.update(cx, |tool, cx| {
236                                tool.render(
237                                    &ToolUseStatus::Pending,
238                                    window,
239                                    WeakEntity::new_invalid(),
240                                    cx,
241                                )
242                                .into_any_element()
243                            }))
244                            .into_any_element(),
245                    ),
246                    single_example(
247                        "Successful",
248                        div()
249                            .size_full()
250                            .child(successful_search.update(cx, |tool, cx| {
251                                tool.render(
252                                    &ToolUseStatus::Finished("".into()),
253                                    window,
254                                    WeakEntity::new_invalid(),
255                                    cx,
256                                )
257                                .into_any_element()
258                            }))
259                            .into_any_element(),
260                    ),
261                    single_example(
262                        "Error",
263                        div()
264                            .size_full()
265                            .child(error_search.update(cx, |tool, cx| {
266                                tool.render(
267                                    &ToolUseStatus::Error("".into()),
268                                    window,
269                                    WeakEntity::new_invalid(),
270                                    cx,
271                                )
272                                .into_any_element()
273                            }))
274                            .into_any_element(),
275                    ),
276                ])])
277                .into_any_element(),
278        )
279    }
280}
281
282fn example_search_response() -> WebSearchResponse {
283    WebSearchResponse {
284        results: vec![
285            WebSearchResult {
286                title: "Alo".to_string(),
287                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
288                text: "Alo is a popular restaurant in Toronto.".to_string(),
289            },
290            WebSearchResult {
291                title: "Alo".to_string(),
292                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
293                text: "Information about Alo restaurant in Toronto.".to_string(),
294            },
295            WebSearchResult {
296                title: "Edulis".to_string(),
297                url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(),
298                text: "Details about Edulis restaurant in Toronto.".to_string(),
299            },
300            WebSearchResult {
301                title: "Sushi Masaki Saito".to_string(),
302                url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada"
303                    .to_string(),
304                text: "Information about Sushi Masaki Saito in Toronto.".to_string(),
305            },
306            WebSearchResult {
307                title: "Shoushin".to_string(),
308                url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(),
309                text: "Details about Shoushin restaurant in Toronto.".to_string(),
310            },
311            WebSearchResult {
312                title: "Restaurant 20 Victoria".to_string(),
313                url:
314                    "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada"
315                        .to_string(),
316                text: "Information about Restaurant 20 Victoria in Toronto.".to_string(),
317            },
318        ],
319    }
320}