1use std::sync::Arc;
2
3use crate::{AgentTool, ToolCallEventStream};
4use agent_client_protocol as acp;
5use anyhow::{Result, anyhow};
6use cloud_llm_client::WebSearchResponse;
7use gpui::{App, AppContext, Task};
8use language_model::{
9 LanguageModelProviderId, LanguageModelToolResultContent, ZED_CLOUD_PROVIDER_ID,
10};
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use ui::prelude::*;
14use web_search::WebSearchRegistry;
15
16/// Search the web for information using your query.
17/// Use this when you need real-time information, facts, or data that might not be in your training.
18/// Results will include snippets and links from relevant web pages.
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct WebSearchToolInput {
21 /// The search term or question to query on the web.
22 query: String,
23}
24
25#[derive(Debug, Serialize, Deserialize)]
26#[serde(transparent)]
27pub struct WebSearchToolOutput(WebSearchResponse);
28
29impl From<WebSearchToolOutput> for LanguageModelToolResultContent {
30 fn from(value: WebSearchToolOutput) -> Self {
31 serde_json::to_string(&value.0)
32 .expect("Failed to serialize WebSearchResponse")
33 .into()
34 }
35}
36
37pub struct WebSearchTool;
38
39impl AgentTool for WebSearchTool {
40 type Input = WebSearchToolInput;
41 type Output = WebSearchToolOutput;
42
43 fn name() -> &'static str {
44 "web_search"
45 }
46
47 fn kind() -> acp::ToolKind {
48 acp::ToolKind::Fetch
49 }
50
51 fn initial_title(
52 &self,
53 _input: Result<Self::Input, serde_json::Value>,
54 _cx: &mut App,
55 ) -> SharedString {
56 "Searching the Web".into()
57 }
58
59 /// We currently only support Zed Cloud as a provider.
60 fn supports_provider(provider: &LanguageModelProviderId) -> bool {
61 provider == &ZED_CLOUD_PROVIDER_ID
62 }
63
64 fn run(
65 self: Arc<Self>,
66 input: Self::Input,
67 event_stream: ToolCallEventStream,
68 cx: &mut App,
69 ) -> Task<Result<Self::Output>> {
70 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
71 return Task::ready(Err(anyhow!("Web search is not available.")));
72 };
73
74 let search_task = provider.search(input.query, cx);
75 cx.background_spawn(async move {
76 let response = match search_task.await {
77 Ok(response) => response,
78 Err(err) => {
79 event_stream
80 .update_fields(acp::ToolCallUpdateFields::new().title("Web Search Failed"));
81 return Err(err);
82 }
83 };
84
85 emit_update(&response, &event_stream);
86 Ok(WebSearchToolOutput(response))
87 })
88 }
89
90 fn replay(
91 &self,
92 _input: Self::Input,
93 output: Self::Output,
94 event_stream: ToolCallEventStream,
95 _cx: &mut App,
96 ) -> Result<()> {
97 emit_update(&output.0, &event_stream);
98 Ok(())
99 }
100}
101
102fn emit_update(response: &WebSearchResponse, event_stream: &ToolCallEventStream) {
103 let result_text = if response.results.len() == 1 {
104 "1 result".to_string()
105 } else {
106 format!("{} results", response.results.len())
107 };
108 event_stream.update_fields(
109 acp::ToolCallUpdateFields::new()
110 .title(format!("Searched the web: {result_text}"))
111 .content(
112 response
113 .results
114 .iter()
115 .map(|result| {
116 acp::ToolCallContent::Content(acp::Content::new(
117 acp::ContentBlock::ResourceLink(
118 acp::ResourceLink::new(result.title.clone(), result.url.clone())
119 .title(result.title.clone())
120 .description(result.text.clone()),
121 ),
122 ))
123 })
124 .collect(),
125 ),
126 );
127}