web_search_tool.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use crate::schema::json_schema_for;
  4use crate::ui::ToolCallCardHeader;
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
  7use futures::{Future, FutureExt, TryFutureExt};
  8use gpui::{App, AppContext, Context, Entity, IntoElement, Task, Window};
  9use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
 10use project::Project;
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use ui::{IconName, Tooltip, prelude::*};
 14use web_search::WebSearchRegistry;
 15use zed_llm_client::WebSearchResponse;
 16
 17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 18pub struct WebSearchToolInput {
 19    /// The search term or question to query on the web.
 20    query: String,
 21}
 22
 23pub struct WebSearchTool;
 24
 25impl Tool for WebSearchTool {
 26    fn name(&self) -> String {
 27        "web_search".into()
 28    }
 29
 30    fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
 31        false
 32    }
 33
 34    fn description(&self) -> String {
 35        "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
 36    }
 37
 38    fn icon(&self) -> IconName {
 39        IconName::Globe
 40    }
 41
 42    fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
 43        json_schema_for::<WebSearchToolInput>(format)
 44    }
 45
 46    fn ui_text(&self, _input: &serde_json::Value) -> String {
 47        "Searching the Web".to_string()
 48    }
 49
 50    fn run(
 51        self: Arc<Self>,
 52        input: serde_json::Value,
 53        _messages: &[LanguageModelRequestMessage],
 54        _project: Entity<Project>,
 55        _action_log: Entity<ActionLog>,
 56        cx: &mut App,
 57    ) -> ToolResult {
 58        let input = match serde_json::from_value::<WebSearchToolInput>(input) {
 59            Ok(input) => input,
 60            Err(err) => return Task::ready(Err(anyhow!(err))).into(),
 61        };
 62        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 63            return Task::ready(Err(anyhow!("Web search is not available."))).into();
 64        };
 65
 66        let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
 67        let output = cx.background_spawn({
 68            let search_task = search_task.clone();
 69            async move {
 70                let response = search_task.await.map_err(|err| anyhow!(err))?;
 71                serde_json::to_string(&response).context("Failed to serialize search results")
 72            }
 73        });
 74
 75        ToolResult {
 76            output,
 77            card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
 78        }
 79    }
 80}
 81
 82struct WebSearchToolCard {
 83    response: Option<Result<WebSearchResponse>>,
 84    _task: Task<()>,
 85}
 86
 87impl WebSearchToolCard {
 88    fn new(
 89        search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
 90        cx: &mut Context<Self>,
 91    ) -> Self {
 92        let _task = cx.spawn(async move |this, cx| {
 93            let response = search_task.await.map_err(|err| anyhow!(err));
 94            this.update(cx, |this, cx| {
 95                this.response = Some(response);
 96                cx.notify();
 97            })
 98            .ok();
 99        });
100
101        Self {
102            response: None,
103            _task,
104        }
105    }
106}
107
108impl ToolCard for WebSearchToolCard {
109    fn render(
110        &mut self,
111        _status: &ToolUseStatus,
112        _window: &mut Window,
113        cx: &mut Context<Self>,
114    ) -> impl IntoElement {
115        let header = match self.response.as_ref() {
116            Some(Ok(response)) => {
117                let text: SharedString = if response.citations.len() == 1 {
118                    "1 result".into()
119                } else {
120                    format!("{} results", response.citations.len()).into()
121                };
122                ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
123                    .with_secondary_text(text)
124            }
125            Some(Err(error)) => {
126                ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
127            }
128            None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
129        };
130
131        let content =
132            self.response.as_ref().and_then(|response| match response {
133                Ok(response) => {
134                    Some(
135                        v_flex()
136                            .overflow_hidden()
137                            .ml_1p5()
138                            .pl(px(5.))
139                            .border_l_1()
140                            .border_color(cx.theme().colors().border_variant)
141                            .gap_1()
142                            .children(response.citations.iter().enumerate().map(
143                                |(index, citation)| {
144                                    let title = citation.title.clone();
145                                    let url = citation.url.clone();
146
147                                    Button::new(("citation", index), title)
148                                        .label_size(LabelSize::Small)
149                                        .color(Color::Muted)
150                                        .icon(IconName::ArrowUpRight)
151                                        .icon_size(IconSize::XSmall)
152                                        .icon_position(IconPosition::End)
153                                        .truncate(true)
154                                        .tooltip({
155                                            let url = url.clone();
156                                            move |window, cx| {
157                                                Tooltip::with_meta(
158                                                    "Citation Link",
159                                                    None,
160                                                    url.clone(),
161                                                    window,
162                                                    cx,
163                                                )
164                                            }
165                                        })
166                                        .on_click({
167                                            let url = url.clone();
168                                            move |_, _, cx| cx.open_url(&url)
169                                        })
170                                },
171                            ))
172                            .into_any(),
173                    )
174                }
175                Err(_) => None,
176            });
177
178        v_flex().mb_3().gap_1().child(header).children(content)
179    }
180}