1use std::{collections::BTreeMap, sync::Arc};
2
3use agent::ContextServerRegistry;
4use agent_settings::{AgentProfileId, AgentProfileSettings};
5use fs::Fs;
6use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
7use picker::{Picker, PickerDelegate};
8use settings::{AgentProfileContent, ContextServerPresetContent, update_settings_file};
9use ui::{ListItem, ListItemSpacing, prelude::*};
10use util::ResultExt as _;
11
12pub struct ToolPicker {
13 picker: Entity<Picker<ToolPickerDelegate>>,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq)]
17enum ToolPickerMode {
18 BuiltinTools,
19 McpTools,
20}
21
22impl ToolPicker {
23 pub fn builtin_tools(
24 delegate: ToolPickerDelegate,
25 window: &mut Window,
26 cx: &mut Context<Self>,
27 ) -> Self {
28 let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
29 Self { picker }
30 }
31
32 pub fn mcp_tools(
33 delegate: ToolPickerDelegate,
34 window: &mut Window,
35 cx: &mut Context<Self>,
36 ) -> Self {
37 let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
38 Self { picker }
39 }
40}
41
42impl EventEmitter<DismissEvent> for ToolPicker {}
43
44impl Focusable for ToolPicker {
45 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
46 self.picker.focus_handle(cx)
47 }
48}
49
50impl Render for ToolPicker {
51 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
52 v_flex().w(rems(34.)).child(self.picker.clone())
53 }
54}
55
56#[derive(Debug, Clone)]
57pub enum PickerItem {
58 Tool {
59 server_id: Option<Arc<str>>,
60 name: Arc<str>,
61 },
62 ContextServer {
63 server_id: Arc<str>,
64 },
65}
66
67pub struct ToolPickerDelegate {
68 tool_picker: WeakEntity<ToolPicker>,
69 fs: Arc<dyn Fs>,
70 items: Arc<Vec<PickerItem>>,
71 profile_id: AgentProfileId,
72 profile_settings: AgentProfileSettings,
73 filtered_items: Vec<PickerItem>,
74 selected_index: usize,
75 mode: ToolPickerMode,
76}
77
78impl ToolPickerDelegate {
79 pub fn builtin_tools(
80 tool_names: Vec<Arc<str>>,
81 fs: Arc<dyn Fs>,
82 profile_id: AgentProfileId,
83 profile_settings: AgentProfileSettings,
84 cx: &mut Context<ToolPicker>,
85 ) -> Self {
86 Self::new(
87 Arc::new(
88 tool_names
89 .into_iter()
90 .map(|name| PickerItem::Tool {
91 name,
92 server_id: None,
93 })
94 .collect(),
95 ),
96 ToolPickerMode::BuiltinTools,
97 fs,
98 profile_id,
99 profile_settings,
100 cx,
101 )
102 }
103
104 pub fn mcp_tools(
105 registry: &Entity<ContextServerRegistry>,
106 fs: Arc<dyn Fs>,
107 profile_id: AgentProfileId,
108 profile_settings: AgentProfileSettings,
109 cx: &mut Context<ToolPicker>,
110 ) -> Self {
111 let mut items = Vec::new();
112
113 for (id, tools) in registry.read(cx).servers() {
114 let server_id = id.clone().0;
115 items.push(PickerItem::ContextServer {
116 server_id: server_id.clone(),
117 });
118 items.extend(tools.keys().map(|tool_name| PickerItem::Tool {
119 name: tool_name.clone().into(),
120 server_id: Some(server_id.clone()),
121 }));
122 }
123
124 Self::new(
125 Arc::new(items),
126 ToolPickerMode::McpTools,
127 fs,
128 profile_id,
129 profile_settings,
130 cx,
131 )
132 }
133
134 fn new(
135 items: Arc<Vec<PickerItem>>,
136 mode: ToolPickerMode,
137 fs: Arc<dyn Fs>,
138 profile_id: AgentProfileId,
139 profile_settings: AgentProfileSettings,
140 cx: &mut Context<ToolPicker>,
141 ) -> Self {
142 Self {
143 tool_picker: cx.entity().downgrade(),
144 mode,
145 fs,
146 items,
147 profile_id,
148 profile_settings,
149 filtered_items: Vec::new(),
150 selected_index: 0,
151 }
152 }
153}
154
155impl PickerDelegate for ToolPickerDelegate {
156 type ListItem = AnyElement;
157
158 fn match_count(&self) -> usize {
159 self.filtered_items.len()
160 }
161
162 fn selected_index(&self) -> usize {
163 self.selected_index
164 }
165
166 fn set_selected_index(
167 &mut self,
168 ix: usize,
169 _window: &mut Window,
170 _cx: &mut Context<Picker<Self>>,
171 ) {
172 self.selected_index = ix;
173 }
174
175 fn can_select(
176 &mut self,
177 ix: usize,
178 _window: &mut Window,
179 _cx: &mut Context<Picker<Self>>,
180 ) -> bool {
181 let item = &self.filtered_items[ix];
182 match item {
183 PickerItem::Tool { .. } => true,
184 PickerItem::ContextServer { .. } => false,
185 }
186 }
187
188 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
189 match self.mode {
190 ToolPickerMode::BuiltinTools => "Search built-in tools…",
191 ToolPickerMode::McpTools => "Search MCP tools…",
192 }
193 .into()
194 }
195
196 fn update_matches(
197 &mut self,
198 query: String,
199 window: &mut Window,
200 cx: &mut Context<Picker<Self>>,
201 ) -> Task<()> {
202 let all_items = self.items.clone();
203
204 cx.spawn_in(window, async move |this, cx| {
205 let filtered_items = cx
206 .background_spawn(async move {
207 let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
208 BTreeMap::default();
209
210 for item in all_items.iter() {
211 if let PickerItem::Tool { server_id, name } = item.clone()
212 && name.contains(&query)
213 {
214 tools_by_provider.entry(server_id).or_default().push(name);
215 }
216 }
217
218 let mut items = Vec::new();
219
220 for (server_id, names) in tools_by_provider {
221 if let Some(server_id) = server_id.clone() {
222 items.push(PickerItem::ContextServer { server_id });
223 }
224 for name in names {
225 items.push(PickerItem::Tool {
226 server_id: server_id.clone(),
227 name,
228 });
229 }
230 }
231
232 items
233 })
234 .await;
235
236 this.update(cx, |this, _cx| {
237 this.delegate.filtered_items = filtered_items;
238 this.delegate.selected_index = this
239 .delegate
240 .selected_index
241 .min(this.delegate.filtered_items.len().saturating_sub(1));
242 })
243 .log_err();
244 })
245 }
246
247 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
248 if self.filtered_items.is_empty() {
249 self.dismissed(window, cx);
250 return;
251 }
252
253 let item = &self.filtered_items[self.selected_index];
254
255 let PickerItem::Tool {
256 name: tool_name,
257 server_id,
258 } = item
259 else {
260 return;
261 };
262
263 let is_currently_enabled = if let Some(server_id) = server_id.clone() {
264 let preset = self
265 .profile_settings
266 .context_servers
267 .entry(server_id)
268 .or_default();
269 let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
270 *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
271 is_enabled
272 } else {
273 let is_enabled = *self
274 .profile_settings
275 .tools
276 .entry(tool_name.clone())
277 .or_default();
278 *self
279 .profile_settings
280 .tools
281 .entry(tool_name.clone())
282 .or_default() = !is_enabled;
283 is_enabled
284 };
285
286 update_settings_file(self.fs.clone(), cx, {
287 let profile_id = self.profile_id.clone();
288 let default_profile = self.profile_settings.clone();
289 let server_id = server_id.clone();
290 let tool_name = tool_name.clone();
291 move |settings, _cx| {
292 let profiles = settings
293 .agent
294 .get_or_insert_default()
295 .profiles
296 .get_or_insert_default();
297 let profile = profiles
298 .entry(profile_id.0)
299 .or_insert_with(|| AgentProfileContent {
300 name: default_profile.name.into(),
301 tools: default_profile.tools,
302 enable_all_context_servers: Some(
303 default_profile.enable_all_context_servers,
304 ),
305 context_servers: default_profile
306 .context_servers
307 .into_iter()
308 .map(|(server_id, preset)| {
309 (
310 server_id,
311 ContextServerPresetContent {
312 tools: preset.tools,
313 },
314 )
315 })
316 .collect(),
317 });
318
319 if let Some(server_id) = server_id {
320 let preset = profile.context_servers.entry(server_id).or_default();
321 *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
322 } else {
323 *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
324 }
325 }
326 });
327 }
328
329 fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
330 self.tool_picker
331 .update(cx, |_this, cx| cx.emit(DismissEvent))
332 .log_err();
333 }
334
335 fn render_match(
336 &self,
337 ix: usize,
338 selected: bool,
339 _window: &mut Window,
340 cx: &mut Context<Picker<Self>>,
341 ) -> Option<Self::ListItem> {
342 let item = &self.filtered_items.get(ix)?;
343 match item {
344 PickerItem::ContextServer { server_id, .. } => Some(
345 div()
346 .px_2()
347 .pb_1()
348 .when(ix > 1, |this| {
349 this.mt_1()
350 .pt_2()
351 .border_t_1()
352 .border_color(cx.theme().colors().border_variant)
353 })
354 .child(
355 Label::new(server_id)
356 .size(LabelSize::XSmall)
357 .color(Color::Muted),
358 )
359 .into_any_element(),
360 ),
361 PickerItem::Tool { name, server_id } => {
362 let is_enabled = if let Some(server_id) = server_id {
363 self.profile_settings
364 .context_servers
365 .get(server_id.as_ref())
366 .and_then(|preset| preset.tools.get(name))
367 .copied()
368 .unwrap_or(self.profile_settings.enable_all_context_servers)
369 } else {
370 self.profile_settings
371 .tools
372 .get(name)
373 .copied()
374 .unwrap_or(false)
375 };
376
377 Some(
378 ListItem::new(ix)
379 .inset(true)
380 .spacing(ListItemSpacing::Sparse)
381 .toggle_state(selected)
382 .child(Label::new(name.clone()))
383 .end_slot::<Icon>(is_enabled.then(|| {
384 Icon::new(IconName::Check)
385 .size(IconSize::Small)
386 .color(Color::Success)
387 }))
388 .into_any_element(),
389 )
390 }
391 }
392 }
393}