1use std::{sync::Arc, time::Duration};
2
3use crate::schema::json_schema_for;
4use anyhow::{Context as _, Result, anyhow};
5use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
6use futures::{FutureExt, TryFutureExt};
7use gpui::{
8 Animation, AnimationExt, App, AppContext, Context, Entity, IntoElement, Task, Window,
9 pulsating_between,
10};
11use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
12use project::Project;
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use ui::{IconName, Tooltip, prelude::*};
16use web_search::WebSearchRegistry;
17use zed_llm_client::WebSearchResponse;
18
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct WebSearchToolInput {
21 /// The search term or question to query on the web.
22 query: String,
23}
24
25pub struct WebSearchTool;
26
27impl Tool for WebSearchTool {
28 fn name(&self) -> String {
29 "web_search".into()
30 }
31
32 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
33 false
34 }
35
36 fn description(&self) -> String {
37 "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
38 }
39
40 fn icon(&self) -> IconName {
41 IconName::Globe
42 }
43
44 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
45 json_schema_for::<WebSearchToolInput>(format)
46 }
47
48 fn ui_text(&self, _input: &serde_json::Value) -> String {
49 "Web Search".to_string()
50 }
51
52 fn run(
53 self: Arc<Self>,
54 input: serde_json::Value,
55 _messages: &[LanguageModelRequestMessage],
56 _project: Entity<Project>,
57 _action_log: Entity<ActionLog>,
58 cx: &mut App,
59 ) -> ToolResult {
60 let input = match serde_json::from_value::<WebSearchToolInput>(input) {
61 Ok(input) => input,
62 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
63 };
64 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
65 return Task::ready(Err(anyhow!("Web search is not available."))).into();
66 };
67
68 let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
69 let output = cx.background_spawn({
70 let search_task = search_task.clone();
71 async move {
72 let response = search_task.await.map_err(|err| anyhow!(err))?;
73 serde_json::to_string(&response).context("Failed to serialize search results")
74 }
75 });
76
77 ToolResult {
78 output,
79 card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
80 }
81 }
82}
83
84struct WebSearchToolCard {
85 response: Option<Result<WebSearchResponse>>,
86 _task: Task<()>,
87}
88
89impl WebSearchToolCard {
90 fn new(
91 search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
92 cx: &mut Context<Self>,
93 ) -> Self {
94 let _task = cx.spawn(async move |this, cx| {
95 let response = search_task.await.map_err(|err| anyhow!(err));
96 this.update(cx, |this, cx| {
97 this.response = Some(response);
98 cx.notify();
99 })
100 .ok();
101 });
102
103 Self {
104 response: None,
105 _task,
106 }
107 }
108}
109
110impl ToolCard for WebSearchToolCard {
111 fn render(
112 &mut self,
113 _status: &ToolUseStatus,
114 _window: &mut Window,
115 cx: &mut Context<Self>,
116 ) -> impl IntoElement {
117 let header = h_flex()
118 .id("tool-label-container")
119 .gap_1p5()
120 .max_w_full()
121 .overflow_x_scroll()
122 .child(
123 Icon::new(IconName::Globe)
124 .size(IconSize::XSmall)
125 .color(Color::Muted),
126 )
127 .child(match self.response.as_ref() {
128 Some(Ok(response)) => {
129 let text: SharedString = if response.citations.len() == 1 {
130 "1 result".into()
131 } else {
132 format!("{} results", response.citations.len()).into()
133 };
134 h_flex()
135 .gap_1p5()
136 .child(Label::new("Searched the Web").size(LabelSize::Small))
137 .child(
138 div()
139 .size(px(3.))
140 .rounded_full()
141 .bg(cx.theme().colors().text),
142 )
143 .child(Label::new(text).size(LabelSize::Small))
144 .into_any_element()
145 }
146 Some(Err(error)) => div()
147 .id("web-search-error")
148 .child(Label::new("Web Search failed").size(LabelSize::Small))
149 .tooltip(Tooltip::text(error.to_string()))
150 .into_any_element(),
151
152 None => Label::new("Searching the Web…")
153 .size(LabelSize::Small)
154 .with_animation(
155 "web-search-label",
156 Animation::new(Duration::from_secs(2))
157 .repeat()
158 .with_easing(pulsating_between(0.6, 1.)),
159 |label, delta| label.alpha(delta),
160 )
161 .into_any_element(),
162 })
163 .into_any();
164
165 let content =
166 self.response.as_ref().and_then(|response| match response {
167 Ok(response) => {
168 Some(
169 v_flex()
170 .ml_1p5()
171 .pl_1p5()
172 .border_l_1()
173 .border_color(cx.theme().colors().border_variant)
174 .gap_1()
175 .children(response.citations.iter().enumerate().map(
176 |(index, citation)| {
177 let title = citation.title.clone();
178 let url = citation.url.clone();
179
180 Button::new(("citation", index), title)
181 .label_size(LabelSize::Small)
182 .color(Color::Muted)
183 .icon(IconName::ArrowUpRight)
184 .icon_size(IconSize::XSmall)
185 .icon_position(IconPosition::End)
186 .truncate(true)
187 .tooltip({
188 let url = url.clone();
189 move |window, cx| {
190 Tooltip::with_meta(
191 "Citation Link",
192 None,
193 url.clone(),
194 window,
195 cx,
196 )
197 }
198 })
199 .on_click({
200 let url = url.clone();
201 move |_, _, cx| cx.open_url(&url)
202 })
203 },
204 ))
205 .into_any(),
206 )
207 }
208 Err(_) => None,
209 });
210
211 v_flex().my_2().gap_1().child(header).children(content)
212 }
213}