web_search_tool.rs

  1use std::sync::Arc;
  2
  3use crate::{AgentTool, ToolCallEventStream};
  4use agent_client_protocol as acp;
  5use anyhow::{Result, anyhow};
  6use cloud_llm_client::WebSearchResponse;
  7use gpui::{App, AppContext, Task};
  8use language_model::{
  9    LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
 10};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use ui::prelude::*;
 14use web_search::WebSearchRegistry;
 15
 16/// Search the web for information using your query.
 17/// Use this when you need real-time information, facts, or data that might not be in your training.
 18/// Results will include snippets and links from relevant web pages.
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct WebSearchToolInput {
 21    /// The search term or question to query on the web.
 22    query: String,
 23}
 24
 25#[derive(Debug, Serialize, Deserialize)]
 26#[serde(transparent)]
 27pub struct WebSearchToolOutput(WebSearchResponse);
 28
 29impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
 30    fn from(value: WebSearchToolOutput) -> Self {
 31        serde_json::to_string(&value.0)
 32            .expect("Failed to serialize WebSearchResponse")
 33            .into()
 34    }
 35}
 36
 37pub struct WebSearchTool;
 38
 39impl AgentTool for WebSearchTool {
 40    type Input = WebSearchToolInput;
 41    type Output = WebSearchToolOutput;
 42
 43    fn name() -> &'static str {
 44        "web_search"
 45    }
 46
 47    fn kind() -> acp::ToolKind {
 48        acp::ToolKind::Fetch
 49    }
 50
 51    fn initial_title(
 52        &self,
 53        _input: Result<Self::Input, serde_json::Value>,
 54        _cx: &mut App,
 55    ) -> SharedString {
 56        "Searching the Web".into()
 57    }
 58
 59    /// We currently only support Zed Cloud as a provider.
 60    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
 61        provider == &ZED_CLOUD_PROVIDER_ID
 62    }
 63
 64    fn run(
 65        self: Arc<Self>,
 66        input: Self::Input,
 67        event_stream: ToolCallEventStream,
 68        cx: &mut App,
 69    ) -> Task<Result<Self::Output>> {
 70        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 71            return Task::ready(Err(anyhow!("Web search is not available.")));
 72        };
 73
 74        let search_task = provider.search(input.query, cx);
 75        cx.background_spawn(async move {
 76            let response = match search_task.await {
 77                Ok(response) => response,
 78                Err(err) => {
 79                    event_stream.update_fields(acp::ToolCallUpdateFields {
 80                        title: Some("Web Search Failed".to_string()),
 81                        ..Default::default()
 82                    });
 83                    return Err(err);
 84                }
 85            };
 86
 87            emit_update(&response, &event_stream);
 88            Ok(WebSearchToolOutput(response))
 89        })
 90    }
 91
 92    fn replay(
 93        &self,
 94        _input: Self::Input,
 95        output: Self::Output,
 96        event_stream: ToolCallEventStream,
 97        _cx: &mut App,
 98    ) -> Result<()> {
 99        emit_update(&output.0, &event_stream);
100        Ok(())
101    }
102}
103
104fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
105    let result_text = if response.results.len() == 1 {
106        "1 result".to_string()
107    } else {
108        format!("{} results", response.results.len())
109    };
110    event_stream.update_fields(acp::ToolCallUpdateFields {
111        title: Some(format!("Searched the web: {result_text}")),
112        content: Some(
113            response
114                .results
115                .iter()
116                .map(|result| acp::ToolCallContent::Content {
117                    content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
118                        name: result.title.clone(),
119                        uri: result.url.clone(),
120                        title: Some(result.title.clone()),
121                        description: Some(result.text.clone()),
122                        mime_type: None,
123                        annotations: None,
124                        size: None,
125                    }),
126                })
127                .collect(),
128        ),
129        ..Default::default()
130    });
131}