web_search_tool.rs

  1use std::sync::Arc;
  2
  3use crate::{AgentTool, ToolCallEventStream};
  4use agent_client_protocol as acp;
  5use anyhow::{Result, anyhow};
  6use cloud_llm_client::WebSearchResponse;
  7use futures::FutureExt as _;
  8use gpui::{App, AppContext, Task};
  9use language_model::{
 10    LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
 11};
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use ui::prelude::*;
 15use web_search::WebSearchRegistry;
 16
 17/// Search the web for information using your query.
 18/// Use this when you need real-time information, facts, or data that might not be in your training.
 19/// Results will include snippets and links from relevant web pages.
 20#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 21pub struct WebSearchToolInput {
 22    /// The search term or question to query on the web.
 23    query: String,
 24}
 25
 26#[derive(Debug, Serialize, Deserialize)]
 27#[serde(transparent)]
 28pub struct WebSearchToolOutput(WebSearchResponse);
 29
 30impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
 31    fn from(value: WebSearchToolOutput) -> Self {
 32        serde_json::to_string(&value.0)
 33            .expect("Failed to serialize WebSearchResponse")
 34            .into()
 35    }
 36}
 37
 38pub struct WebSearchTool;
 39
 40impl AgentTool for WebSearchTool {
 41    type Input = WebSearchToolInput;
 42    type Output = WebSearchToolOutput;
 43
 44    fn name() -> &'static str {
 45        "web_search"
 46    }
 47
 48    fn kind() -> acp::ToolKind {
 49        acp::ToolKind::Fetch
 50    }
 51
 52    fn initial_title(
 53        &self,
 54        _input: Result<Self::Input, serde_json::Value>,
 55        _cx: &mut App,
 56    ) -> SharedString {
 57        "Searching the Web".into()
 58    }
 59
 60    /// We currently only support Zed Cloud as a provider.
 61    fn supports_provider(provider: &LanguageModelProviderId) -> bool {
 62        provider == &ZED_CLOUD_PROVIDER_ID
 63    }
 64
 65    fn run(
 66        self: Arc<Self>,
 67        input: Self::Input,
 68        event_stream: ToolCallEventStream,
 69        cx: &mut App,
 70    ) -> Task<Result<Self::Output>> {
 71        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 72            return Task::ready(Err(anyhow!("Web search is not available.")));
 73        };
 74
 75        let search_task = provider.search(input.query, cx);
 76        cx.background_spawn(async move {
 77            let response = futures::select! {
 78                result = search_task.fuse() => {
 79                    match result {
 80                        Ok(response) => response,
 81                        Err(err) => {
 82                            event_stream
 83                                .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
 84                            return Err(err);
 85                        }
 86                    }
 87                }
 88                _ = event_stream.cancelled_by_user().fuse() => {
 89                    anyhow::bail!("Web search cancelled by user");
 90                }
 91            };
 92
 93            emit_update(&response, &event_stream);
 94            Ok(WebSearchToolOutput(response))
 95        })
 96    }
 97
 98    fn replay(
 99        &self,
100        _input: Self::Input,
101        output: Self::Output,
102        event_stream: ToolCallEventStream,
103        _cx: &mut App,
104    ) -> Result<()> {
105        emit_update(&output.0, &event_stream);
106        Ok(())
107    }
108}
109
110fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
111    let result_text = if response.results.len() == 1 {
112        "1 result".to_string()
113    } else {
114        format!("{} results", response.results.len())
115    };
116    event_stream.update_fields(
117        acp::ToolCallUpdateFields::new()
118            .title(format!("Searched the web: {result_text}"))
119            .content(
120                response
121                    .results
122                    .iter()
123                    .map(|result| {
124                        acp::ToolCallContent::Content(acp::Content::new(
125                            acp::ContentBlock::ResourceLink(
126                                acp::ResourceLink::new(result.title.clone(), result.url.clone())
127                                    .title(result.title.clone())
128                                    .description(result.text.clone()),
129                            ),
130                        ))
131                    })
132                    .collect::<Vec<_>>(),
133            ),
134    );
135}