1use std::sync::Arc;
2
3use crate::{AgentTool, ToolCallEventStream};
4use agent_client_protocol as acp;
5use anyhow::{Result, anyhow};
6use cloud_llm_client::WebSearchResponse;
7use gpui::{App, AppContext, Task};
8use language_model::{
9 LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use ui::prelude::*;
14use web_search::WebSearchRegistry;
15
16/// Search the web for information using your query.
17/// Use this when you need real-time information, facts, or data that might not be in your training.
18/// Results will include snippets and links from relevant web pages.
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct WebSearchToolInput {
21 /// The search term or question to query on the web.
22 query: String,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26#[serde(transparent)]
27pub struct WebSearchToolOutput(WebSearchResponse);
28
29impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
30 fn from(value: WebSearchToolOutput) -> Self {
31 serde_json::to_string(&value.0)
32 .expect("Failed to serialize WebSearchResponse")
33 .into()
34 }
35}
36
37pub struct WebSearchTool;
38
39impl AgentTool for WebSearchTool {
40 type Input = WebSearchToolInput;
41 type Output = WebSearchToolOutput;
42
43 fn name() -> &'static str {
44 "web_search"
45 }
46
47 fn kind() -> acp::ToolKind {
48 acp::ToolKind::Fetch
49 }
50
51 fn initial_title(
52 &self,
53 _input: Result<Self::Input, serde_json::Value>,
54 _cx: &mut App,
55 ) -> SharedString {
56 "Searching the Web".into()
57 }
58
59 /// We currently only support Zed Cloud as a provider.
60 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
61 provider == &ZED_CLOUD_PROVIDER_ID
62 }
63
64 fn run(
65 self: Arc<Self>,
66 input: Self::Input,
67 event_stream: ToolCallEventStream,
68 cx: &mut App,
69 ) -> Task<Result<Self::Output>> {
70 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
71 return Task::ready(Err(anyhow!("Web search is not available.")));
72 };
73
74 let search_task = provider.search(input.query, cx);
75 cx.background_spawn(async move {
76 let response = match search_task.await {
77 Ok(response) => response,
78 Err(err) => {
79 event_stream.update_fields(acp::ToolCallUpdateFields {
80 title: Some("Web Search Failed".to_string()),
81 ..Default::default()
82 });
83 return Err(err);
84 }
85 };
86
87 emit_update(&response, &event_stream);
88 Ok(WebSearchToolOutput(response))
89 })
90 }
91
92 fn replay(
93 &self,
94 _input: Self::Input,
95 output: Self::Output,
96 event_stream: ToolCallEventStream,
97 _cx: &mut App,
98 ) -> Result<()> {
99 emit_update(&output.0, &event_stream);
100 Ok(())
101 }
102}
103
104fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
105 let result_text = if response.results.len() == 1 {
106 "1 result".to_string()
107 } else {
108 format!("{} results", response.results.len())
109 };
110 event_stream.update_fields(acp::ToolCallUpdateFields {
111 title: Some(format!("Searched the web: {result_text}")),
112 content: Some(
113 response
114 .results
115 .iter()
116 .map(|result| acp::ToolCallContent::Content {
117 content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
118 name: result.title.clone(),
119 uri: result.url.clone(),
120 title: Some(result.title.clone()),
121 description: Some(result.text.clone()),
122 mime_type: None,
123 annotations: None,
124 size: None,
125 }),
126 })
127 .collect(),
128 ),
129 ..Default::default()
130 });
131}