web_search_tool.rs

  1use std::sync::Arc;
  2
  3use crate::{
  4    AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
  5};
  6use agent_client_protocol as acp;
  7use agent_settings::AgentSettings;
  8use anyhow::{Result, anyhow};
  9use cloud_llm_client::WebSearchResponse;
 10use futures::FutureExt as _;
 11use gpui::{App, AppContext, Task};
 12use language_model::{
 13    LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::Settings;
 18use ui::prelude::*;
 19use util::markdown::MarkdownInlineCode;
 20use web_search::WebSearchRegistry;
 21
 22/// Search the web for information using your query.
 23/// Use this when you need real-time information, facts, or data that might not be in your training.
 24/// Results will include snippets and links from relevant web pages.
 25#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 26pub struct WebSearchToolInput {
 27    /// The search term or question to query on the web.
 28    query: String,
 29}
 30
 31#[derive(Debug, Serialize, Deserialize)]
 32#[serde(transparent)]
 33pub struct WebSearchToolOutput(WebSearchResponse);
 34
 35impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
 36    fn from(value: WebSearchToolOutput) -> Self {
 37        serde_json::to_string(&value.0)
 38            .expect("Failed to serialize WebSearchResponse")
 39            .into()
 40    }
 41}
 42
 43pub struct WebSearchTool;
 44
 45impl AgentTool for WebSearchTool {
 46    type Input = WebSearchToolInput;
 47    type Output = WebSearchToolOutput;
 48
 49    const NAME: &'static str = "web_search";
 50
 51    fn kind() -> acp::ToolKind {
 52        acp::ToolKind::Fetch
 53    }
 54
 55    fn initial_title(
 56        &self,
 57        _input: Result<Self::Input, serde_json::Value>,
 58        _cx: &mut App,
 59    ) -> SharedString {
 60        "Searching the Web".into()
 61    }
 62
 63    /// We currently only support Zed Cloud as a provider.
 64    fn supports_provider(provider: &LanguageModelProviderId) -> bool {
 65        provider == &ZED_CLOUD_PROVIDER_ID
 66    }
 67
 68    fn run(
 69        self: Arc<Self>,
 70        input: Self::Input,
 71        event_stream: ToolCallEventStream,
 72        cx: &mut App,
 73    ) -> Task<Result<Self::Output>> {
 74        let settings = AgentSettings::get_global(cx);
 75        let decision = decide_permission_from_settings(
 76            Self::NAME,
 77            std::slice::from_ref(&input.query),
 78            settings,
 79        );
 80
 81        let authorize = match decision {
 82            ToolPermissionDecision::Allow => None,
 83            ToolPermissionDecision::Deny(reason) => {
 84                return Task::ready(Err(anyhow!("{}", reason)));
 85            }
 86            ToolPermissionDecision::Confirm => {
 87                let context = crate::ToolPermissionContext {
 88                    tool_name: Self::NAME.to_string(),
 89                    input_values: vec![input.query.clone()],
 90                };
 91                Some(event_stream.authorize(
 92                    format!("Search the web for {}", MarkdownInlineCode(&input.query)),
 93                    context,
 94                    cx,
 95                ))
 96            }
 97        };
 98
 99        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
100            return Task::ready(Err(anyhow!("Web search is not available.")));
101        };
102
103        let search_task = provider.search(input.query, cx);
104        cx.background_spawn(async move {
105            if let Some(authorize) = authorize {
106                authorize.await?;
107            }
108
109            let response = futures::select! {
110                result = search_task.fuse() => {
111                    match result {
112                        Ok(response) => response,
113                        Err(err) => {
114                            event_stream
115                                .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
116                            return Err(err);
117                        }
118                    }
119                }
120                _ = event_stream.cancelled_by_user().fuse() => {
121                    anyhow::bail!("Web search cancelled by user");
122                }
123            };
124
125            emit_update(&response, &event_stream);
126            Ok(WebSearchToolOutput(response))
127        })
128    }
129
130    fn replay(
131        &self,
132        _input: Self::Input,
133        output: Self::Output,
134        event_stream: ToolCallEventStream,
135        _cx: &mut App,
136    ) -> Result<()> {
137        emit_update(&output.0, &event_stream);
138        Ok(())
139    }
140}
141
142fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
143    let result_text = if response.results.len() == 1 {
144        "1 result".to_string()
145    } else {
146        format!("{} results", response.results.len())
147    };
148    event_stream.update_fields(
149        acp::ToolCallUpdateFields::new()
150            .title(format!("Searched the web: {result_text}"))
151            .content(
152                response
153                    .results
154                    .iter()
155                    .map(|result| {
156                        acp::ToolCallContent::Content(acp::Content::new(
157                            acp::ContentBlock::ResourceLink(
158                                acp::ResourceLink::new(result.title.clone(), result.url.clone())
159                                    .title(result.title.clone())
160                                    .description(result.text.clone()),
161                            ),
162                        ))
163                    })
164                    .collect::<Vec<_>>(),
165            ),
166    );
167}