web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use crate::ui::ToolCallCardHeader;
  5use action_log::ActionLog;
  6use anyhow::{Context as _, Result, anyhow};
  7use assistant_tool::{
  8    Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput, ToolUseStatus,
  9};
 10use cloud_llm_client::{WebSearchResponse, WebSearchResult};
 11use futures::{Future, FutureExt, TryFutureExt};
 12use gpui::{
 13    AnyWindowHandle, App, AppContext, Context, Entity, IntoElement, Task, WeakEntity, Window,
 14};
 15use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 16use project::Project;
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use ui::{IconName, Tooltip, prelude::*};
 20use web_search::WebSearchRegistry;
 21use workspace::Workspace;
 22
 23#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 24pub struct WebSearchToolInput {
 25    /// The search term or question to query on the web.
 26    query: String,
 27}
 28
 29pub struct WebSearchTool;
 30
 31impl Tool for WebSearchTool {
 32    fn name(&self) -> String {
 33        "web_search".into()
 34    }
 35
 36    fn needs_confirmation(&self, _: &serde_json::Value, _: &Entity<Project>, _: &App) -> bool {
 37        false
 38    }
 39
 40    fn may_perform_edits(&self) -> bool {
 41        false
 42    }
 43
 44    fn description(&self) -> String {
 45        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 46    }
 47
 48    fn icon(&self) -> IconName {
 49        IconName::ToolWeb
 50    }
 51
 52    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 53        json_schema_for::<WebSearchToolInput>(format)
 54    }
 55
 56    fn ui_text(&self, _input: &serde_json::Value) -> String {
 57        "Searching the Web".to_string()
 58    }
 59
 60    fn run(
 61        self: Arc<Self>,
 62        input: serde_json::Value,
 63        _request: Arc<LanguageModelRequest>,
 64        _project: Entity<Project>,
 65        _action_log: Entity<ActionLog>,
 66        _model: Arc<dyn LanguageModel>,
 67        _window: Option<AnyWindowHandle>,
 68        cx: &mut App,
 69    ) -> ToolResult {
 70        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 71            Ok(input) => input,
 72            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 73        };
 74        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 75            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 76        };
 77
 78        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 79        let output = cx.background_spawn({
 80            let search_task = search_task.clone();
 81            async move {
 82                let response = search_task.await.map_err(|err| anyhow!(err))?;
 83                Ok(ToolResultOutput {
 84                    content: ToolResultContent::Text(
 85                        serde_json::to_string(&response)
 86                            .context("Failed to serialize search results")?,
 87                    ),
 88                    output: Some(serde_json::to_value(response)?),
 89                })
 90            }
 91        });
 92
 93        ToolResult {
 94            output,
 95            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 96        }
 97    }
 98
 99    fn deserialize_card(
100        self: Arc<Self>,
101        output: serde_json::Value,
102        _project: Entity<Project>,
103        _window: &mut Window,
104        cx: &mut App,
105    ) -> Option<assistant_tool::AnyToolCard> {
106        let output = serde_json::from_value::<WebSearchResponse>(output).ok()?;
107        let card = cx.new(|cx| WebSearchToolCard::new(Task::ready(Ok(output)), cx));
108        Some(card.into())
109    }
110}
111
112#[derive(RegisterComponent)]
113struct WebSearchToolCard {
114    response: Option<Result<WebSearchResponse>>,
115    _task: Task<()>,
116}
117
118impl WebSearchToolCard {
119    fn new(
120        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
121        cx: &mut Context<Self>,
122    ) -> Self {
123        let _task = cx.spawn(async move |this, cx| {
124            let response = search_task.await.map_err(|err| anyhow!(err));
125            this.update(cx, |this, cx| {
126                this.response = Some(response);
127                cx.notify();
128            })
129            .ok();
130        });
131
132        Self {
133            response: None,
134            _task,
135        }
136    }
137}
138
139impl ToolCard for WebSearchToolCard {
140    fn render(
141        &mut self,
142        _status: &ToolUseStatus,
143        _window: &mut Window,
144        _workspace: WeakEntity<Workspace>,
145        cx: &mut Context<Self>,
146    ) -> impl IntoElement {
147        let icon = IconName::ToolWeb;
148
149        let header = match self.response.as_ref() {
150            Some(Ok(response)) => {
151                let text: SharedString = if response.results.len() == 1 {
152                    "1 result".into()
153                } else {
154                    format!("{} results", response.results.len()).into()
155                };
156                ToolCallCardHeader::new(icon, "Searched the Web").with_secondary_text(text)
157            }
158            Some(Err(error)) => {
159                ToolCallCardHeader::new(icon, "Web Search").with_error(error.to_string())
160            }
161            None => ToolCallCardHeader::new(icon, "Searching the Web").loading(),
162        };
163
164        let content = self.response.as_ref().and_then(|response| match response {
165            Ok(response) => Some(
166                v_flex()
167                    .overflow_hidden()
168                    .ml_1p5()
169                    .pl(px(5.))
170                    .border_l_1()
171                    .border_color(cx.theme().colors().border_variant)
172                    .gap_1()
173                    .children(response.results.iter().enumerate().map(|(index, result)| {
174                        let title = result.title.clone();
175                        let url = SharedString::from(result.url.clone());
176
177                        Button::new(("result", index), title)
178                            .label_size(LabelSize::Small)
179                            .color(Color::Muted)
180                            .icon(IconName::ArrowUpRight)
181                            .icon_size(IconSize::Small)
182                            .icon_position(IconPosition::End)
183                            .truncate(true)
184                            .tooltip({
185                                let url = url.clone();
186                                move |window, cx| {
187                                    Tooltip::with_meta(
188                                        "Web Search Result",
189                                        None,
190                                        url.clone(),
191                                        window,
192                                        cx,
193                                    )
194                                }
195                            })
196                            .on_click({
197                                let url = url.clone();
198                                move |_, _, cx| cx.open_url(&url)
199                            })
200                    }))
201                    .into_any(),
202            ),
203            Err(_) => None,
204        });
205
206        v_flex().mb_3().gap_1().child(header).children(content)
207    }
208}
209
210impl Component for WebSearchToolCard {
211    fn scope() -> ComponentScope {
212        ComponentScope::Agent
213    }
214
215    fn preview(window: &mut Window, cx: &mut App) -> Option<AnyElement> {
216        let in_progress_search = cx.new(|cx| WebSearchToolCard {
217            response: None,
218            _task: cx.spawn(async move |_this, cx| {
219                loop {
220                    cx.background_executor()
221                        .timer(Duration::from_secs(60))
222                        .await
223                }
224            }),
225        });
226
227        let successful_search = cx.new(|_cx| WebSearchToolCard {
228            response: Some(Ok(example_search_response())),
229            _task: Task::ready(()),
230        });
231
232        let error_search = cx.new(|_cx| WebSearchToolCard {
233            response: Some(Err(anyhow!("Failed to resolve https://google.com"))),
234            _task: Task::ready(()),
235        });
236
237        Some(
238            v_flex()
239                .gap_6()
240                .children(vec![example_group(vec![
241                    single_example(
242                        "In Progress",
243                        div()
244                            .size_full()
245                            .child(in_progress_search.update(cx, |tool, cx| {
246                                tool.render(
247                                    &ToolUseStatus::Pending,
248                                    window,
249                                    WeakEntity::new_invalid(),
250                                    cx,
251                                )
252                                .into_any_element()
253                            }))
254                            .into_any_element(),
255                    ),
256                    single_example(
257                        "Successful",
258                        div()
259                            .size_full()
260                            .child(successful_search.update(cx, |tool, cx| {
261                                tool.render(
262                                    &ToolUseStatus::Finished("".into()),
263                                    window,
264                                    WeakEntity::new_invalid(),
265                                    cx,
266                                )
267                                .into_any_element()
268                            }))
269                            .into_any_element(),
270                    ),
271                    single_example(
272                        "Error",
273                        div()
274                            .size_full()
275                            .child(error_search.update(cx, |tool, cx| {
276                                tool.render(
277                                    &ToolUseStatus::Error("".into()),
278                                    window,
279                                    WeakEntity::new_invalid(),
280                                    cx,
281                                )
282                                .into_any_element()
283                            }))
284                            .into_any_element(),
285                    ),
286                ])])
287                .into_any_element(),
288        )
289    }
290}
291
292fn example_search_response() -> WebSearchResponse {
293    WebSearchResponse {
294        results: vec![
295            WebSearchResult {
296                title: "Alo".to_string(),
297                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
298                text: "Alo is a popular restaurant in Toronto.".to_string(),
299            },
300            WebSearchResult {
301                title: "Alo".to_string(),
302                url: "https://www.google.com/maps/search/Alo%2C+Toronto%2C+Canada".to_string(),
303                text: "Information about Alo restaurant in Toronto.".to_string(),
304            },
305            WebSearchResult {
306                title: "Edulis".to_string(),
307                url: "https://www.google.com/maps/search/Edulis%2C+Toronto%2C+Canada".to_string(),
308                text: "Details about Edulis restaurant in Toronto.".to_string(),
309            },
310            WebSearchResult {
311                title: "Sushi Masaki Saito".to_string(),
312                url: "https://www.google.com/maps/search/Sushi+Masaki+Saito%2C+Toronto%2C+Canada"
313                    .to_string(),
314                text: "Information about Sushi Masaki Saito in Toronto.".to_string(),
315            },
316            WebSearchResult {
317                title: "Shoushin".to_string(),
318                url: "https://www.google.com/maps/search/Shoushin%2C+Toronto%2C+Canada".to_string(),
319                text: "Details about Shoushin restaurant in Toronto.".to_string(),
320            },
321            WebSearchResult {
322                title: "Restaurant 20 Victoria".to_string(),
323                url:
324                    "https://www.google.com/maps/search/Restaurant+20+Victoria%2C+Toronto%2C+Canada"
325                        .to_string(),
326                text: "Information about Restaurant 20 Victoria in Toronto.".to_string(),
327            },
328        ],
329    }
330}