web_search_tool.rs

  1use std::sync::Arc;
  2
  3use crate::{
  4    AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision,
  5    decide_permission_from_settings,
  6};
  7use agent_client_protocol as acp;
  8use agent_settings::AgentSettings;
  9use anyhow::Result;
 10use cloud_llm_client::WebSearchResponse;
 11use futures::FutureExt as _;
 12use gpui::{App, Task};
 13use language_model::{
 14    LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
 15};
 16use schemars::JsonSchema;
 17use serde::{Deserialize, Serialize};
 18use settings::Settings;
 19use ui::prelude::*;
 20use util::markdown::MarkdownInlineCode;
 21use web_search::WebSearchRegistry;
 22
 23/// Search the web for information using your query.
 24/// Use this when you need real-time information, facts, or data that might not be in your training.
 25/// Results will include snippets and links from relevant web pages.
 26#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 27pub struct WebSearchToolInput {
 28    /// The search term or question to query on the web.
 29    query: String,
 30}
 31
 32#[derive(Debug, Serialize, Deserialize)]
 33#[serde(untagged)]
 34pub enum WebSearchToolOutput {
 35    Success(WebSearchResponse),
 36    Error { error: String },
 37}
 38
 39impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
 40    fn from(value: WebSearchToolOutput) -> Self {
 41        match value {
 42            WebSearchToolOutput::Success(response) => serde_json::to_string(&response)
 43                .unwrap_or_else(|e| format!("Failed to serialize web search response: {e}"))
 44                .into(),
 45            WebSearchToolOutput::Error { error } => error.into(),
 46        }
 47    }
 48}
 49
 50pub struct WebSearchTool;
 51
 52impl AgentTool for WebSearchTool {
 53    type Input = WebSearchToolInput;
 54    type Output = WebSearchToolOutput;
 55
 56    const NAME: &'static str = "web_search";
 57
 58    fn kind() -> acp::ToolKind {
 59        acp::ToolKind::Fetch
 60    }
 61
 62    fn initial_title(
 63        &self,
 64        _input: Result<Self::Input, serde_json::Value>,
 65        _cx: &mut App,
 66    ) -> SharedString {
 67        "Searching the Web".into()
 68    }
 69
 70    /// We currently only support Zed Cloud as a provider.
 71    fn supports_provider(provider: &LanguageModelProviderId) -> bool {
 72        provider == &ZED_CLOUD_PROVIDER_ID
 73    }
 74
 75    fn run(
 76        self: Arc<Self>,
 77        input: ToolInput<Self::Input>,
 78        event_stream: ToolCallEventStream,
 79        cx: &mut App,
 80    ) -> Task<Result<Self::Output, Self::Output>> {
 81        cx.spawn(async move |cx| {
 82            let input = input
 83                .recv()
 84                .await
 85                .map_err(|e| WebSearchToolOutput::Error {
 86                    error: format!("Failed to receive tool input: {e}"),
 87                })?;
 88
 89            let (authorize, search_task) = cx.update(|cx| {
 90                let decision = decide_permission_from_settings(
 91                    Self::NAME,
 92                    std::slice::from_ref(&input.query),
 93                    AgentSettings::get_global(cx),
 94                );
 95
 96                let authorize = match decision {
 97                    ToolPermissionDecision::Allow => None,
 98                    ToolPermissionDecision::Deny(reason) => {
 99                        return Err(WebSearchToolOutput::Error { error: reason });
100                    }
101                    ToolPermissionDecision::Confirm => {
102                        let context =
103                            crate::ToolPermissionContext::new(Self::NAME, vec![input.query.clone()]);
104                        Some(event_stream.authorize(
105                            format!("Search the web for {}", MarkdownInlineCode(&input.query)),
106                            context,
107                            cx,
108                        ))
109                    }
110                };
111
112                let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
113                    return Err(WebSearchToolOutput::Error {
114                        error: "Web search is not available.".to_string(),
115                    });
116                };
117
118                let search_task = provider.search(input.query, cx);
119                Ok((authorize, search_task))
120            })?;
121
122            if let Some(authorize) = authorize {
123                authorize.await.map_err(|e| WebSearchToolOutput::Error { error: e.to_string() })?;
124            }
125
126            let response = futures::select! {
127                result = search_task.fuse() => {
128                    match result {
129                        Ok(response) => response,
130                        Err(err) => {
131                            event_stream
132                                .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
133                            return Err(WebSearchToolOutput::Error { error: err.to_string() });
134                        }
135                    }
136                }
137                _ = event_stream.cancelled_by_user().fuse() => {
138                    return Err(WebSearchToolOutput::Error { error: "Web search cancelled by user".to_string() });
139                }
140            };
141
142            emit_update(&response, &event_stream);
143            Ok(WebSearchToolOutput::Success(response))
144        })
145    }
146
147    fn replay(
148        &self,
149        _input: Self::Input,
150        output: Self::Output,
151        event_stream: ToolCallEventStream,
152        _cx: &mut App,
153    ) -> Result<()> {
154        if let WebSearchToolOutput::Success(response) = &output {
155            emit_update(response, &event_stream);
156        }
157        Ok(())
158    }
159}
160
161fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
162    let result_text = if response.results.len() == 1 {
163        "1 result".to_string()
164    } else {
165        format!("{} results", response.results.len())
166    };
167    event_stream.update_fields(
168        acp::ToolCallUpdateFields::new()
169            .title(format!("Searched the web: {result_text}"))
170            .content(
171                response
172                    .results
173                    .iter()
174                    .map(|result| {
175                        acp::ToolCallContent::Content(acp::Content::new(
176                            acp::ContentBlock::ResourceLink(
177                                acp::ResourceLink::new(result.title.clone(), result.url.clone())
178                                    .title(result.title.clone())
179                                    .description(result.text.clone()),
180                            ),
181                        ))
182                    })
183                    .collect::<Vec<_>>(),
184            ),
185    );
186}