web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use crate::ui::ToolCallCardHeader;
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{
  7    ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
  8};
  9use futures::{Future, FutureExt, TryFutureExt};
 10use gpui::{
 11    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 12};
 13use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 14use project::Project;
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use ui::{IconName, Tooltip, prelude::*};
 18use web_search::WebSearchRegistry;
 19use workspace::Workspace;
 20use zed_llm_client::{WebSearchResponse, WebSearchResult};
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct WebSearchToolInput {
 24    /// The search term or question to query on the web.
 25    query: String,
 26}
 27
 28pub struct WebSearchTool;
 29
 30impl Tool for WebSearchTool {
 31    fn name(&self) -> String {
 32        "web_search".into()
 33    }
 34
 35    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 36        false
 37    }
 38
 39    fn description(&self) -> String {
 40        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 41    }
 42
 43    fn icon(&self) -> IconName {
 44        IconName::Globe
 45    }
 46
 47    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 48        json_schema_for::<WebSearchToolInput>(format)
 49    }
 50
 51    fn ui_text(&self, _input: &serde_json::Value) -> String {
 52        "Searching the Web".to_string()
 53    }
 54
 55    fn run(
 56        self: Arc<Self>,
 57        input: serde_json::Value,
 58        _request: Arc<LanguageModelRequest>,
 59        _project: Entity<Project>,
 60        _action_log: Entity<ActionLog>,
 61        _model: Arc<dyn LanguageModel>,
 62        _window: Option<AnyWindowHandle>,
 63        cx: &mut App,
 64    ) -> ToolResult {
 65        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 66            Ok(input) => input,
 67            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 68        };
 69        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 70            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 71        };
 72
 73        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 74        let output = cx.background_spawn({
 75            let search_task = search_task.clone();
 76            async move {
 77                let response = search_task.await.map_err(|err| anyhow!(err))?;
 78                Ok(ToolResultOutput {
 79                    content: ToolResultContent::Text(
 80                        serde_json::to_string(&response)
 81                            .context("Failed to serialize search results")?,
 82                    ),
 83                    output: Some(serde_json::to_value(response)?),
 84                })
 85            }
 86        });
 87
 88        ToolResult {
 89            output,
 90            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 91        }
 92    }
 93
 94    fn deserialize_card(
 95        self: Arc<Self>,
 96        output: serde_json::Value,
 97        _project: Entity<Project>,
 98        _window: &mut Window,
 99        cx: &mut App,
100    ) -> Option<assistant_tool::AnyToolCard> {
101        let output = serde_json::from_value::<WebSearchResponse>(output).ok()?;
102        let card = cx.new(|cx| WebSearchToolCard::new(Task::ready(Ok(output)), cx));
103        Some(card.into())
104    }
105}
106
107#[derive(RegisterComponent)]
108struct WebSearchToolCard {
109    response: Option<Result<WebSearchResponse>>,
110    _task: Task<()>,
111}
112
113impl WebSearchToolCard {
114    fn new(
115        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
116        cx: &mut Context<Self>,
117    ) -> Self {
118        let _task = cx.spawn(async move |this, cx| {
119            let response = search_task.await.map_err(|err| anyhow!(err));
120            this.update(cx, |this, cx| {
121                this.response = Some(response);
122                cx.notify();
123            })
124            .ok();
125        });
126
127        Self {
128            response: None,
129            _task,
130        }
131    }
132}
133
134impl ToolCard for WebSearchToolCard {
135    fn render(
136        &mut self,
137        _status: &ToolUseStatus,
138        _window: &mut Window,
139        _workspace: WeakEntity<Workspace>,
140        cx: &mut Context<Self>,
141    ) -> impl IntoElement {
142        let header = match self.response.as_ref() {
143            Some(Ok(response)) => {
144                let text: SharedString = if response.results.len() == 1 {
145                    "1 result".into()
146                } else {
147                    format!("{} results", response.results.len()).into()
148                };
149                ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
150                    .with_secondary_text(text)
151            }
152            Some(Err(error)) => {
153                ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
154            }
155            None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
156        };
157
158        let content = self.response.as_ref().and_then(|response| match response {
159            Ok(response) => Some(
160                v_flex()
161                    .overflow_hidden()
162                    .ml_1p5()
163                    .pl(px(5.))
164                    .border_l_1()
165                    .border_color(cx.theme().colors().border_variant)
166                    .gap_1()
167                    .children(response.results.iter().enumerate().map(|(index, result)| {
168                        let title = result.title.clone();
169                        let url = SharedString::from(result.url.clone());
170
171                        Button::new(("result", index), title)
172                            .label_size(LabelSize::Small)
173                            .color(Color::Muted)
174                            .icon(IconName::ArrowUpRight)
175                            .icon_size(IconSize::XSmall)
176                            .icon_position(IconPosition::End)
177                            .truncate(true)
178                            .tooltip({
179                                let url = url.clone();
180                                move |window, cx| {
181                                    Tooltip::with_meta(
182                                        "Web Search Result",
183                                        None,
184                                        url.clone(),
185                                        window,
186                                        cx,
187                                    )
188                                }
189                            })
190                            .on_click({
191                                let url = url.clone();
192                                move |_, _, cx| cx.open_url(&url)
193                            })
194                    }))
195                    .into_any(),
196            ),
197            Err(_) => None,
198        });
199
200        v_flex().mb_3().gap_1().child(header).children(content)
201    }
202}
203
204impl Component for WebSearchToolCard {
205    fn scope() -> ComponentScope {
206        ComponentScope::Agent
207    }
208
209    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
210        let in_progress_search = cx.new(|cx| WebSearchToolCard {
211            response: None,
212            _task: cx.spawn(async move |_this, cx| {
213                loop {
214                    cx.background_executor()
215                        .timer(Duration::from_secs(60))
216                        .await
217                }
218            }),
219        });
220
221        let successful_search = cx.new(|_cx| WebSearchToolCard {
222            response: Some(Ok(example_search_response())),
223            _task: Task::ready(()),
224        });
225
226        let error_search = cx.new(|_cx| WebSearchToolCard {
227            response: Some(Err(anyhow!("Failed to resolve https://google.com"))),
228            _task: Task::ready(()),
229        });
230
231        Some(
232            v_flex()
233                .gap_6()
234                .children(vec![example_group(vec![
235                    single_example(
236                        "In Progress",
237                        div()
238                            .size_full()
239                            .child(in_progress_search.update(cx, |tool, cx| {
240                                tool.render(
241                                    &ToolUseStatus::Pending,
242                                    window,
243                                    WeakEntity::new_invalid(),
244                                    cx,
245                                )
246                                .into_any_element()
247                            }))
248                            .into_any_element(),
249                    ),
250                    single_example(
251                        "Successful",
252                        div()
253                            .size_full()
254                            .child(successful_search.update(cx, |tool, cx| {
255                                tool.render(
256                                    &ToolUseStatus::Finished("".into()),
257                                    window,
258                                    WeakEntity::new_invalid(),
259                                    cx,
260                                )
261                                .into_any_element()
262                            }))
263                            .into_any_element(),
264                    ),
265                    single_example(
266                        "Error",
267                        div()
268                            .size_full()
269                            .child(error_search.update(cx, |tool, cx| {
270                                tool.render(
271                                    &ToolUseStatus::Error("".into()),
272                                    window,
273                                    WeakEntity::new_invalid(),
274                                    cx,
275                                )
276                                .into_any_element()
277                            }))
278                            .into_any_element(),
279                    ),
280                ])])
281                .into_any_element(),
282        )
283    }
284}
285
286fn example_search_response() -> WebSearchResponse {
287    WebSearchResponse {
288        results: vec![
289            WebSearchResult {
290                title: "Alo".to_string(),
291                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
292                text: "Alo is a popular restaurant in Toronto.".to_string(),
293            },
294            WebSearchResult {
295                title: "Alo".to_string(),
296                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
297                text: "Information about Alo restaurant in Toronto.".to_string(),
298            },
299            WebSearchResult {
300                title: "Edulis".to_string(),
301                url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(),
302                text: "Details about Edulis restaurant in Toronto.".to_string(),
303            },
304            WebSearchResult {
305                title: "Sushi Masaki Saito".to_string(),
306                url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada"
307                    .to_string(),
308                text: "Information about Sushi Masaki Saito in Toronto.".to_string(),
309            },
310            WebSearchResult {
311                title: "Shoushin".to_string(),
312                url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(),
313                text: "Details about Shoushin restaurant in Toronto.".to_string(),
314            },
315            WebSearchResult {
316                title: "Restaurant 20 Victoria".to_string(),
317                url:
318                    "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada"
319                        .to_string(),
320                text: "Information about Restaurant 20 Victoria in Toronto.".to_string(),
321            },
322        ],
323    }
324}