web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use crate::ui::ToolCallCardHeader;
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{
  7    ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
  8};
  9use futures::{Future, FutureExt, TryFutureExt};
 10use gpui::{
 11    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 12};
 13use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 14use project::Project;
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use ui::{IconName, Tooltip, prelude::*};
 18use web_search::WebSearchRegistry;
 19use workspace::Workspace;
 20use zed_llm_client::{WebSearchResponse, WebSearchResult};
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct WebSearchToolInput {
 24    /// The search term or question to query on the web.
 25    query: String,
 26}
 27
 28pub struct WebSearchTool;
 29
 30impl Tool for WebSearchTool {
 31    type Input = WebSearchToolInput;
 32
 33    fn name(&self) -> String {
 34        "web_search".into()
 35    }
 36
 37    fn needs_confirmation(&self, _: &Self::Input, _: &App) -> bool {
 38        false
 39    }
 40
 41    fn may_perform_edits(&self) -> bool {
 42        false
 43    }
 44
 45    fn description(&self) -> String {
 46        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 47    }
 48
 49    fn icon(&self) -> IconName {
 50        IconName::Globe
 51    }
 52
 53    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 54        json_schema_for::<WebSearchToolInput>(format)
 55    }
 56
 57    fn ui_text(&self, _input: &Self::Input) -> String {
 58        "Searching the Web".to_string()
 59    }
 60
 61    fn run(
 62        self: Arc<Self>,
 63        input: Self::Input,
 64        _request: Arc<LanguageModelRequest>,
 65        _project: Entity<Project>,
 66        _action_log: Entity<ActionLog>,
 67        _model: Arc<dyn LanguageModel>,
 68        _window: Option<AnyWindowHandle>,
 69        cx: &mut App,
 70    ) -> ToolResult {
 71        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 72            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 73        };
 74
 75        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 76        let output = cx.background_spawn({
 77            let search_task = search_task.clone();
 78            async move {
 79                let response = search_task.await.map_err(|err| anyhow!(err))?;
 80                Ok(ToolResultOutput {
 81                    content: ToolResultContent::Text(
 82                        serde_json::to_string(&response)
 83                            .context("Failed to serialize search results")?,
 84                    ),
 85                    output: Some(serde_json::to_value(response)?),
 86                })
 87            }
 88        });
 89
 90        ToolResult {
 91            output,
 92            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 93        }
 94    }
 95
 96    fn deserialize_card(
 97        self: Arc<Self>,
 98        output: serde_json::Value,
 99        _project: Entity<Project>,
100        _window: &mut Window,
101        cx: &mut App,
102    ) -> Option<assistant_tool::AnyToolCard> {
103        let output = serde_json::from_value::<WebSearchResponse>(output).ok()?;
104        let card = cx.new(|cx| WebSearchToolCard::new(Task::ready(Ok(output)), cx));
105        Some(card.into())
106    }
107}
108
109#[derive(RegisterComponent)]
110struct WebSearchToolCard {
111    response: Option<Result<WebSearchResponse>>,
112    _task: Task<()>,
113}
114
115impl WebSearchToolCard {
116    fn new(
117        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
118        cx: &mut Context<Self>,
119    ) -> Self {
120        let _task = cx.spawn(async move |this, cx| {
121            let response = search_task.await.map_err(|err| anyhow!(err));
122            this.update(cx, |this, cx| {
123                this.response = Some(response);
124                cx.notify();
125            })
126            .ok();
127        });
128
129        Self {
130            response: None,
131            _task,
132        }
133    }
134}
135
136impl ToolCard for WebSearchToolCard {
137    fn render(
138        &mut self,
139        _status: &ToolUseStatus,
140        _window: &mut Window,
141        _workspace: WeakEntity<Workspace>,
142        cx: &mut Context<Self>,
143    ) -> impl IntoElement {
144        let header = match self.response.as_ref() {
145            Some(Ok(response)) => {
146                let text: SharedString = if response.results.len() == 1 {
147                    "1 result".into()
148                } else {
149                    format!("{} results", response.results.len()).into()
150                };
151                ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
152                    .with_secondary_text(text)
153            }
154            Some(Err(error)) => {
155                ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
156            }
157            None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
158        };
159
160        let content = self.response.as_ref().and_then(|response| match response {
161            Ok(response) => Some(
162                v_flex()
163                    .overflow_hidden()
164                    .ml_1p5()
165                    .pl(px(5.))
166                    .border_l_1()
167                    .border_color(cx.theme().colors().border_variant)
168                    .gap_1()
169                    .children(response.results.iter().enumerate().map(|(index, result)| {
170                        let title = result.title.clone();
171                        let url = SharedString::from(result.url.clone());
172
173                        Button::new(("result", index), title)
174                            .label_size(LabelSize::Small)
175                            .color(Color::Muted)
176                            .icon(IconName::ArrowUpRight)
177                            .icon_size(IconSize::XSmall)
178                            .icon_position(IconPosition::End)
179                            .truncate(true)
180                            .tooltip({
181                                let url = url.clone();
182                                move |window, cx| {
183                                    Tooltip::with_meta(
184                                        "Web Search Result",
185                                        None,
186                                        url.clone(),
187                                        window,
188                                        cx,
189                                    )
190                                }
191                            })
192                            .on_click({
193                                let url = url.clone();
194                                move |_, _, cx| cx.open_url(&url)
195                            })
196                    }))
197                    .into_any(),
198            ),
199            Err(_) => None,
200        });
201
202        v_flex().mb_3().gap_1().child(header).children(content)
203    }
204}
205
206impl Component for WebSearchToolCard {
207    fn scope() -> ComponentScope {
208        ComponentScope::Agent
209    }
210
211    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
212        let in_progress_search = cx.new(|cx| WebSearchToolCard {
213            response: None,
214            _task: cx.spawn(async move |_this, cx| {
215                loop {
216                    cx.background_executor()
217                        .timer(Duration::from_secs(60))
218                        .await
219                }
220            }),
221        });
222
223        let successful_search = cx.new(|_cx| WebSearchToolCard {
224            response: Some(Ok(example_search_response())),
225            _task: Task::ready(()),
226        });
227
228        let error_search = cx.new(|_cx| WebSearchToolCard {
229            response: Some(Err(anyhow!("Failed to resolve https://google.com"))),
230            _task: Task::ready(()),
231        });
232
233        Some(
234            v_flex()
235                .gap_6()
236                .children(vec![example_group(vec![
237                    single_example(
238                        "In Progress",
239                        div()
240                            .size_full()
241                            .child(in_progress_search.update(cx, |tool, cx| {
242                                tool.render(
243                                    &ToolUseStatus::Pending,
244                                    window,
245                                    WeakEntity::new_invalid(),
246                                    cx,
247                                )
248                                .into_any_element()
249                            }))
250                            .into_any_element(),
251                    ),
252                    single_example(
253                        "Successful",
254                        div()
255                            .size_full()
256                            .child(successful_search.update(cx, |tool, cx| {
257                                tool.render(
258                                    &ToolUseStatus::Finished("".into()),
259                                    window,
260                                    WeakEntity::new_invalid(),
261                                    cx,
262                                )
263                                .into_any_element()
264                            }))
265                            .into_any_element(),
266                    ),
267                    single_example(
268                        "Error",
269                        div()
270                            .size_full()
271                            .child(error_search.update(cx, |tool, cx| {
272                                tool.render(
273                                    &ToolUseStatus::Error("".into()),
274                                    window,
275                                    WeakEntity::new_invalid(),
276                                    cx,
277                                )
278                                .into_any_element()
279                            }))
280                            .into_any_element(),
281                    ),
282                ])])
283                .into_any_element(),
284        )
285    }
286}
287
288fn example_search_response() -> WebSearchResponse {
289    WebSearchResponse {
290        results: vec![
291            WebSearchResult {
292                title: "Alo".to_string(),
293                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
294                text: "Alo is a popular restaurant in Toronto.".to_string(),
295            },
296            WebSearchResult {
297                title: "Alo".to_string(),
298                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
299                text: "Information about Alo restaurant in Toronto.".to_string(),
300            },
301            WebSearchResult {
302                title: "Edulis".to_string(),
303                url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(),
304                text: "Details about Edulis restaurant in Toronto.".to_string(),
305            },
306            WebSearchResult {
307                title: "Sushi Masaki Saito".to_string(),
308                url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada"
309                    .to_string(),
310                text: "Information about Sushi Masaki Saito in Toronto.".to_string(),
311            },
312            WebSearchResult {
313                title: "Shoushin".to_string(),
314                url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(),
315                text: "Details about Shoushin restaurant in Toronto.".to_string(),
316            },
317            WebSearchResult {
318                title: "Restaurant 20 Victoria".to_string(),
319                url:
320                    "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada"
321                        .to_string(),
322                text: "Information about Restaurant 20 Victoria in Toronto.".to_string(),
323            },
324        ],
325    }
326}