1use std::{collections::BTreeMap, sync::Arc};
2
3use agent_settings::{
4 AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent,
5 ContextServerPresetContent,
6};
7use assistant_tool::{ToolSource, ToolWorkingSet};
8use fs::Fs;
9use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
10use picker::{Picker, PickerDelegate};
11use settings::update_settings_file;
12use ui::{ListItem, ListItemSpacing, prelude::*};
13use util::ResultExt as _;
14
15pub struct ToolPicker {
16 picker: Entity<Picker<ToolPickerDelegate>>,
17}
18
19#[derive(Clone, Copy, Debug, PartialEq)]
20pub enum ToolPickerMode {
21 BuiltinTools,
22 McpTools,
23}
24
25impl ToolPicker {
26 pub fn builtin_tools(
27 delegate: ToolPickerDelegate,
28 window: &mut Window,
29 cx: &mut Context<Self>,
30 ) -> Self {
31 let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
32 Self { picker }
33 }
34
35 pub fn mcp_tools(
36 delegate: ToolPickerDelegate,
37 window: &mut Window,
38 cx: &mut Context<Self>,
39 ) -> Self {
40 let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
41 Self { picker }
42 }
43}
44
45impl EventEmitter<DismissEvent> for ToolPicker {}
46
47impl Focusable for ToolPicker {
48 fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
49 self.picker.focus_handle(cx)
50 }
51}
52
53impl Render for ToolPicker {
54 fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
55 v_flex().w(rems(34.)).child(self.picker.clone())
56 }
57}
58
59#[derive(Debug, Clone)]
60pub enum PickerItem {
61 Tool {
62 server_id: Option<Arc<str>>,
63 name: Arc<str>,
64 },
65 ContextServer {
66 server_id: Arc<str>,
67 },
68}
69
70pub struct ToolPickerDelegate {
71 tool_picker: WeakEntity<ToolPicker>,
72 fs: Arc<dyn Fs>,
73 items: Arc<Vec<PickerItem>>,
74 profile_id: AgentProfileId,
75 profile_settings: AgentProfileSettings,
76 filtered_items: Vec<PickerItem>,
77 selected_index: usize,
78 mode: ToolPickerMode,
79}
80
81impl ToolPickerDelegate {
82 pub fn new(
83 mode: ToolPickerMode,
84 fs: Arc<dyn Fs>,
85 tool_set: Entity<ToolWorkingSet>,
86 profile_id: AgentProfileId,
87 profile_settings: AgentProfileSettings,
88 cx: &mut Context<ToolPicker>,
89 ) -> Self {
90 let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
91
92 Self {
93 tool_picker: cx.entity().downgrade(),
94 fs,
95 items,
96 profile_id,
97 profile_settings,
98 filtered_items: Vec::new(),
99 selected_index: 0,
100 mode,
101 }
102 }
103
104 fn resolve_items(
105 mode: ToolPickerMode,
106 tool_set: &Entity<ToolWorkingSet>,
107 cx: &mut App,
108 ) -> Vec<PickerItem> {
109 let mut items = Vec::new();
110 for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
111 match source {
112 ToolSource::Native => {
113 if mode == ToolPickerMode::BuiltinTools {
114 items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
115 name: tool.name().into(),
116 server_id: None,
117 }));
118 }
119 }
120 ToolSource::ContextServer { id } => {
121 if mode == ToolPickerMode::McpTools && !tools.is_empty() {
122 let server_id: Arc<str> = id.clone().into();
123 items.push(PickerItem::ContextServer {
124 server_id: server_id.clone(),
125 });
126 items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
127 name: tool.name().into(),
128 server_id: Some(server_id.clone()),
129 }));
130 }
131 }
132 }
133 }
134 items
135 }
136}
137
138impl PickerDelegate for ToolPickerDelegate {
139 type ListItem = AnyElement;
140
141 fn match_count(&self) -> usize {
142 self.filtered_items.len()
143 }
144
145 fn selected_index(&self) -> usize {
146 self.selected_index
147 }
148
149 fn set_selected_index(
150 &mut self,
151 ix: usize,
152 _window: &mut Window,
153 _cx: &mut Context<Picker<Self>>,
154 ) {
155 self.selected_index = ix;
156 }
157
158 fn can_select(
159 &mut self,
160 ix: usize,
161 _window: &mut Window,
162 _cx: &mut Context<Picker<Self>>,
163 ) -> bool {
164 let item = &self.filtered_items[ix];
165 match item {
166 PickerItem::Tool { .. } => true,
167 PickerItem::ContextServer { .. } => false,
168 }
169 }
170
171 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
172 match self.mode {
173 ToolPickerMode::BuiltinTools => "Search built-in tools…",
174 ToolPickerMode::McpTools => "Search MCP tools…",
175 }
176 .into()
177 }
178
179 fn update_matches(
180 &mut self,
181 query: String,
182 window: &mut Window,
183 cx: &mut Context<Picker<Self>>,
184 ) -> Task<()> {
185 let all_items = self.items.clone();
186
187 cx.spawn_in(window, async move |this, cx| {
188 let filtered_items = cx
189 .background_spawn(async move {
190 let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
191 BTreeMap::default();
192
193 for item in all_items.iter() {
194 if let PickerItem::Tool { server_id, name } = item.clone()
195 && name.contains(&query)
196 {
197 tools_by_provider.entry(server_id).or_default().push(name);
198 }
199 }
200
201 let mut items = Vec::new();
202
203 for (server_id, names) in tools_by_provider {
204 if let Some(server_id) = server_id.clone() {
205 items.push(PickerItem::ContextServer { server_id });
206 }
207 for name in names {
208 items.push(PickerItem::Tool {
209 server_id: server_id.clone(),
210 name,
211 });
212 }
213 }
214
215 items
216 })
217 .await;
218
219 this.update(cx, |this, _cx| {
220 this.delegate.filtered_items = filtered_items;
221 this.delegate.selected_index = this
222 .delegate
223 .selected_index
224 .min(this.delegate.filtered_items.len().saturating_sub(1));
225 })
226 .log_err();
227 })
228 }
229
230 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231 if self.filtered_items.is_empty() {
232 self.dismissed(window, cx);
233 return;
234 }
235
236 let item = &self.filtered_items[self.selected_index];
237
238 let PickerItem::Tool {
239 name: tool_name,
240 server_id,
241 } = item
242 else {
243 return;
244 };
245
246 let is_currently_enabled = if let Some(server_id) = server_id.clone() {
247 let preset = self
248 .profile_settings
249 .context_servers
250 .entry(server_id)
251 .or_default();
252 let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
253 *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
254 is_enabled
255 } else {
256 let is_enabled = *self
257 .profile_settings
258 .tools
259 .entry(tool_name.clone())
260 .or_default();
261 *self
262 .profile_settings
263 .tools
264 .entry(tool_name.clone())
265 .or_default() = !is_enabled;
266 is_enabled
267 };
268
269 update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
270 let profile_id = self.profile_id.clone();
271 let default_profile = self.profile_settings.clone();
272 let server_id = server_id.clone();
273 let tool_name = tool_name.clone();
274 move |settings: &mut AgentSettingsContent, _cx| {
275 let profiles = settings.profiles.get_or_insert_default();
276 let profile = profiles
277 .entry(profile_id)
278 .or_insert_with(|| AgentProfileContent {
279 name: default_profile.name.into(),
280 tools: default_profile.tools,
281 enable_all_context_servers: Some(
282 default_profile.enable_all_context_servers,
283 ),
284 context_servers: default_profile
285 .context_servers
286 .into_iter()
287 .map(|(server_id, preset)| {
288 (
289 server_id,
290 ContextServerPresetContent {
291 tools: preset.tools,
292 },
293 )
294 })
295 .collect(),
296 });
297
298 if let Some(server_id) = server_id {
299 let preset = profile.context_servers.entry(server_id).or_default();
300 *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
301 } else {
302 *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
303 }
304 }
305 });
306 }
307
308 fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
309 self.tool_picker
310 .update(cx, |_this, cx| cx.emit(DismissEvent))
311 .log_err();
312 }
313
314 fn render_match(
315 &self,
316 ix: usize,
317 selected: bool,
318 _window: &mut Window,
319 cx: &mut Context<Picker<Self>>,
320 ) -> Option<Self::ListItem> {
321 let item = &self.filtered_items.get(ix)?;
322 match item {
323 PickerItem::ContextServer { server_id, .. } => Some(
324 div()
325 .px_2()
326 .pb_1()
327 .when(ix > 1, |this| {
328 this.mt_1()
329 .pt_2()
330 .border_t_1()
331 .border_color(cx.theme().colors().border_variant)
332 })
333 .child(
334 Label::new(server_id)
335 .size(LabelSize::XSmall)
336 .color(Color::Muted),
337 )
338 .into_any_element(),
339 ),
340 PickerItem::Tool { name, server_id } => {
341 let is_enabled = if let Some(server_id) = server_id {
342 self.profile_settings
343 .context_servers
344 .get(server_id.as_ref())
345 .and_then(|preset| preset.tools.get(name))
346 .copied()
347 .unwrap_or(self.profile_settings.enable_all_context_servers)
348 } else {
349 self.profile_settings
350 .tools
351 .get(name)
352 .copied()
353 .unwrap_or(false)
354 };
355
356 Some(
357 ListItem::new(ix)
358 .inset(true)
359 .spacing(ListItemSpacing::Sparse)
360 .toggle_state(selected)
361 .child(Label::new(name.clone()))
362 .end_slot::<Icon>(is_enabled.then(|| {
363 Icon::new(IconName::Check)
364 .size(IconSize::Small)
365 .color(Color::Success)
366 }))
367 .into_any_element(),
368 )
369 }
370 }
371 }
372}