tool_picker.rs

  1use std::{collections::BTreeMap, sync::Arc};
  2
  3use agent_settings::{
  4    AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent,
  5    ContextServerPresetContent,
  6};
  7use assistant_tool::{ToolSource, ToolWorkingSet};
  8use fs::Fs;
  9use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
 10use picker::{Picker, PickerDelegate};
 11use settings::update_settings_file;
 12use ui::{ListItem, ListItemSpacing, prelude::*};
 13use util::ResultExt as _;
 14
 15pub struct ToolPicker {
 16    picker: Entity<Picker<ToolPickerDelegate>>,
 17}
 18
 19#[derive(Clone, Copy, Debug, PartialEq)]
 20pub enum ToolPickerMode {
 21    BuiltinTools,
 22    McpTools,
 23}
 24
 25impl ToolPicker {
 26    pub fn builtin_tools(
 27        delegate: ToolPickerDelegate,
 28        window: &mut Window,
 29        cx: &mut Context<Self>,
 30    ) -> Self {
 31        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 32        Self { picker }
 33    }
 34
 35    pub fn mcp_tools(
 36        delegate: ToolPickerDelegate,
 37        window: &mut Window,
 38        cx: &mut Context<Self>,
 39    ) -> Self {
 40        let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
 41        Self { picker }
 42    }
 43}
 44
 45impl EventEmitter<DismissEvent> for ToolPicker {}
 46
 47impl Focusable for ToolPicker {
 48    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 49        self.picker.focus_handle(cx)
 50    }
 51}
 52
 53impl Render for ToolPicker {
 54    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 55        v_flex().w(rems(34.)).child(self.picker.clone())
 56    }
 57}
 58
 59#[derive(Debug, Clone)]
 60pub enum PickerItem {
 61    Tool {
 62        server_id: Option<Arc<str>>,
 63        name: Arc<str>,
 64    },
 65    ContextServer {
 66        server_id: Arc<str>,
 67    },
 68}
 69
 70pub struct ToolPickerDelegate {
 71    tool_picker: WeakEntity<ToolPicker>,
 72    fs: Arc<dyn Fs>,
 73    items: Arc<Vec<PickerItem>>,
 74    profile_id: AgentProfileId,
 75    profile_settings: AgentProfileSettings,
 76    filtered_items: Vec<PickerItem>,
 77    selected_index: usize,
 78    mode: ToolPickerMode,
 79}
 80
 81impl ToolPickerDelegate {
 82    pub fn new(
 83        mode: ToolPickerMode,
 84        fs: Arc<dyn Fs>,
 85        tool_set: Entity<ToolWorkingSet>,
 86        profile_id: AgentProfileId,
 87        profile_settings: AgentProfileSettings,
 88        cx: &mut Context<ToolPicker>,
 89    ) -> Self {
 90        let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
 91
 92        Self {
 93            tool_picker: cx.entity().downgrade(),
 94            fs,
 95            items,
 96            profile_id,
 97            profile_settings,
 98            filtered_items: Vec::new(),
 99            selected_index: 0,
100            mode,
101        }
102    }
103
104    fn resolve_items(
105        mode: ToolPickerMode,
106        tool_set: &Entity<ToolWorkingSet>,
107        cx: &mut App,
108    ) -> Vec<PickerItem> {
109        let mut items = Vec::new();
110        for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
111            match source {
112                ToolSource::Native => {
113                    if mode == ToolPickerMode::BuiltinTools {
114                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
115                            name: tool.name().into(),
116                            server_id: None,
117                        }));
118                    }
119                }
120                ToolSource::ContextServer { id } => {
121                    if mode == ToolPickerMode::McpTools && !tools.is_empty() {
122                        let server_id: Arc<str> = id.clone().into();
123                        items.push(PickerItem::ContextServer {
124                            server_id: server_id.clone(),
125                        });
126                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
127                            name: tool.name().into(),
128                            server_id: Some(server_id.clone()),
129                        }));
130                    }
131                }
132            }
133        }
134        items
135    }
136}
137
138impl PickerDelegate for ToolPickerDelegate {
139    type ListItem = AnyElement;
140
141    fn match_count(&self) -> usize {
142        self.filtered_items.len()
143    }
144
145    fn selected_index(&self) -> usize {
146        self.selected_index
147    }
148
149    fn set_selected_index(
150        &mut self,
151        ix: usize,
152        _window: &mut Window,
153        _cx: &mut Context<Picker<Self>>,
154    ) {
155        self.selected_index = ix;
156    }
157
158    fn can_select(
159        &mut self,
160        ix: usize,
161        _window: &mut Window,
162        _cx: &mut Context<Picker<Self>>,
163    ) -> bool {
164        let item = &self.filtered_items[ix];
165        match item {
166            PickerItem::Tool { .. } => true,
167            PickerItem::ContextServer { .. } => false,
168        }
169    }
170
171    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
172        match self.mode {
173            ToolPickerMode::BuiltinTools => "Search built-in tools…",
174            ToolPickerMode::McpTools => "Search MCP tools…",
175        }
176        .into()
177    }
178
179    fn update_matches(
180        &mut self,
181        query: String,
182        window: &mut Window,
183        cx: &mut Context<Picker<Self>>,
184    ) -> Task<()> {
185        let all_items = self.items.clone();
186
187        cx.spawn_in(window, async move |this, cx| {
188            let filtered_items = cx
189                .background_spawn(async move {
190                    let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
191                        BTreeMap::default();
192
193                    for item in all_items.iter() {
194                        if let PickerItem::Tool { server_id, name } = item.clone() {
195                            if name.contains(&query) {
196                                tools_by_provider.entry(server_id).or_default().push(name);
197                            }
198                        }
199                    }
200
201                    let mut items = Vec::new();
202
203                    for (server_id, names) in tools_by_provider {
204                        if let Some(server_id) = server_id.clone() {
205                            items.push(PickerItem::ContextServer { server_id });
206                        }
207                        for name in names {
208                            items.push(PickerItem::Tool {
209                                server_id: server_id.clone(),
210                                name,
211                            });
212                        }
213                    }
214
215                    items
216                })
217                .await;
218
219            this.update(cx, |this, _cx| {
220                this.delegate.filtered_items = filtered_items;
221                this.delegate.selected_index = this
222                    .delegate
223                    .selected_index
224                    .min(this.delegate.filtered_items.len().saturating_sub(1));
225            })
226            .log_err();
227        })
228    }
229
230    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231        if self.filtered_items.is_empty() {
232            self.dismissed(window, cx);
233            return;
234        }
235
236        let item = &self.filtered_items[self.selected_index];
237
238        let PickerItem::Tool {
239            name: tool_name,
240            server_id,
241        } = item
242        else {
243            return;
244        };
245
246        let is_currently_enabled = if let Some(server_id) = server_id.clone() {
247            let preset = self
248                .profile_settings
249                .context_servers
250                .entry(server_id)
251                .or_default();
252            let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
253            *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
254            is_enabled
255        } else {
256            let is_enabled = *self
257                .profile_settings
258                .tools
259                .entry(tool_name.clone())
260                .or_default();
261            *self
262                .profile_settings
263                .tools
264                .entry(tool_name.clone())
265                .or_default() = !is_enabled;
266            is_enabled
267        };
268
269        update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
270            let profile_id = self.profile_id.clone();
271            let default_profile = self.profile_settings.clone();
272            let server_id = server_id.clone();
273            let tool_name = tool_name.clone();
274            move |settings: &mut AgentSettingsContent, _cx| {
275                settings
276                    .v2_setting(|v2_settings| {
277                        let profiles = v2_settings.profiles.get_or_insert_default();
278                        let profile =
279                            profiles
280                                .entry(profile_id)
281                                .or_insert_with(|| AgentProfileContent {
282                                    name: default_profile.name.into(),
283                                    tools: default_profile.tools,
284                                    enable_all_context_servers: Some(
285                                        default_profile.enable_all_context_servers,
286                                    ),
287                                    context_servers: default_profile
288                                        .context_servers
289                                        .into_iter()
290                                        .map(|(server_id, preset)| {
291                                            (
292                                                server_id,
293                                                ContextServerPresetContent {
294                                                    tools: preset.tools,
295                                                },
296                                            )
297                                        })
298                                        .collect(),
299                                });
300
301                        if let Some(server_id) = server_id {
302                            let preset = profile.context_servers.entry(server_id).or_default();
303                            *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
304                        } else {
305                            *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
306                        }
307
308                        Ok(())
309                    })
310                    .ok();
311            }
312        });
313    }
314
315    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
316        self.tool_picker
317            .update(cx, |_this, cx| cx.emit(DismissEvent))
318            .log_err();
319    }
320
321    fn render_match(
322        &self,
323        ix: usize,
324        selected: bool,
325        _window: &mut Window,
326        cx: &mut Context<Picker<Self>>,
327    ) -> Option<Self::ListItem> {
328        let item = &self.filtered_items[ix];
329        match item {
330            PickerItem::ContextServer { server_id, .. } => Some(
331                div()
332                    .px_2()
333                    .pb_1()
334                    .when(ix > 1, |this| {
335                        this.mt_1()
336                            .pt_2()
337                            .border_t_1()
338                            .border_color(cx.theme().colors().border_variant)
339                    })
340                    .child(
341                        Label::new(server_id)
342                            .size(LabelSize::XSmall)
343                            .color(Color::Muted),
344                    )
345                    .into_any_element(),
346            ),
347            PickerItem::Tool { name, server_id } => {
348                let is_enabled = if let Some(server_id) = server_id {
349                    self.profile_settings
350                        .context_servers
351                        .get(server_id.as_ref())
352                        .and_then(|preset| preset.tools.get(name))
353                        .copied()
354                        .unwrap_or(self.profile_settings.enable_all_context_servers)
355                } else {
356                    self.profile_settings
357                        .tools
358                        .get(name)
359                        .copied()
360                        .unwrap_or(false)
361                };
362
363                Some(
364                    ListItem::new(ix)
365                        .inset(true)
366                        .spacing(ListItemSpacing::Sparse)
367                        .toggle_state(selected)
368                        .child(Label::new(name.clone()))
369                        .end_slot::<Icon>(is_enabled.then(|| {
370                            Icon::new(IconName::Check)
371                                .size(IconSize::Small)
372                                .color(Color::Success)
373                        }))
374                        .into_any_element(),
375                )
376            }
377        }
378    }
379}