1use std::{collections::BTreeMap, sync::Arc};
2
3use agent_settings::{
4 AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent,
5 ContextServerPresetContent,
6};
7use assistant_tool::{ToolSource, ToolWorkingSet};
8use fs::Fs;
9use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
10use picker::{Picker, PickerDelegate};
11use settings::update_settings_file;
12use ui::{ListItem, ListItemSpacing, prelude::*};
13use util::ResultExt as _;
14
15pub struct ToolPicker {
16 picker: Entity<Picker<ToolPickerDelegate>>,
17}
18
19#[derive(Clone, Copy, Debug, PartialEq)]
20pub enum ToolPickerMode {
21 BuiltinTools,
22 McpTools,
23}
24
25impl ToolPicker {
26 pub fn builtin_tools(
27 delegate: ToolPickerDelegate,
28 window: &mut Window,
29 cx: &mut Context<Self>,
30 ) -> Self {
31 let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
32 Self { picker }
33 }
34
35 pub fn mcp_tools(
36 delegate: ToolPickerDelegate,
37 window: &mut Window,
38 cx: &mut Context<Self>,
39 ) -> Self {
40 let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
41 Self { picker }
42 }
43}
44
45impl EventEmitter<DismissEvent> for ToolPicker {}
46
47impl Focusable for ToolPicker {
48 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
49 self.picker.focus_handle(cx)
50 }
51}
52
53impl Render for ToolPicker {
54 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
55 v_flex().w(rems(34.)).child(self.picker.clone())
56 }
57}
58
59#[derive(Debug, Clone)]
60pub enum PickerItem {
61 Tool {
62 server_id: Option<Arc<str>>,
63 name: Arc<str>,
64 },
65 ContextServer {
66 server_id: Arc<str>,
67 },
68}
69
70pub struct ToolPickerDelegate {
71 tool_picker: WeakEntity<ToolPicker>,
72 fs: Arc<dyn Fs>,
73 items: Arc<Vec<PickerItem>>,
74 profile_id: AgentProfileId,
75 profile_settings: AgentProfileSettings,
76 filtered_items: Vec<PickerItem>,
77 selected_index: usize,
78 mode: ToolPickerMode,
79}
80
81impl ToolPickerDelegate {
82 pub fn new(
83 mode: ToolPickerMode,
84 fs: Arc<dyn Fs>,
85 tool_set: Entity<ToolWorkingSet>,
86 profile_id: AgentProfileId,
87 profile_settings: AgentProfileSettings,
88 cx: &mut Context<ToolPicker>,
89 ) -> Self {
90 let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
91
92 Self {
93 tool_picker: cx.entity().downgrade(),
94 fs,
95 items,
96 profile_id,
97 profile_settings,
98 filtered_items: Vec::new(),
99 selected_index: 0,
100 mode,
101 }
102 }
103
104 fn resolve_items(
105 mode: ToolPickerMode,
106 tool_set: &Entity<ToolWorkingSet>,
107 cx: &mut App,
108 ) -> Vec<PickerItem> {
109 let mut items = Vec::new();
110 for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
111 match source {
112 ToolSource::Native => {
113 if mode == ToolPickerMode::BuiltinTools {
114 items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
115 name: tool.name().into(),
116 server_id: None,
117 }));
118 }
119 }
120 ToolSource::ContextServer { id } => {
121 if mode == ToolPickerMode::McpTools && !tools.is_empty() {
122 let server_id: Arc<str> = id.clone().into();
123 items.push(PickerItem::ContextServer {
124 server_id: server_id.clone(),
125 });
126 items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
127 name: tool.name().into(),
128 server_id: Some(server_id.clone()),
129 }));
130 }
131 }
132 }
133 }
134 items
135 }
136}
137
138impl PickerDelegate for ToolPickerDelegate {
139 type ListItem = AnyElement;
140
141 fn match_count(&self) -> usize {
142 self.filtered_items.len()
143 }
144
145 fn selected_index(&self) -> usize {
146 self.selected_index
147 }
148
149 fn set_selected_index(
150 &mut self,
151 ix: usize,
152 _window: &mut Window,
153 _cx: &mut Context<Picker<Self>>,
154 ) {
155 self.selected_index = ix;
156 }
157
158 fn can_select(
159 &mut self,
160 ix: usize,
161 _window: &mut Window,
162 _cx: &mut Context<Picker<Self>>,
163 ) -> bool {
164 let item = &self.filtered_items[ix];
165 match item {
166 PickerItem::Tool { .. } => true,
167 PickerItem::ContextServer { .. } => false,
168 }
169 }
170
171 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
172 match self.mode {
173 ToolPickerMode::BuiltinTools => "Search built-in tools…",
174 ToolPickerMode::McpTools => "Search MCP tools…",
175 }
176 .into()
177 }
178
179 fn update_matches(
180 &mut self,
181 query: String,
182 window: &mut Window,
183 cx: &mut Context<Picker<Self>>,
184 ) -> Task<()> {
185 let all_items = self.items.clone();
186
187 cx.spawn_in(window, async move |this, cx| {
188 let filtered_items = cx
189 .background_spawn(async move {
190 let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
191 BTreeMap::default();
192
193 for item in all_items.iter() {
194 if let PickerItem::Tool { server_id, name } = item.clone()
195 && name.contains(&query) {
196 tools_by_provider.entry(server_id).or_default().push(name);
197 }
198 }
199
200 let mut items = Vec::new();
201
202 for (server_id, names) in tools_by_provider {
203 if let Some(server_id) = server_id.clone() {
204 items.push(PickerItem::ContextServer { server_id });
205 }
206 for name in names {
207 items.push(PickerItem::Tool {
208 server_id: server_id.clone(),
209 name,
210 });
211 }
212 }
213
214 items
215 })
216 .await;
217
218 this.update(cx, |this, _cx| {
219 this.delegate.filtered_items = filtered_items;
220 this.delegate.selected_index = this
221 .delegate
222 .selected_index
223 .min(this.delegate.filtered_items.len().saturating_sub(1));
224 })
225 .log_err();
226 })
227 }
228
229 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
230 if self.filtered_items.is_empty() {
231 self.dismissed(window, cx);
232 return;
233 }
234
235 let item = &self.filtered_items[self.selected_index];
236
237 let PickerItem::Tool {
238 name: tool_name,
239 server_id,
240 } = item
241 else {
242 return;
243 };
244
245 let is_currently_enabled = if let Some(server_id) = server_id.clone() {
246 let preset = self
247 .profile_settings
248 .context_servers
249 .entry(server_id)
250 .or_default();
251 let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
252 *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
253 is_enabled
254 } else {
255 let is_enabled = *self
256 .profile_settings
257 .tools
258 .entry(tool_name.clone())
259 .or_default();
260 *self
261 .profile_settings
262 .tools
263 .entry(tool_name.clone())
264 .or_default() = !is_enabled;
265 is_enabled
266 };
267
268 update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
269 let profile_id = self.profile_id.clone();
270 let default_profile = self.profile_settings.clone();
271 let server_id = server_id.clone();
272 let tool_name = tool_name.clone();
273 move |settings: &mut AgentSettingsContent, _cx| {
274 let profiles = settings.profiles.get_or_insert_default();
275 let profile = profiles
276 .entry(profile_id)
277 .or_insert_with(|| AgentProfileContent {
278 name: default_profile.name.into(),
279 tools: default_profile.tools,
280 enable_all_context_servers: Some(
281 default_profile.enable_all_context_servers,
282 ),
283 context_servers: default_profile
284 .context_servers
285 .into_iter()
286 .map(|(server_id, preset)| {
287 (
288 server_id,
289 ContextServerPresetContent {
290 tools: preset.tools,
291 },
292 )
293 })
294 .collect(),
295 });
296
297 if let Some(server_id) = server_id {
298 let preset = profile.context_servers.entry(server_id).or_default();
299 *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
300 } else {
301 *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
302 }
303 }
304 });
305 }
306
307 fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
308 self.tool_picker
309 .update(cx, |_this, cx| cx.emit(DismissEvent))
310 .log_err();
311 }
312
313 fn render_match(
314 &self,
315 ix: usize,
316 selected: bool,
317 _window: &mut Window,
318 cx: &mut Context<Picker<Self>>,
319 ) -> Option<Self::ListItem> {
320 let item = &self.filtered_items[ix];
321 match item {
322 PickerItem::ContextServer { server_id, .. } => Some(
323 div()
324 .px_2()
325 .pb_1()
326 .when(ix > 1, |this| {
327 this.mt_1()
328 .pt_2()
329 .border_t_1()
330 .border_color(cx.theme().colors().border_variant)
331 })
332 .child(
333 Label::new(server_id)
334 .size(LabelSize::XSmall)
335 .color(Color::Muted),
336 )
337 .into_any_element(),
338 ),
339 PickerItem::Tool { name, server_id } => {
340 let is_enabled = if let Some(server_id) = server_id {
341 self.profile_settings
342 .context_servers
343 .get(server_id.as_ref())
344 .and_then(|preset| preset.tools.get(name))
345 .copied()
346 .unwrap_or(self.profile_settings.enable_all_context_servers)
347 } else {
348 self.profile_settings
349 .tools
350 .get(name)
351 .copied()
352 .unwrap_or(false)
353 };
354
355 Some(
356 ListItem::new(ix)
357 .inset(true)
358 .spacing(ListItemSpacing::Sparse)
359 .toggle_state(selected)
360 .child(Label::new(name.clone()))
361 .end_slot::<Icon>(is_enabled.then(|| {
362 Icon::new(IconName::Check)
363 .size(IconSize::Small)
364 .color(Color::Success)
365 }))
366 .into_any_element(),
367 )
368 }
369 }
370 }
371}