web_search_tool.rs

  1use std::sync::Arc;
  2
  3use crate::{
  4    AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
  5};
  6use agent_client_protocol as acp;
  7use agent_settings::AgentSettings;
  8use anyhow::Result;
  9use cloud_llm_client::WebSearchResponse;
 10use futures::FutureExt as _;
 11use gpui::{App, AppContext, Task};
 12use language_model::{
 13    LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::Settings;
 18use ui::prelude::*;
 19use util::markdown::MarkdownInlineCode;
 20use web_search::WebSearchRegistry;
 21
 22/// Search the web for information using your query.
 23/// Use this when you need real-time information, facts, or data that might not be in your training.
 24/// Results will include snippets and links from relevant web pages.
 25#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 26pub struct WebSearchToolInput {
 27    /// The search term or question to query on the web.
 28    query: String,
 29}
 30
 31#[derive(Debug, Serialize, Deserialize)]
 32#[serde(untagged)]
 33pub enum WebSearchToolOutput {
 34    Success(WebSearchResponse),
 35    Error { error: String },
 36}
 37
 38impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
 39    fn from(value: WebSearchToolOutput) -> Self {
 40        match value {
 41            WebSearchToolOutput::Success(response) => serde_json::to_string(&response)
 42                .unwrap_or_else(|e| format!("Failed to serialize web search response: {e}"))
 43                .into(),
 44            WebSearchToolOutput::Error { error } => error.into(),
 45        }
 46    }
 47}
 48
 49pub struct WebSearchTool;
 50
 51impl AgentTool for WebSearchTool {
 52    type Input = WebSearchToolInput;
 53    type Output = WebSearchToolOutput;
 54
 55    const NAME: &'static str = "web_search";
 56
 57    fn kind() -> acp::ToolKind {
 58        acp::ToolKind::Fetch
 59    }
 60
 61    fn initial_title(
 62        &self,
 63        _input: Result<Self::Input, serde_json::Value>,
 64        _cx: &mut App,
 65    ) -> SharedString {
 66        "Searching the Web".into()
 67    }
 68
 69    /// We currently only support Zed Cloud as a provider.
 70    fn supports_provider(provider: &LanguageModelProviderId) -> bool {
 71        provider == &ZED_CLOUD_PROVIDER_ID
 72    }
 73
 74    fn run(
 75        self: Arc<Self>,
 76        input: Self::Input,
 77        event_stream: ToolCallEventStream,
 78        cx: &mut App,
 79    ) -> Task<Result<Self::Output, Self::Output>> {
 80        let settings = AgentSettings::get_global(cx);
 81        let decision = decide_permission_from_settings(
 82            Self::NAME,
 83            std::slice::from_ref(&input.query),
 84            settings,
 85        );
 86
 87        let authorize = match decision {
 88            ToolPermissionDecision::Allow => None,
 89            ToolPermissionDecision::Deny(reason) => {
 90                return Task::ready(Err(WebSearchToolOutput::Error { error: reason }));
 91            }
 92            ToolPermissionDecision::Confirm => {
 93                let context =
 94                    crate::ToolPermissionContext::new(Self::NAME, vec![input.query.clone()]);
 95                Some(event_stream.authorize(
 96                    format!("Search the web for {}", MarkdownInlineCode(&input.query)),
 97                    context,
 98                    cx,
 99                ))
100            }
101        };
102
103        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
104            return Task::ready(Err(WebSearchToolOutput::Error {
105                error: "Web search is not available.".to_string(),
106            }));
107        };
108
109        let search_task = provider.search(input.query, cx);
110        cx.background_spawn(async move {
111            if let Some(authorize) = authorize {
112                authorize.await.map_err(|e| WebSearchToolOutput::Error { error: e.to_string() })?;
113            }
114
115            let response = futures::select! {
116                result = search_task.fuse() => {
117                    match result {
118                        Ok(response) => response,
119                        Err(err) => {
120                            event_stream
121                                .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
122                            return Err(WebSearchToolOutput::Error { error: err.to_string() });
123                        }
124                    }
125                }
126                _ = event_stream.cancelled_by_user().fuse() => {
127                    return Err(WebSearchToolOutput::Error { error: "Web search cancelled by user".to_string() });
128                }
129            };
130
131            emit_update(&response, &event_stream);
132            Ok(WebSearchToolOutput::Success(response))
133        })
134    }
135
136    fn replay(
137        &self,
138        _input: Self::Input,
139        output: Self::Output,
140        event_stream: ToolCallEventStream,
141        _cx: &mut App,
142    ) -> Result<()> {
143        if let WebSearchToolOutput::Success(response) = &output {
144            emit_update(response, &event_stream);
145        }
146        Ok(())
147    }
148}
149
150fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
151    let result_text = if response.results.len() == 1 {
152        "1 result".to_string()
153    } else {
154        format!("{} results", response.results.len())
155    };
156    event_stream.update_fields(
157        acp::ToolCallUpdateFields::new()
158            .title(format!("Searched the web: {result_text}"))
159            .content(
160                response
161                    .results
162                    .iter()
163                    .map(|result| {
164                        acp::ToolCallContent::Content(acp::Content::new(
165                            acp::ContentBlock::ResourceLink(
166                                acp::ResourceLink::new(result.title.clone(), result.url.clone())
167                                    .title(result.title.clone())
168                                    .description(result.text.clone()),
169                            ),
170                        ))
171                    })
172                    .collect::<Vec<_>>(),
173            ),
174    );
175}