web_search_tool.rs

  1use std::sync::Arc;
  2
  3use crate::{AgentTool, ToolCallEventStream};
  4use agent_client_protocol as acp;
  5use anyhow::{Result, anyhow};
  6use cloud_llm_client::WebSearchResponse;
  7use gpui::{App, AppContext, Task};
  8use language_model::{
  9    LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
 10};
 11use schemars::JsonSchema;
 12use serde::{Deserialize, Serialize};
 13use ui::prelude::*;
 14use web_search::WebSearchRegistry;
 15
 16/// Search the web for information using your query.
 17/// Use this when you need real-time information, facts, or data that might not be in your training.
 18/// Results will include snippets and links from relevant web pages.
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct WebSearchToolInput {
 21    /// The search term or question to query on the web.
 22    query: String,
 23}
 24
 25#[derive(Debug, Serialize, Deserialize)]
 26#[serde(transparent)]
 27pub struct WebSearchToolOutput(WebSearchResponse);
 28
 29impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
 30    fn from(value: WebSearchToolOutput) -> Self {
 31        serde_json::to_string(&value.0)
 32            .expect("Failed to serialize WebSearchResponse")
 33            .into()
 34    }
 35}
 36
 37pub struct WebSearchTool;
 38
 39impl AgentTool for WebSearchTool {
 40    type Input = WebSearchToolInput;
 41    type Output = WebSearchToolOutput;
 42
 43    fn name() -> &'static str {
 44        "web_search"
 45    }
 46
 47    fn kind() -> acp::ToolKind {
 48        acp::ToolKind::Fetch
 49    }
 50
 51    fn initial_title(
 52        &self,
 53        _input: Result<Self::Input, serde_json::Value>,
 54        _cx: &mut App,
 55    ) -> SharedString {
 56        "Searching the Web".into()
 57    }
 58
 59    /// We currently only support Zed Cloud as a provider.
 60    fn supports_provider(provider: &LanguageModelProviderId) -> bool {
 61        provider == &ZED_CLOUD_PROVIDER_ID
 62    }
 63
 64    fn run(
 65        self: Arc<Self>,
 66        input: Self::Input,
 67        event_stream: ToolCallEventStream,
 68        cx: &mut App,
 69    ) -> Task<Result<Self::Output>> {
 70        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 71            return Task::ready(Err(anyhow!("Web search is not available.")));
 72        };
 73
 74        let search_task = provider.search(input.query, cx);
 75        cx.background_spawn(async move {
 76            let response = match search_task.await {
 77                Ok(response) => response,
 78                Err(err) => {
 79                    event_stream
 80                        .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
 81                    return Err(err);
 82                }
 83            };
 84
 85            emit_update(&response, &event_stream);
 86            Ok(WebSearchToolOutput(response))
 87        })
 88    }
 89
 90    fn replay(
 91        &self,
 92        _input: Self::Input,
 93        output: Self::Output,
 94        event_stream: ToolCallEventStream,
 95        _cx: &mut App,
 96    ) -> Result<()> {
 97        emit_update(&output.0, &event_stream);
 98        Ok(())
 99    }
100}
101
102fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
103    let result_text = if response.results.len() == 1 {
104        "1 result".to_string()
105    } else {
106        format!("{} results", response.results.len())
107    };
108    event_stream.update_fields(
109        acp::ToolCallUpdateFields::new()
110            .title(format!("Searched the web: {result_text}"))
111            .content(
112                response
113                    .results
114                    .iter()
115                    .map(|result| {
116                        acp::ToolCallContent::Content(acp::Content::new(
117                            acp::ContentBlock::ResourceLink(
118                                acp::ResourceLink::new(result.title.clone(), result.url.clone())
119                                    .title(result.title.clone())
120                                    .description(result.text.clone()),
121                            ),
122                        ))
123                    })
124                    .collect(),
125            ),
126    );
127}