1use std::{collections::BTreeMap, sync::Arc};
2
3use agent::ContextServerRegistry;
4use agent_settings::{AgentProfileId, AgentProfileSettings};
5use fs::Fs;
6use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
7use picker::{Picker, PickerDelegate};
8use settings::{AgentProfileContent, ContextServerPresetContent, update_settings_file};
9use ui::{ListItem, ListItemSpacing, prelude::*};
10use util::ResultExt as _;
11
12pub struct ToolPicker {
13 picker: Entity<Picker<ToolPickerDelegate>>,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq)]
17enum ToolPickerMode {
18 BuiltinTools,
19 McpTools,
20}
21
22impl ToolPicker {
23 pub fn builtin_tools(
24 delegate: ToolPickerDelegate,
25 window: &mut Window,
26 cx: &mut Context<Self>,
27 ) -> Self {
28 let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
29 Self { picker }
30 }
31
32 pub fn mcp_tools(
33 delegate: ToolPickerDelegate,
34 window: &mut Window,
35 cx: &mut Context<Self>,
36 ) -> Self {
37 let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
38 Self { picker }
39 }
40}
41
42impl EventEmitter<DismissEvent> for ToolPicker {}
43
44impl Focusable for ToolPicker {
45 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
46 self.picker.focus_handle(cx)
47 }
48}
49
50impl Render for ToolPicker {
51 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
52 v_flex().w(rems(34.)).child(self.picker.clone())
53 }
54}
55
56#[derive(Debug, Clone)]
57pub enum PickerItem {
58 Tool {
59 server_id: Option<Arc<str>>,
60 name: Arc<str>,
61 },
62 ContextServer {
63 server_id: Arc<str>,
64 },
65}
66
67pub struct ToolPickerDelegate {
68 tool_picker: WeakEntity<ToolPicker>,
69 fs: Arc<dyn Fs>,
70 items: Arc<Vec<PickerItem>>,
71 profile_id: AgentProfileId,
72 profile_settings: AgentProfileSettings,
73 filtered_items: Vec<PickerItem>,
74 selected_index: usize,
75 mode: ToolPickerMode,
76}
77
78impl ToolPickerDelegate {
79 pub fn builtin_tools(
80 tool_names: Vec<Arc<str>>,
81 fs: Arc<dyn Fs>,
82 profile_id: AgentProfileId,
83 profile_settings: AgentProfileSettings,
84 cx: &mut Context<ToolPicker>,
85 ) -> Self {
86 Self::new(
87 Arc::new(
88 tool_names
89 .into_iter()
90 .map(|name| PickerItem::Tool {
91 name,
92 server_id: None,
93 })
94 .collect(),
95 ),
96 ToolPickerMode::BuiltinTools,
97 fs,
98 profile_id,
99 profile_settings,
100 cx,
101 )
102 }
103
104 pub fn mcp_tools(
105 registry: &Entity<ContextServerRegistry>,
106 fs: Arc<dyn Fs>,
107 profile_id: AgentProfileId,
108 profile_settings: AgentProfileSettings,
109 cx: &mut Context<ToolPicker>,
110 ) -> Self {
111 let mut items = Vec::new();
112
113 for (id, tools) in registry.read(cx).servers() {
114 let server_id = id.clone().0;
115 items.push(PickerItem::ContextServer {
116 server_id: server_id.clone(),
117 });
118 items.extend(tools.keys().map(|tool_name| PickerItem::Tool {
119 name: tool_name.clone().into(),
120 server_id: Some(server_id.clone()),
121 }));
122 }
123
124 Self::new(
125 Arc::new(items),
126 ToolPickerMode::McpTools,
127 fs,
128 profile_id,
129 profile_settings,
130 cx,
131 )
132 }
133
134 fn new(
135 items: Arc<Vec<PickerItem>>,
136 mode: ToolPickerMode,
137 fs: Arc<dyn Fs>,
138 profile_id: AgentProfileId,
139 profile_settings: AgentProfileSettings,
140 cx: &mut Context<ToolPicker>,
141 ) -> Self {
142 Self {
143 tool_picker: cx.entity().downgrade(),
144 mode,
145 fs,
146 items,
147 profile_id,
148 profile_settings,
149 filtered_items: Vec::new(),
150 selected_index: 0,
151 }
152 }
153}
154
155impl PickerDelegate for ToolPickerDelegate {
156 type ListItem = AnyElement;
157
158 fn match_count(&self) -> usize {
159 self.filtered_items.len()
160 }
161
162 fn selected_index(&self) -> usize {
163 self.selected_index
164 }
165
166 fn set_selected_index(
167 &mut self,
168 ix: usize,
169 _window: &mut Window,
170 _cx: &mut Context<Picker<Self>>,
171 ) {
172 self.selected_index = ix;
173 }
174
175 fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
176 let item = &self.filtered_items[ix];
177 match item {
178 PickerItem::Tool { .. } => true,
179 PickerItem::ContextServer { .. } => false,
180 }
181 }
182
183 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
184 match self.mode {
185 ToolPickerMode::BuiltinTools => "Search built-in tools…",
186 ToolPickerMode::McpTools => "Search MCP tools…",
187 }
188 .into()
189 }
190
191 fn update_matches(
192 &mut self,
193 query: String,
194 window: &mut Window,
195 cx: &mut Context<Picker<Self>>,
196 ) -> Task<()> {
197 let all_items = self.items.clone();
198
199 cx.spawn_in(window, async move |this, cx| {
200 let filtered_items = cx
201 .background_spawn(async move {
202 let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
203 BTreeMap::default();
204
205 for item in all_items.iter() {
206 if let PickerItem::Tool { server_id, name } = item.clone()
207 && name.contains(&query)
208 {
209 tools_by_provider.entry(server_id).or_default().push(name);
210 }
211 }
212
213 let mut items = Vec::new();
214
215 for (server_id, names) in tools_by_provider {
216 if let Some(server_id) = server_id.clone() {
217 items.push(PickerItem::ContextServer { server_id });
218 }
219 for name in names {
220 items.push(PickerItem::Tool {
221 server_id: server_id.clone(),
222 name,
223 });
224 }
225 }
226
227 items
228 })
229 .await;
230
231 this.update(cx, |this, _cx| {
232 this.delegate.filtered_items = filtered_items;
233 this.delegate.selected_index = this
234 .delegate
235 .selected_index
236 .min(this.delegate.filtered_items.len().saturating_sub(1));
237 })
238 .log_err();
239 })
240 }
241
242 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
243 if self.filtered_items.is_empty() {
244 self.dismissed(window, cx);
245 return;
246 }
247
248 let item = &self.filtered_items[self.selected_index];
249
250 let PickerItem::Tool {
251 name: tool_name,
252 server_id,
253 } = item
254 else {
255 return;
256 };
257
258 let is_currently_enabled = if let Some(server_id) = server_id.clone() {
259 let preset = self
260 .profile_settings
261 .context_servers
262 .entry(server_id)
263 .or_default();
264 let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
265 *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
266 is_enabled
267 } else {
268 let is_enabled = *self
269 .profile_settings
270 .tools
271 .entry(tool_name.clone())
272 .or_default();
273 *self
274 .profile_settings
275 .tools
276 .entry(tool_name.clone())
277 .or_default() = !is_enabled;
278 is_enabled
279 };
280
281 update_settings_file(self.fs.clone(), cx, {
282 let profile_id = self.profile_id.clone();
283 let default_profile = self.profile_settings.clone();
284 let server_id = server_id.clone();
285 let tool_name = tool_name.clone();
286 move |settings, _cx| {
287 let profiles = settings
288 .agent
289 .get_or_insert_default()
290 .profiles
291 .get_or_insert_default();
292 let profile = profiles
293 .entry(profile_id.0)
294 .or_insert_with(|| AgentProfileContent {
295 name: default_profile.name.into(),
296 tools: default_profile.tools,
297 enable_all_context_servers: Some(
298 default_profile.enable_all_context_servers,
299 ),
300 context_servers: default_profile
301 .context_servers
302 .into_iter()
303 .map(|(server_id, preset)| {
304 (
305 server_id,
306 ContextServerPresetContent {
307 tools: preset.tools,
308 },
309 )
310 })
311 .collect(),
312 default_model: default_profile.default_model.clone(),
313 });
314
315 if let Some(server_id) = server_id {
316 let preset = profile.context_servers.entry(server_id).or_default();
317 *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
318 } else {
319 *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
320 }
321 }
322 });
323 }
324
325 fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
326 self.tool_picker
327 .update(cx, |_this, cx| cx.emit(DismissEvent))
328 .log_err();
329 }
330
331 fn render_match(
332 &self,
333 ix: usize,
334 selected: bool,
335 _window: &mut Window,
336 cx: &mut Context<Picker<Self>>,
337 ) -> Option<Self::ListItem> {
338 let item = &self.filtered_items.get(ix)?;
339 match item {
340 PickerItem::ContextServer { server_id, .. } => Some(
341 div()
342 .px_2()
343 .pb_1()
344 .when(ix > 1, |this| {
345 this.mt_1()
346 .pt_2()
347 .border_t_1()
348 .border_color(cx.theme().colors().border_variant)
349 })
350 .child(
351 Label::new(server_id)
352 .size(LabelSize::XSmall)
353 .color(Color::Muted),
354 )
355 .into_any_element(),
356 ),
357 PickerItem::Tool { name, server_id } => {
358 let is_enabled = if let Some(server_id) = server_id {
359 self.profile_settings
360 .context_servers
361 .get(server_id.as_ref())
362 .and_then(|preset| preset.tools.get(name))
363 .copied()
364 .unwrap_or(self.profile_settings.enable_all_context_servers)
365 } else {
366 self.profile_settings
367 .tools
368 .get(name)
369 .copied()
370 .unwrap_or(false)
371 };
372
373 Some(
374 ListItem::new(ix)
375 .inset(true)
376 .spacing(ListItemSpacing::Sparse)
377 .toggle_state(selected)
378 .child(Label::new(name.clone()))
379 .end_slot::<Icon>(is_enabled.then(|| {
380 Icon::new(IconName::Check)
381 .size(IconSize::Small)
382 .color(Color::Success)
383 }))
384 .into_any_element(),
385 )
386 }
387 }
388 }
389}