web_search_tool.rs

  1use std::sync::Arc;
  2
  3use crate::{AgentTool, ToolCallEventStream};
  4use agent_client_protocol as acp;
  5use anyhow::{Result, anyhow};
  6use cloud_llm_client::WebSearchResponse;
  7use gpui::{App, AppContext, Task};
  8use language_model::{
  9    LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
 10};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use ui::prelude::*;
 14use web_search::WebSearchRegistry;
 15
 16/// Search the web for information using your query.
 17/// Use this when you need real-time information, facts, or data that might not be in your training.
 18/// Results will include snippets and links from relevant web pages.
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct WebSearchToolInput {
 21    /// The search term or question to query on the web.
 22    query: String,
 23}
 24
 25#[derive(Debug, Serialize, Deserialize)]
 26#[serde(transparent)]
 27pub struct WebSearchToolOutput(WebSearchResponse);
 28
 29impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
 30    fn from(value: WebSearchToolOutput) -> Self {
 31        serde_json::to_string(&value.0)
 32            .expect("Failed to serialize WebSearchResponse")
 33            .into()
 34    }
 35}
 36
 37pub struct WebSearchTool;
 38
 39impl AgentTool for WebSearchTool {
 40    type Input = WebSearchToolInput;
 41    type Output = WebSearchToolOutput;
 42
 43    fn name() -> &'static str {
 44        "web_search"
 45    }
 46
 47    fn kind() -> acp::ToolKind {
 48        acp::ToolKind::Fetch
 49    }
 50
 51    fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
 52        "Searching the Web".into()
 53    }
 54
 55    /// We currently only support Zed Cloud as a provider.
 56    fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
 57        provider == &ZED_CLOUD_PROVIDER_ID
 58    }
 59
 60    fn run(
 61        self: Arc<Self>,
 62        input: Self::Input,
 63        event_stream: ToolCallEventStream,
 64        cx: &mut App,
 65    ) -> Task<Result<Self::Output>> {
 66        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 67            return Task::ready(Err(anyhow!("Web search is not available.")));
 68        };
 69
 70        let search_task = provider.search(input.query, cx);
 71        cx.background_spawn(async move {
 72            let response = match search_task.await {
 73                Ok(response) => response,
 74                Err(err) => {
 75                    event_stream.update_fields(acp::ToolCallUpdateFields {
 76                        title: Some("Web Search Failed".to_string()),
 77                        ..Default::default()
 78                    });
 79                    return Err(err);
 80                }
 81            };
 82
 83            emit_update(&response, &event_stream);
 84            Ok(WebSearchToolOutput(response))
 85        })
 86    }
 87
 88    fn replay(
 89        &self,
 90        _input: Self::Input,
 91        output: Self::Output,
 92        event_stream: ToolCallEventStream,
 93        _cx: &mut App,
 94    ) -> Result<()> {
 95        emit_update(&output.0, &event_stream);
 96        Ok(())
 97    }
 98}
 99
100fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
101    let result_text = if response.results.len() == 1 {
102        "1 result".to_string()
103    } else {
104        format!("{} results", response.results.len())
105    };
106    event_stream.update_fields(acp::ToolCallUpdateFields {
107        title: Some(format!("Searched the web: {result_text}")),
108        content: Some(
109            response
110                .results
111                .iter()
112                .map(|result| acp::ToolCallContent::Content {
113                    content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
114                        name: result.title.clone(),
115                        uri: result.url.clone(),
116                        title: Some(result.title.clone()),
117                        description: Some(result.text.clone()),
118                        mime_type: None,
119                        annotations: None,
120                        size: None,
121                    }),
122                })
123                .collect(),
124        ),
125        ..Default::default()
126    });
127}