web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use anyhow::{Context as _, Result, anyhow};
  5use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
  6use futures::{FutureExt, TryFutureExt};
  7use gpui::{
  8    Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
  9    pulsating_between,
 10};
 11use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use ui::{IconName, Tooltip, prelude::*};
 16use web_search::WebSearchRegistry;
 17use zed_llm_client::WebSearchResponse;
 18
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct WebSearchToolInput {
 21    /// The search term or question to query on the web.
 22    query: String,
 23}
 24
 25pub struct WebSearchTool;
 26
 27impl Tool for WebSearchTool {
 28    fn name(&self) -> String {
 29        "web_search".into()
 30    }
 31
 32    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 33        false
 34    }
 35
 36    fn description(&self) -> String {
 37        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 38    }
 39
 40    fn icon(&self) -> IconName {
 41        IconName::Globe
 42    }
 43
 44    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 45        json_schema_for::<WebSearchToolInput>(format)
 46    }
 47
 48    fn ui_text(&self, _input: &serde_json::Value) -> String {
 49        "Web Search".to_string()
 50    }
 51
 52    fn run(
 53        self: Arc<Self>,
 54        input: serde_json::Value,
 55        _messages: &[LanguageModelRequestMessage],
 56        _project: Entity<Project>,
 57        _action_log: Entity<ActionLog>,
 58        cx: &mut App,
 59    ) -> ToolResult {
 60        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 61            Ok(input) => input,
 62            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 63        };
 64        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 65            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 66        };
 67
 68        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 69        let output = cx.background_spawn({
 70            let search_task = search_task.clone();
 71            async move {
 72                let response = search_task.await.map_err(|err| anyhow!(err))?;
 73                serde_json::to_string(&response).context("Failed to serialize search results")
 74            }
 75        });
 76
 77        ToolResult {
 78            output,
 79            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 80        }
 81    }
 82}
 83
 84struct WebSearchToolCard {
 85    response: Option<Result<WebSearchResponse>>,
 86    _task: Task<()>,
 87}
 88
 89impl WebSearchToolCard {
 90    fn new(
 91        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
 92        cx: &mut Context<Self>,
 93    ) -> Self {
 94        let _task = cx.spawn(async move |this, cx| {
 95            let response = search_task.await.map_err(|err| anyhow!(err));
 96            this.update(cx, |this, cx| {
 97                this.response = Some(response);
 98                cx.notify();
 99            })
100            .ok();
101        });
102
103        Self {
104            response: None,
105            _task,
106        }
107    }
108}
109
110impl ToolCard for WebSearchToolCard {
111    fn render(
112        &mut self,
113        _status: &ToolUseStatus,
114        _window: &mut Window,
115        cx: &mut Context<Self>,
116    ) -> impl IntoElement {
117        let header = h_flex()
118            .id("tool-label-container")
119            .gap_1p5()
120            .max_w_full()
121            .overflow_x_scroll()
122            .child(
123                Icon::new(IconName::Globe)
124                    .size(IconSize::XSmall)
125                    .color(Color::Muted),
126            )
127            .child(match self.response.as_ref() {
128                Some(Ok(response)) => {
129                    let text: SharedString = if response.citations.len() == 1 {
130                        "1 result".into()
131                    } else {
132                        format!("{} results", response.citations.len()).into()
133                    };
134                    h_flex()
135                        .gap_1p5()
136                        .child(Label::new("Searched the Web").size(LabelSize::Small))
137                        .child(
138                            div()
139                                .size(px(3.))
140                                .rounded_full()
141                                .bg(cx.theme().colors().text),
142                        )
143                        .child(Label::new(text).size(LabelSize::Small))
144                        .into_any_element()
145                }
146                Some(Err(error)) => div()
147                    .id("web-search-error")
148                    .child(Label::new("Web Search failed").size(LabelSize::Small))
149                    .tooltip(Tooltip::text(error.to_string()))
150                    .into_any_element(),
151
152                None => Label::new("Searching the Web…")
153                    .size(LabelSize::Small)
154                    .with_animation(
155                        "web-search-label",
156                        Animation::new(Duration::from_secs(2))
157                            .repeat()
158                            .with_easing(pulsating_between(0.6, 1.)),
159                        |label, delta| label.alpha(delta),
160                    )
161                    .into_any_element(),
162            })
163            .into_any();
164
165        let content =
166            self.response.as_ref().and_then(|response| match response {
167                Ok(response) => {
168                    Some(
169                        v_flex()
170                            .ml_1p5()
171                            .pl_1p5()
172                            .border_l_1()
173                            .border_color(cx.theme().colors().border_variant)
174                            .gap_1()
175                            .children(response.citations.iter().enumerate().map(
176                                |(index, citation)| {
177                                    let title = citation.title.clone();
178                                    let url = citation.url.clone();
179
180                                    Button::new(("citation", index), title)
181                                        .label_size(LabelSize::Small)
182                                        .color(Color::Muted)
183                                        .icon(IconName::ArrowUpRight)
184                                        .icon_size(IconSize::XSmall)
185                                        .icon_position(IconPosition::End)
186                                        .truncate(true)
187                                        .tooltip({
188                                            let url = url.clone();
189                                            move |window, cx| {
190                                                Tooltip::with_meta(
191                                                    "Citation Link",
192                                                    None,
193                                                    url.clone(),
194                                                    window,
195                                                    cx,
196                                                )
197                                            }
198                                        })
199                                        .on_click({
200                                            let url = url.clone();
201                                            move |_, _, cx| cx.open_url(&url)
202                                        })
203                                },
204                            ))
205                            .into_any(),
206                    )
207                }
208                Err(_) => None,
209            });
210
211        v_flex().my_2().gap_1().child(header).children(content)
212    }
213}