1use std::{sync::Arc, time::Duration};
2
3use crate::schema::json_schema_for;
4use crate::ui::ToolCallCardHeader;
5use anyhow::{Context as _, Result, anyhow};
6use assistant_tool::{ActionLog, Tool, ToolCard, ToolResult, ToolUseStatus};
7use futures::{Future, FutureExt, TryFutureExt};
8use gpui::{App, AppContext, Context, Entity, IntoElement, Task, Window};
9use language_model::{LanguageModelRequestMessage, LanguageModelToolSchemaFormat};
10use project::Project;
11use schemars::JsonSchema;
12use serde::{Deserialize, Serialize};
13use ui::{IconName, Tooltip, prelude::*};
14use web_search::WebSearchRegistry;
15use zed_llm_client::WebSearchResponse;
16
17#[derive(Debug, Serialize, Deserialize, JsonSchema)]
18pub struct WebSearchToolInput {
19 /// The search term or question to query on the web.
20 query: String,
21}
22
23pub struct WebSearchTool;
24
25impl Tool for WebSearchTool {
26 fn name(&self) -> String {
27 "web_search".into()
28 }
29
30 fn needs_confirmation(&self, _: &serde_json::Value, _: &App) -> bool {
31 false
32 }
33
34 fn description(&self) -> String {
35 "Search the web for information using your query. Use this when you need real-time information, facts, or data that might not be in your training. Results will include snippets and links from relevant web pages.".into()
36 }
37
38 fn icon(&self) -> IconName {
39 IconName::Globe
40 }
41
42 fn input_schema(&self, format: LanguageModelToolSchemaFormat) -> Result<serde_json::Value> {
43 json_schema_for::<WebSearchToolInput>(format)
44 }
45
46 fn ui_text(&self, _input: &serde_json::Value) -> String {
47 "Searching the Web".to_string()
48 }
49
50 fn run(
51 self: Arc<Self>,
52 input: serde_json::Value,
53 _messages: &[LanguageModelRequestMessage],
54 _project: Entity<Project>,
55 _action_log: Entity<ActionLog>,
56 cx: &mut App,
57 ) -> ToolResult {
58 let input = match serde_json::from_value::<WebSearchToolInput>(input) {
59 Ok(input) => input,
60 Err(err) => return Task::ready(Err(anyhow!(err))).into(),
61 };
62 let Some(provider) = WebSearchRegistry::read_global(cx).active_provider() else {
63 return Task::ready(Err(anyhow!("Web search is not available."))).into();
64 };
65
66 let search_task = provider.search(input.query, cx).map_err(Arc::new).shared();
67 let output = cx.background_spawn({
68 let search_task = search_task.clone();
69 async move {
70 let response = search_task.await.map_err(|err| anyhow!(err))?;
71 serde_json::to_string(&response).context("Failed to serialize search results")
72 }
73 });
74
75 ToolResult {
76 output,
77 card: Some(cx.new(|cx| WebSearchToolCard::new(search_task, cx)).into()),
78 }
79 }
80}
81
82struct WebSearchToolCard {
83 response: Option<Result<WebSearchResponse>>,
84 _task: Task<()>,
85}
86
87impl WebSearchToolCard {
88 fn new(
89 search_task: impl 'static + Future<Output = Result<WebSearchResponse, Arc<anyhow::Error>>>,
90 cx: &mut Context<Self>,
91 ) -> Self {
92 let _task = cx.spawn(async move |this, cx| {
93 let response = search_task.await.map_err(|err| anyhow!(err));
94 this.update(cx, |this, cx| {
95 this.response = Some(response);
96 cx.notify();
97 })
98 .ok();
99 });
100
101 Self {
102 response: None,
103 _task,
104 }
105 }
106}
107
108impl ToolCard for WebSearchToolCard {
109 fn render(
110 &mut self,
111 _status: &ToolUseStatus,
112 _window: &mut Window,
113 cx: &mut Context<Self>,
114 ) -> impl IntoElement {
115 let header = match self.response.as_ref() {
116 Some(Ok(response)) => {
117 let text: SharedString = if response.citations.len() == 1 {
118 "1 result".into()
119 } else {
120 format!("{} results", response.citations.len()).into()
121 };
122 ToolCallCardHeader::new(IconName::Globe, "Searched the Web")
123 .with_secondary_text(text)
124 }
125 Some(Err(error)) => {
126 ToolCallCardHeader::new(IconName::Globe, "Web Search").with_error(error.to_string())
127 }
128 None => ToolCallCardHeader::new(IconName::Globe, "Searching the Web").loading(),
129 };
130
131 let content =
132 self.response.as_ref().and_then(|response| match response {
133 Ok(response) => {
134 Some(
135 v_flex()
136 .overflow_hidden()
137 .ml_1p5()
138 .pl(px(5.))
139 .border_l_1()
140 .border_color(cx.theme().colors().border_variant)
141 .gap_1()
142 .children(response.citations.iter().enumerate().map(
143 |(index, citation)| {
144 let title = citation.title.clone();
145 let url = citation.url.clone();
146
147 Button::new(("citation", index), title)
148 .label_size(LabelSize::Small)
149 .color(Color::Muted)
150 .icon(IconName::ArrowUpRight)
151 .icon_size(IconSize::XSmall)
152 .icon_position(IconPosition::End)
153 .truncate(true)
154 .tooltip({
155 let url = url.clone();
156 move |window, cx| {
157 Tooltip::with_meta(
158 "Citation Link",
159 None,
160 url.clone(),
161 window,
162 cx,
163 )
164 }
165 })
166 .on_click({
167 let url = url.clone();
168 move |_, _, cx| cx.open_url(&url)
169 })
170 },
171 ))
172 .into_any(),
173 )
174 }
175 Err(_) => None,
176 });
177
178 v_flex().mb_3().gap_1().child(header).children(content)
179 }
180}