web_search_tool.rs

  1use std::sync::Arc;
  2
  3use crate::{AgentTool, ToolCallEventStream};
  4use agent_client_protocol as acp;
  5use anyhow::{Result, anyhow};
  6use cloud_llm_client::WebSearchResponse;
  7use gpui::{App, AppContext, Task};
  8use language_model::LanguageModelToolResultContent;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use ui::prelude::*;
 12use web_search::WebSearchRegistry;
 13
 14/// Search the web for information using your query.
 15/// Use this when you need real-time information, facts, or data that might not be in your training. \
 16/// Results will include snippets and links from relevant web pages.
 17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 18pub struct WebSearchToolInput {
 19    /// The search term or question to query on the web.
 20    query: String,
 21}
 22
 23#[derive(Debug, Serialize, Deserialize)]
 24#[serde(transparent)]
 25pub struct WebSearchToolOutput(WebSearchResponse);
 26
 27impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
 28    fn from(value: WebSearchToolOutput) -> Self {
 29        serde_json::to_string(&value.0)
 30            .expect("Failed to serialize WebSearchResponse")
 31            .into()
 32    }
 33}
 34
 35pub struct WebSearchTool;
 36
 37impl AgentTool for WebSearchTool {
 38    type Input = WebSearchToolInput;
 39    type Output = WebSearchToolOutput;
 40
 41    fn name(&self) -> SharedString {
 42        "web_search".into()
 43    }
 44
 45    fn kind(&self) -> acp::ToolKind {
 46        acp::ToolKind::Fetch
 47    }
 48
 49    fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
 50        "Searching the Web".into()
 51    }
 52
 53    fn run(
 54        self: Arc<Self>,
 55        input: Self::Input,
 56        event_stream: ToolCallEventStream,
 57        cx: &mut App,
 58    ) -> Task<Result<Self::Output>> {
 59        let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
 60            return Task::ready(Err(anyhow!("Web search is not available.")));
 61        };
 62
 63        let search_task = provider.search(input.query, cx);
 64        cx.background_spawn(async move {
 65            let response = match search_task.await {
 66                Ok(response) => response,
 67                Err(err) => {
 68                    event_stream.update_fields(acp::ToolCallUpdateFields {
 69                        title: Some("Web Search Failed".to_string()),
 70                        ..Default::default()
 71                    });
 72                    return Err(err);
 73                }
 74            };
 75
 76            let result_text = if response.results.len() == 1 {
 77                "1 result".to_string()
 78            } else {
 79                format!("{} results", response.results.len())
 80            };
 81            event_stream.update_fields(acp::ToolCallUpdateFields {
 82                title: Some(format!("Searched the web: {result_text}")),
 83                content: Some(
 84                    response
 85                        .results
 86                        .iter()
 87                        .map(|result| acp::ToolCallContent::Content {
 88                            content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
 89                                name: result.title.clone(),
 90                                uri: result.url.clone(),
 91                                title: Some(result.title.clone()),
 92                                description: Some(result.text.clone()),
 93                                mime_type: None,
 94                                annotations: None,
 95                                size: None,
 96                            }),
 97                        })
 98                        .collect(),
 99                ),
100                ..Default::default()
101            });
102            Ok(WebSearchToolOutput(response))
103        })
104    }
105}