1use std::sync::Arc;
2
3use crate::{AgentTool, ToolCallEventStream};
4use agent_client_protocol as acp;
5use anyhow::{Result, anyhow};
6use cloud_llm_client::WebSearchResponse;
7use gpui::{App, AppContext, Task};
8use language_model::LanguageModelToolResultContent;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use ui::prelude::*;
12use web_search::WebSearchRegistry;
13
14/// Search the web for information using your query.
15/// Use this when you need real-time information, facts, or data that might not be in your training. \
16/// Results will include snippets and links from relevant web pages.
17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
18pub struct WebSearchToolInput {
19 /// The search term or question to query on the web.
20 query: String,
21}
22
23#[derive(Debug, Serialize, Deserialize)]
24#[serde(transparent)]
25pub struct WebSearchToolOutput(WebSearchResponse);
26
27impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
28 fn from(value: WebSearchToolOutput) -> Self {
29 serde_json::to_string(&value.0)
30 .expect("Failed to serialize WebSearchResponse")
31 .into()
32 }
33}
34
35pub struct WebSearchTool;
36
37impl AgentTool for WebSearchTool {
38 type Input = WebSearchToolInput;
39 type Output = WebSearchToolOutput;
40
41 fn name(&self) -> SharedString {
42 "web_search".into()
43 }
44
45 fn kind(&self) -> acp::ToolKind {
46 acp::ToolKind::Fetch
47 }
48
49 fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
50 "Searching the Web".into()
51 }
52
53 fn run(
54 self: Arc<Self>,
55 input: Self::Input,
56 event_stream: ToolCallEventStream,
57 cx: &mut App,
58 ) -> Task<Result<Self::Output>> {
59 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
60 return Task::ready(Err(anyhow!("Web search is not available.")));
61 };
62
63 let search_task = provider.search(input.query, cx);
64 cx.background_spawn(async move {
65 let response = match search_task.await {
66 Ok(response) => response,
67 Err(err) => {
68 event_stream.update_fields(acp::ToolCallUpdateFields {
69 title: Some("Web Search Failed".to_string()),
70 ..Default::default()
71 });
72 return Err(err);
73 }
74 };
75
76 let result_text = if response.results.len() == 1 {
77 "1 result".to_string()
78 } else {
79 format!("{} results", response.results.len())
80 };
81 event_stream.update_fields(acp::ToolCallUpdateFields {
82 title: Some(format!("Searched the web: {result_text}")),
83 content: Some(
84 response
85 .results
86 .iter()
87 .map(|result| acp::ToolCallContent::Content {
88 content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
89 name: result.title.clone(),
90 uri: result.url.clone(),
91 title: Some(result.title.clone()),
92 description: Some(result.text.clone()),
93 mime_type: None,
94 annotations: None,
95 size: None,
96 }),
97 })
98 .collect(),
99 ),
100 ..Default::default()
101 });
102 Ok(WebSearchToolOutput(response))
103 })
104 }
105}