1use std::{collections::BTreeMap, sync::Arc};
2
3use agent_settings::{AgentProfileId, AgentProfileSettings};
4use assistant_tool::{ToolSource, ToolWorkingSet};
5use fs::Fs;
6use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
7use picker::{Picker, PickerDelegate};
8use settings::{AgentProfileContent, ContextServerPresetContent, update_settings_file};
9use ui::{ListItem, ListItemSpacing, prelude::*};
10use util::ResultExt as _;
11
12pub struct ToolPicker {
13 picker: Entity<Picker<ToolPickerDelegate>>,
14}
15
16#[derive(Clone, Copy, Debug, PartialEq)]
17pub enum ToolPickerMode {
18 BuiltinTools,
19 McpTools,
20}
21
22impl ToolPicker {
23 pub fn builtin_tools(
24 delegate: ToolPickerDelegate,
25 window: &mut Window,
26 cx: &mut Context<Self>,
27 ) -> Self {
28 let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
29 Self { picker }
30 }
31
32 pub fn mcp_tools(
33 delegate: ToolPickerDelegate,
34 window: &mut Window,
35 cx: &mut Context<Self>,
36 ) -> Self {
37 let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
38 Self { picker }
39 }
40}
41
42impl EventEmitter<DismissEvent> for ToolPicker {}
43
44impl Focusable for ToolPicker {
45 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
46 self.picker.focus_handle(cx)
47 }
48}
49
50impl Render for ToolPicker {
51 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
52 v_flex().w(rems(34.)).child(self.picker.clone())
53 }
54}
55
56#[derive(Debug, Clone)]
57pub enum PickerItem {
58 Tool {
59 server_id: Option<Arc<str>>,
60 name: Arc<str>,
61 },
62 ContextServer {
63 server_id: Arc<str>,
64 },
65}
66
67pub struct ToolPickerDelegate {
68 tool_picker: WeakEntity<ToolPicker>,
69 fs: Arc<dyn Fs>,
70 items: Arc<Vec<PickerItem>>,
71 profile_id: AgentProfileId,
72 profile_settings: AgentProfileSettings,
73 filtered_items: Vec<PickerItem>,
74 selected_index: usize,
75 mode: ToolPickerMode,
76}
77
78impl ToolPickerDelegate {
79 pub fn new(
80 mode: ToolPickerMode,
81 fs: Arc<dyn Fs>,
82 tool_set: Entity<ToolWorkingSet>,
83 profile_id: AgentProfileId,
84 profile_settings: AgentProfileSettings,
85 cx: &mut Context<ToolPicker>,
86 ) -> Self {
87 let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
88
89 Self {
90 tool_picker: cx.entity().downgrade(),
91 fs,
92 items,
93 profile_id,
94 profile_settings,
95 filtered_items: Vec::new(),
96 selected_index: 0,
97 mode,
98 }
99 }
100
101 fn resolve_items(
102 mode: ToolPickerMode,
103 tool_set: &Entity<ToolWorkingSet>,
104 cx: &mut App,
105 ) -> Vec<PickerItem> {
106 let mut items = Vec::new();
107 for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
108 match source {
109 ToolSource::Native => {
110 if mode == ToolPickerMode::BuiltinTools {
111 items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
112 name: tool.name().into(),
113 server_id: None,
114 }));
115 }
116 }
117 ToolSource::ContextServer { id } => {
118 if mode == ToolPickerMode::McpTools && !tools.is_empty() {
119 let server_id: Arc<str> = id.clone().into();
120 items.push(PickerItem::ContextServer {
121 server_id: server_id.clone(),
122 });
123 items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
124 name: tool.name().into(),
125 server_id: Some(server_id.clone()),
126 }));
127 }
128 }
129 }
130 }
131 items
132 }
133}
134
135impl PickerDelegate for ToolPickerDelegate {
136 type ListItem = AnyElement;
137
138 fn match_count(&self) -> usize {
139 self.filtered_items.len()
140 }
141
142 fn selected_index(&self) -> usize {
143 self.selected_index
144 }
145
146 fn set_selected_index(
147 &mut self,
148 ix: usize,
149 _window: &mut Window,
150 _cx: &mut Context<Picker<Self>>,
151 ) {
152 self.selected_index = ix;
153 }
154
155 fn can_select(
156 &mut self,
157 ix: usize,
158 _window: &mut Window,
159 _cx: &mut Context<Picker<Self>>,
160 ) -> bool {
161 let item = &self.filtered_items[ix];
162 match item {
163 PickerItem::Tool { .. } => true,
164 PickerItem::ContextServer { .. } => false,
165 }
166 }
167
168 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
169 match self.mode {
170 ToolPickerMode::BuiltinTools => "Search built-in tools…",
171 ToolPickerMode::McpTools => "Search MCP tools…",
172 }
173 .into()
174 }
175
176 fn update_matches(
177 &mut self,
178 query: String,
179 window: &mut Window,
180 cx: &mut Context<Picker<Self>>,
181 ) -> Task<()> {
182 let all_items = self.items.clone();
183
184 cx.spawn_in(window, async move |this, cx| {
185 let filtered_items = cx
186 .background_spawn(async move {
187 let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
188 BTreeMap::default();
189
190 for item in all_items.iter() {
191 if let PickerItem::Tool { server_id, name } = item.clone()
192 && name.contains(&query)
193 {
194 tools_by_provider.entry(server_id).or_default().push(name);
195 }
196 }
197
198 let mut items = Vec::new();
199
200 for (server_id, names) in tools_by_provider {
201 if let Some(server_id) = server_id.clone() {
202 items.push(PickerItem::ContextServer { server_id });
203 }
204 for name in names {
205 items.push(PickerItem::Tool {
206 server_id: server_id.clone(),
207 name,
208 });
209 }
210 }
211
212 items
213 })
214 .await;
215
216 this.update(cx, |this, _cx| {
217 this.delegate.filtered_items = filtered_items;
218 this.delegate.selected_index = this
219 .delegate
220 .selected_index
221 .min(this.delegate.filtered_items.len().saturating_sub(1));
222 })
223 .log_err();
224 })
225 }
226
227 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
228 if self.filtered_items.is_empty() {
229 self.dismissed(window, cx);
230 return;
231 }
232
233 let item = &self.filtered_items[self.selected_index];
234
235 let PickerItem::Tool {
236 name: tool_name,
237 server_id,
238 } = item
239 else {
240 return;
241 };
242
243 let is_currently_enabled = if let Some(server_id) = server_id.clone() {
244 let preset = self
245 .profile_settings
246 .context_servers
247 .entry(server_id)
248 .or_default();
249 let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
250 *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
251 is_enabled
252 } else {
253 let is_enabled = *self
254 .profile_settings
255 .tools
256 .entry(tool_name.clone())
257 .or_default();
258 *self
259 .profile_settings
260 .tools
261 .entry(tool_name.clone())
262 .or_default() = !is_enabled;
263 is_enabled
264 };
265
266 update_settings_file(self.fs.clone(), cx, {
267 let profile_id = self.profile_id.clone();
268 let default_profile = self.profile_settings.clone();
269 let server_id = server_id.clone();
270 let tool_name = tool_name.clone();
271 move |settings, _cx| {
272 let profiles = settings
273 .agent
274 .get_or_insert_default()
275 .profiles
276 .get_or_insert_default();
277 let profile = profiles
278 .entry(profile_id.0)
279 .or_insert_with(|| AgentProfileContent {
280 name: default_profile.name.into(),
281 tools: default_profile.tools,
282 enable_all_context_servers: Some(
283 default_profile.enable_all_context_servers,
284 ),
285 context_servers: default_profile
286 .context_servers
287 .into_iter()
288 .map(|(server_id, preset)| {
289 (
290 server_id,
291 ContextServerPresetContent {
292 tools: preset.tools,
293 },
294 )
295 })
296 .collect(),
297 });
298
299 if let Some(server_id) = server_id {
300 let preset = profile.context_servers.entry(server_id).or_default();
301 *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
302 } else {
303 *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
304 }
305 }
306 });
307 }
308
309 fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
310 self.tool_picker
311 .update(cx, |_this, cx| cx.emit(DismissEvent))
312 .log_err();
313 }
314
315 fn render_match(
316 &self,
317 ix: usize,
318 selected: bool,
319 _window: &mut Window,
320 cx: &mut Context<Picker<Self>>,
321 ) -> Option<Self::ListItem> {
322 let item = &self.filtered_items.get(ix)?;
323 match item {
324 PickerItem::ContextServer { server_id, .. } => Some(
325 div()
326 .px_2()
327 .pb_1()
328 .when(ix > 1, |this| {
329 this.mt_1()
330 .pt_2()
331 .border_t_1()
332 .border_color(cx.theme().colors().border_variant)
333 })
334 .child(
335 Label::new(server_id)
336 .size(LabelSize::XSmall)
337 .color(Color::Muted),
338 )
339 .into_any_element(),
340 ),
341 PickerItem::Tool { name, server_id } => {
342 let is_enabled = if let Some(server_id) = server_id {
343 self.profile_settings
344 .context_servers
345 .get(server_id.as_ref())
346 .and_then(|preset| preset.tools.get(name))
347 .copied()
348 .unwrap_or(self.profile_settings.enable_all_context_servers)
349 } else {
350 self.profile_settings
351 .tools
352 .get(name)
353 .copied()
354 .unwrap_or(false)
355 };
356
357 Some(
358 ListItem::new(ix)
359 .inset(true)
360 .spacing(ListItemSpacing::Sparse)
361 .toggle_state(selected)
362 .child(Label::new(name.clone()))
363 .end_slot::<Icon>(is_enabled.then(|| {
364 Icon::new(IconName::Check)
365 .size(IconSize::Small)
366 .color(Color::Success)
367 }))
368 .into_any_element(),
369 )
370 }
371 }
372 }
373}