1use std::sync::Arc;
2
3use crate::{AgentTool, ToolCallEventStream};
4use agent_client_protocol as acp;
5use anyhow::{Result, anyhow};
6use cloud_llm_client::WebSearchResponse;
7use gpui::{App, AppContext, Task};
8use language_model::{
9 LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use ui::prelude::*;
14use web_search::WebSearchRegistry;
15
16/// Search the web for information using your query.
17/// Use this when you need real-time information, facts, or data that might not be in your training.
18/// Results will include snippets and links from relevant web pages.
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct WebSearchToolInput {
21 /// The search term or question to query on the web.
22 query: String,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26#[serde(transparent)]
27pub struct WebSearchToolOutput(WebSearchResponse);
28
29impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
30 fn from(value: WebSearchToolOutput) -> Self {
31 serde_json::to_string(&value.0)
32 .expect("Failed to serialize WebSearchResponse")
33 .into()
34 }
35}
36
37pub struct WebSearchTool;
38
39impl AgentTool for WebSearchTool {
40 type Input = WebSearchToolInput;
41 type Output = WebSearchToolOutput;
42
43 fn name() -> &'static str {
44 "web_search"
45 }
46
47 fn kind() -> acp::ToolKind {
48 acp::ToolKind::Fetch
49 }
50
51 fn initial_title(&self, _input: Result<Self::Input, serde_json::Value>) -> SharedString {
52 "Searching the Web".into()
53 }
54
55 /// We currently only support Zed Cloud as a provider.
56 fn supported_provider(&self, provider: &LanguageModelProviderId) -> bool {
57 provider == &ZED_CLOUD_PROVIDER_ID
58 }
59
60 fn run(
61 self: Arc<Self>,
62 input: Self::Input,
63 event_stream: ToolCallEventStream,
64 cx: &mut App,
65 ) -> Task<Result<Self::Output>> {
66 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
67 return Task::ready(Err(anyhow!("Web search is not available.")));
68 };
69
70 let search_task = provider.search(input.query, cx);
71 cx.background_spawn(async move {
72 let response = match search_task.await {
73 Ok(response) => response,
74 Err(err) => {
75 event_stream.update_fields(acp::ToolCallUpdateFields {
76 title: Some("Web Search Failed".to_string()),
77 ..Default::default()
78 });
79 return Err(err);
80 }
81 };
82
83 emit_update(&response, &event_stream);
84 Ok(WebSearchToolOutput(response))
85 })
86 }
87
88 fn replay(
89 &self,
90 _input: Self::Input,
91 output: Self::Output,
92 event_stream: ToolCallEventStream,
93 _cx: &mut App,
94 ) -> Result<()> {
95 emit_update(&output.0, &event_stream);
96 Ok(())
97 }
98}
99
100fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
101 let result_text = if response.results.len() == 1 {
102 "1 result".to_string()
103 } else {
104 format!("{} results", response.results.len())
105 };
106 event_stream.update_fields(acp::ToolCallUpdateFields {
107 title: Some(format!("Searched the web: {result_text}")),
108 content: Some(
109 response
110 .results
111 .iter()
112 .map(|result| acp::ToolCallContent::Content {
113 content: acp::ContentBlock::ResourceLink(acp::ResourceLink {
114 name: result.title.clone(),
115 uri: result.url.clone(),
116 title: Some(result.title.clone()),
117 description: Some(result.text.clone()),
118 mime_type: None,
119 annotations: None,
120 size: None,
121 }),
122 })
123 .collect(),
124 ),
125 ..Default::default()
126 });
127}