tool_picker.rs

  1use std::{collections::BTreeMap, sync::Arc};
  2
  3use agent_settings::{
  4    AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent,
  5    ContextServerPresetContent,
  6};
  7use assistant_tool::{ToolSource, ToolWorkingSet};
  8use fs::Fs;
  9use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
 10use picker::{Picker, PickerDelegate};
 11use settings::update_settings_file;
 12use ui::{ListItem, ListItemSpacing, prelude::*};
 13use util::ResultExt as _;
 14
 15pub struct ToolPicker {
 16    picker: Entity<Picker<ToolPickerDelegate>>,
 17}
 18
 19#[derive(Clone, Copy, Debug, PartialEq)]
 20pub enum ToolPickerMode {
 21    BuiltinTools,
 22    McpTools,
 23}
 24
 25impl ToolPicker {
 26    pub fn builtin_tools(
 27        delegate: ToolPickerDelegate,
 28        window: &mut Window,
 29        cx: &mut Context<Self>,
 30    ) -> Self {
 31        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 32        Self { picker }
 33    }
 34
 35    pub fn mcp_tools(
 36        delegate: ToolPickerDelegate,
 37        window: &mut Window,
 38        cx: &mut Context<Self>,
 39    ) -> Self {
 40        let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
 41        Self { picker }
 42    }
 43}
 44
 45impl EventEmitter<DismissEvent> for ToolPicker {}
 46
 47impl Focusable for ToolPicker {
 48    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 49        self.picker.focus_handle(cx)
 50    }
 51}
 52
 53impl Render for ToolPicker {
 54    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 55        v_flex().w(rems(34.)).child(self.picker.clone())
 56    }
 57}
 58
 59#[derive(Debug, Clone)]
 60pub enum PickerItem {
 61    Tool {
 62        server_id: Option<Arc<str>>,
 63        name: Arc<str>,
 64    },
 65    ContextServer {
 66        server_id: Arc<str>,
 67    },
 68}
 69
 70pub struct ToolPickerDelegate {
 71    tool_picker: WeakEntity<ToolPicker>,
 72    fs: Arc<dyn Fs>,
 73    items: Arc<Vec<PickerItem>>,
 74    profile_id: AgentProfileId,
 75    profile_settings: AgentProfileSettings,
 76    filtered_items: Vec<PickerItem>,
 77    selected_index: usize,
 78    mode: ToolPickerMode,
 79}
 80
 81impl ToolPickerDelegate {
 82    pub fn new(
 83        mode: ToolPickerMode,
 84        fs: Arc<dyn Fs>,
 85        tool_set: Entity<ToolWorkingSet>,
 86        profile_id: AgentProfileId,
 87        profile_settings: AgentProfileSettings,
 88        cx: &mut Context<ToolPicker>,
 89    ) -> Self {
 90        let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
 91
 92        Self {
 93            tool_picker: cx.entity().downgrade(),
 94            fs,
 95            items,
 96            profile_id,
 97            profile_settings,
 98            filtered_items: Vec::new(),
 99            selected_index: 0,
100            mode,
101        }
102    }
103
104    fn resolve_items(
105        mode: ToolPickerMode,
106        tool_set: &Entity<ToolWorkingSet>,
107        cx: &mut App,
108    ) -> Vec<PickerItem> {
109        let mut items = Vec::new();
110        for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
111            match source {
112                ToolSource::Native => {
113                    if mode == ToolPickerMode::BuiltinTools {
114                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
115                            name: tool.name().into(),
116                            server_id: None,
117                        }));
118                    }
119                }
120                ToolSource::ContextServer { id } => {
121                    if mode == ToolPickerMode::McpTools && !tools.is_empty() {
122                        let server_id: Arc<str> = id.clone().into();
123                        items.push(PickerItem::ContextServer {
124                            server_id: server_id.clone(),
125                        });
126                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
127                            name: tool.name().into(),
128                            server_id: Some(server_id.clone()),
129                        }));
130                    }
131                }
132            }
133        }
134        items
135    }
136}
137
138impl PickerDelegate for ToolPickerDelegate {
139    type ListItem = AnyElement;
140
141    fn match_count(&self) -> usize {
142        self.filtered_items.len()
143    }
144
145    fn selected_index(&self) -> usize {
146        self.selected_index
147    }
148
149    fn set_selected_index(
150        &mut self,
151        ix: usize,
152        _window: &mut Window,
153        _cx: &mut Context<Picker<Self>>,
154    ) {
155        self.selected_index = ix;
156    }
157
158    fn can_select(
159        &mut self,
160        ix: usize,
161        _window: &mut Window,
162        _cx: &mut Context<Picker<Self>>,
163    ) -> bool {
164        let item = &self.filtered_items[ix];
165        match item {
166            PickerItem::Tool { .. } => true,
167            PickerItem::ContextServer { .. } => false,
168        }
169    }
170
171    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
172        match self.mode {
173            ToolPickerMode::BuiltinTools => "Search built-in tools…",
174            ToolPickerMode::McpTools => "Search MCP tools…",
175        }
176        .into()
177    }
178
179    fn update_matches(
180        &mut self,
181        query: String,
182        window: &mut Window,
183        cx: &mut Context<Picker<Self>>,
184    ) -> Task<()> {
185        let all_items = self.items.clone();
186
187        cx.spawn_in(window, async move |this, cx| {
188            let filtered_items = cx
189                .background_spawn(async move {
190                    let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
191                        BTreeMap::default();
192
193                    for item in all_items.iter() {
194                        if let PickerItem::Tool { server_id, name } = item.clone()
195                            && name.contains(&query) {
196                                tools_by_provider.entry(server_id).or_default().push(name);
197                            }
198                    }
199
200                    let mut items = Vec::new();
201
202                    for (server_id, names) in tools_by_provider {
203                        if let Some(server_id) = server_id.clone() {
204                            items.push(PickerItem::ContextServer { server_id });
205                        }
206                        for name in names {
207                            items.push(PickerItem::Tool {
208                                server_id: server_id.clone(),
209                                name,
210                            });
211                        }
212                    }
213
214                    items
215                })
216                .await;
217
218            this.update(cx, |this, _cx| {
219                this.delegate.filtered_items = filtered_items;
220                this.delegate.selected_index = this
221                    .delegate
222                    .selected_index
223                    .min(this.delegate.filtered_items.len().saturating_sub(1));
224            })
225            .log_err();
226        })
227    }
228
229    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
230        if self.filtered_items.is_empty() {
231            self.dismissed(window, cx);
232            return;
233        }
234
235        let item = &self.filtered_items[self.selected_index];
236
237        let PickerItem::Tool {
238            name: tool_name,
239            server_id,
240        } = item
241        else {
242            return;
243        };
244
245        let is_currently_enabled = if let Some(server_id) = server_id.clone() {
246            let preset = self
247                .profile_settings
248                .context_servers
249                .entry(server_id)
250                .or_default();
251            let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
252            *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
253            is_enabled
254        } else {
255            let is_enabled = *self
256                .profile_settings
257                .tools
258                .entry(tool_name.clone())
259                .or_default();
260            *self
261                .profile_settings
262                .tools
263                .entry(tool_name.clone())
264                .or_default() = !is_enabled;
265            is_enabled
266        };
267
268        update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
269            let profile_id = self.profile_id.clone();
270            let default_profile = self.profile_settings.clone();
271            let server_id = server_id.clone();
272            let tool_name = tool_name.clone();
273            move |settings: &mut AgentSettingsContent, _cx| {
274                let profiles = settings.profiles.get_or_insert_default();
275                let profile = profiles
276                    .entry(profile_id)
277                    .or_insert_with(|| AgentProfileContent {
278                        name: default_profile.name.into(),
279                        tools: default_profile.tools,
280                        enable_all_context_servers: Some(
281                            default_profile.enable_all_context_servers,
282                        ),
283                        context_servers: default_profile
284                            .context_servers
285                            .into_iter()
286                            .map(|(server_id, preset)| {
287                                (
288                                    server_id,
289                                    ContextServerPresetContent {
290                                        tools: preset.tools,
291                                    },
292                                )
293                            })
294                            .collect(),
295                    });
296
297                if let Some(server_id) = server_id {
298                    let preset = profile.context_servers.entry(server_id).or_default();
299                    *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
300                } else {
301                    *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
302                }
303            }
304        });
305    }
306
307    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
308        self.tool_picker
309            .update(cx, |_this, cx| cx.emit(DismissEvent))
310            .log_err();
311    }
312
313    fn render_match(
314        &self,
315        ix: usize,
316        selected: bool,
317        _window: &mut Window,
318        cx: &mut Context<Picker<Self>>,
319    ) -> Option<Self::ListItem> {
320        let item = &self.filtered_items[ix];
321        match item {
322            PickerItem::ContextServer { server_id, .. } => Some(
323                div()
324                    .px_2()
325                    .pb_1()
326                    .when(ix > 1, |this| {
327                        this.mt_1()
328                            .pt_2()
329                            .border_t_1()
330                            .border_color(cx.theme().colors().border_variant)
331                    })
332                    .child(
333                        Label::new(server_id)
334                            .size(LabelSize::XSmall)
335                            .color(Color::Muted),
336                    )
337                    .into_any_element(),
338            ),
339            PickerItem::Tool { name, server_id } => {
340                let is_enabled = if let Some(server_id) = server_id {
341                    self.profile_settings
342                        .context_servers
343                        .get(server_id.as_ref())
344                        .and_then(|preset| preset.tools.get(name))
345                        .copied()
346                        .unwrap_or(self.profile_settings.enable_all_context_servers)
347                } else {
348                    self.profile_settings
349                        .tools
350                        .get(name)
351                        .copied()
352                        .unwrap_or(false)
353                };
354
355                Some(
356                    ListItem::new(ix)
357                        .inset(true)
358                        .spacing(ListItemSpacing::Sparse)
359                        .toggle_state(selected)
360                        .child(Label::new(name.clone()))
361                        .end_slot::<Icon>(is_enabled.then(|| {
362                            Icon::new(IconName::Check)
363                                .size(IconSize::Small)
364                                .color(Color::Success)
365                        }))
366                        .into_any_element(),
367                )
368            }
369        }
370    }
371}