web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use crate::ui::ToolCallCardHeader;
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{
  7    ActionLog, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
  8};
  9use cloud_llm_client::{WebSearchResponse, WebSearchResult};
 10use futures::{Future, FutureExt, TryFutureExt};
 11use gpui::{
 12    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 13};
 14use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 15use project::Project;
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use ui::{IconName, Tooltip, prelude::*};
 19use web_search::WebSearchRegistry;
 20use workspace::Workspace;
 21
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct WebSearchToolInput {
 24    /// The search term or question to query on the web.
 25    query: String,
 26}
 27
 28pub struct WebSearchTool;
 29
 30impl Tool for WebSearchTool {
 31    fn name(&self) -> String {
 32        "web_search".into()
 33    }
 34
 35    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 36        false
 37    }
 38
 39    fn may_perform_edits(&self) -> bool {
 40        false
 41    }
 42
 43    fn description(&self) -> String {
 44        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 45    }
 46
 47    fn icon(&self) -> IconName {
 48        IconName::Globe
 49    }
 50
 51    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 52        json_schema_for::<WebSearchToolInput>(format)
 53    }
 54
 55    fn ui_text(&self, _input: &serde_json::Value) -> String {
 56        "Searching the Web".to_string()
 57    }
 58
 59    fn run(
 60        self: Arc<Self>,
 61        input: serde_json::Value,
 62        _request: Arc<LanguageModelRequest>,
 63        _project: Entity<Project>,
 64        _action_log: Entity<ActionLog>,
 65        _model: Arc<dyn LanguageModel>,
 66        _window: Option<AnyWindowHandle>,
 67        cx: &mut App,
 68    ) -> ToolResult {
 69        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 70            Ok(input) => input,
 71            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 72        };
 73        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 74            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 75        };
 76
 77        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 78        let output = cx.background_spawn({
 79            let search_task = search_task.clone();
 80            async move {
 81                let response = search_task.await.map_err(|err| anyhow!(err))?;
 82                Ok(ToolResultOutput {
 83                    content: ToolResultContent::Text(
 84                        serde_json::to_string(&response)
 85                            .context("Failed to serialize search results")?,
 86                    ),
 87                    output: Some(serde_json::to_value(response)?),
 88                })
 89            }
 90        });
 91
 92        ToolResult {
 93            output,
 94            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 95        }
 96    }
 97
 98    fn deserialize_card(
 99        self: Arc<Self>,
100        output: serde_json::Value,
101        _project: Entity<Project>,
102        _window: &mut Window,
103        cx: &mut App,
104    ) -> Option<assistant_tool::AnyToolCard> {
105        let output = serde_json::from_value::<WebSearchResponse>(output).ok()?;
106        let card = cx.new(|cx| WebSearchToolCard::new(Task::ready(Ok(output)), cx));
107        Some(card.into())
108    }
109}
110
111#[derive(RegisterComponent)]
112struct WebSearchToolCard {
113    response: Option<Result<WebSearchResponse>>,
114    _task: Task<()>,
115}
116
117impl WebSearchToolCard {
118    fn new(
119        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
120        cx: &mut Context<Self>,
121    ) -> Self {
122        let _task = cx.spawn(async move |this, cx| {
123            let response = search_task.await.map_err(|err| anyhow!(err));
124            this.update(cx, |this, cx| {
125                this.response = Some(response);
126                cx.notify();
127            })
128            .ok();
129        });
130
131        Self {
132            response: None,
133            _task,
134        }
135    }
136}
137
138impl ToolCard for WebSearchToolCard {
139    fn render(
140        &mut self,
141        _status: &ToolUseStatus,
142        _window: &mut Window,
143        _workspace: WeakEntity<Workspace>,
144        cx: &mut Context<Self>,
145    ) -> impl IntoElement {
146        let icon = IconName::ToolWeb;
147
148        let header = match self.response.as_ref() {
149            Some(Ok(response)) => {
150                let text: SharedString = if response.results.len() == 1 {
151                    "1 result".into()
152                } else {
153                    format!("{} results", response.results.len()).into()
154                };
155                ToolCallCardHeader::new(icon, "Searched the Web").with_secondary_text(text)
156            }
157            Some(Err(error)) => {
158                ToolCallCardHeader::new(icon, "Web Search").with_error(error.to_string())
159            }
160            None => ToolCallCardHeader::new(icon, "Searching the Web").loading(),
161        };
162
163        let content = self.response.as_ref().and_then(|response| match response {
164            Ok(response) => Some(
165                v_flex()
166                    .overflow_hidden()
167                    .ml_1p5()
168                    .pl(px(5.))
169                    .border_l_1()
170                    .border_color(cx.theme().colors().border_variant)
171                    .gap_1()
172                    .children(response.results.iter().enumerate().map(|(index, result)| {
173                        let title = result.title.clone();
174                        let url = SharedString::from(result.url.clone());
175
176                        Button::new(("result", index), title)
177                            .label_size(LabelSize::Small)
178                            .color(Color::Muted)
179                            .icon(IconName::ArrowUpRight)
180                            .icon_size(IconSize::XSmall)
181                            .icon_position(IconPosition::End)
182                            .truncate(true)
183                            .tooltip({
184                                let url = url.clone();
185                                move |window, cx| {
186                                    Tooltip::with_meta(
187                                        "Web Search Result",
188                                        None,
189                                        url.clone(),
190                                        window,
191                                        cx,
192                                    )
193                                }
194                            })
195                            .on_click({
196                                let url = url.clone();
197                                move |_, _, cx| cx.open_url(&url)
198                            })
199                    }))
200                    .into_any(),
201            ),
202            Err(_) => None,
203        });
204
205        v_flex().mb_3().gap_1().child(header).children(content)
206    }
207}
208
209impl Component for WebSearchToolCard {
210    fn scope() -> ComponentScope {
211        ComponentScope::Agent
212    }
213
214    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
215        let in_progress_search = cx.new(|cx| WebSearchToolCard {
216            response: None,
217            _task: cx.spawn(async move |_this, cx| {
218                loop {
219                    cx.background_executor()
220                        .timer(Duration::from_secs(60))
221                        .await
222                }
223            }),
224        });
225
226        let successful_search = cx.new(|_cx| WebSearchToolCard {
227            response: Some(Ok(example_search_response())),
228            _task: Task::ready(()),
229        });
230
231        let error_search = cx.new(|_cx| WebSearchToolCard {
232            response: Some(Err(anyhow!("Failed to resolve https://google.com"))),
233            _task: Task::ready(()),
234        });
235
236        Some(
237            v_flex()
238                .gap_6()
239                .children(vec![example_group(vec![
240                    single_example(
241                        "In Progress",
242                        div()
243                            .size_full()
244                            .child(in_progress_search.update(cx, |tool, cx| {
245                                tool.render(
246                                    &ToolUseStatus::Pending,
247                                    window,
248                                    WeakEntity::new_invalid(),
249                                    cx,
250                                )
251                                .into_any_element()
252                            }))
253                            .into_any_element(),
254                    ),
255                    single_example(
256                        "Successful",
257                        div()
258                            .size_full()
259                            .child(successful_search.update(cx, |tool, cx| {
260                                tool.render(
261                                    &ToolUseStatus::Finished("".into()),
262                                    window,
263                                    WeakEntity::new_invalid(),
264                                    cx,
265                                )
266                                .into_any_element()
267                            }))
268                            .into_any_element(),
269                    ),
270                    single_example(
271                        "Error",
272                        div()
273                            .size_full()
274                            .child(error_search.update(cx, |tool, cx| {
275                                tool.render(
276                                    &ToolUseStatus::Error("".into()),
277                                    window,
278                                    WeakEntity::new_invalid(),
279                                    cx,
280                                )
281                                .into_any_element()
282                            }))
283                            .into_any_element(),
284                    ),
285                ])])
286                .into_any_element(),
287        )
288    }
289}
290
291fn example_search_response() -> WebSearchResponse {
292    WebSearchResponse {
293        results: vec![
294            WebSearchResult {
295                title: "Alo".to_string(),
296                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
297                text: "Alo is a popular restaurant in Toronto.".to_string(),
298            },
299            WebSearchResult {
300                title: "Alo".to_string(),
301                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
302                text: "Information about Alo restaurant in Toronto.".to_string(),
303            },
304            WebSearchResult {
305                title: "Edulis".to_string(),
306                url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(),
307                text: "Details about Edulis restaurant in Toronto.".to_string(),
308            },
309            WebSearchResult {
310                title: "Sushi Masaki Saito".to_string(),
311                url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada"
312                    .to_string(),
313                text: "Information about Sushi Masaki Saito in Toronto.".to_string(),
314            },
315            WebSearchResult {
316                title: "Shoushin".to_string(),
317                url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(),
318                text: "Details about Shoushin restaurant in Toronto.".to_string(),
319            },
320            WebSearchResult {
321                title: "Restaurant 20 Victoria".to_string(),
322                url:
323                    "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada"
324                        .to_string(),
325                text: "Information about Restaurant 20 Victoria in Toronto.".to_string(),
326            },
327        ],
328    }
329}