tool_picker.rs

  1use std::{collections::BTreeMap, sync::Arc};
  2
  3use agent_settings::{AgentProfileId, AgentProfileSettings};
  4use assistant_tool::{ToolSource, ToolWorkingSet};
  5use fs::Fs;
  6use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
  7use picker::{Picker, PickerDelegate};
  8use settings::{AgentProfileContent, ContextServerPresetContent, update_settings_file};
  9use ui::{ListItem, ListItemSpacing, prelude::*};
 10use util::ResultExt as _;
 11
 12pub struct ToolPicker {
 13    picker: Entity<Picker<ToolPickerDelegate>>,
 14}
 15
 16#[derive(Clone, Copy, Debug, PartialEq)]
 17pub enum ToolPickerMode {
 18    BuiltinTools,
 19    McpTools,
 20}
 21
 22impl ToolPicker {
 23    pub fn builtin_tools(
 24        delegate: ToolPickerDelegate,
 25        window: &mut Window,
 26        cx: &mut Context<Self>,
 27    ) -> Self {
 28        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 29        Self { picker }
 30    }
 31
 32    pub fn mcp_tools(
 33        delegate: ToolPickerDelegate,
 34        window: &mut Window,
 35        cx: &mut Context<Self>,
 36    ) -> Self {
 37        let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
 38        Self { picker }
 39    }
 40}
 41
 42impl EventEmitter<DismissEvent> for ToolPicker {}
 43
 44impl Focusable for ToolPicker {
 45    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 46        self.picker.focus_handle(cx)
 47    }
 48}
 49
 50impl Render for ToolPicker {
 51    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 52        v_flex().w(rems(34.)).child(self.picker.clone())
 53    }
 54}
 55
 56#[derive(Debug, Clone)]
 57pub enum PickerItem {
 58    Tool {
 59        server_id: Option<Arc<str>>,
 60        name: Arc<str>,
 61    },
 62    ContextServer {
 63        server_id: Arc<str>,
 64    },
 65}
 66
 67pub struct ToolPickerDelegate {
 68    tool_picker: WeakEntity<ToolPicker>,
 69    fs: Arc<dyn Fs>,
 70    items: Arc<Vec<PickerItem>>,
 71    profile_id: AgentProfileId,
 72    profile_settings: AgentProfileSettings,
 73    filtered_items: Vec<PickerItem>,
 74    selected_index: usize,
 75    mode: ToolPickerMode,
 76}
 77
 78impl ToolPickerDelegate {
 79    pub fn new(
 80        mode: ToolPickerMode,
 81        fs: Arc<dyn Fs>,
 82        tool_set: Entity<ToolWorkingSet>,
 83        profile_id: AgentProfileId,
 84        profile_settings: AgentProfileSettings,
 85        cx: &mut Context<ToolPicker>,
 86    ) -> Self {
 87        let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
 88
 89        Self {
 90            tool_picker: cx.entity().downgrade(),
 91            fs,
 92            items,
 93            profile_id,
 94            profile_settings,
 95            filtered_items: Vec::new(),
 96            selected_index: 0,
 97            mode,
 98        }
 99    }
100
101    fn resolve_items(
102        mode: ToolPickerMode,
103        tool_set: &Entity<ToolWorkingSet>,
104        cx: &mut App,
105    ) -> Vec<PickerItem> {
106        let mut items = Vec::new();
107        for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
108            match source {
109                ToolSource::Native => {
110                    if mode == ToolPickerMode::BuiltinTools {
111                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
112                            name: tool.name().into(),
113                            server_id: None,
114                        }));
115                    }
116                }
117                ToolSource::ContextServer { id } => {
118                    if mode == ToolPickerMode::McpTools && !tools.is_empty() {
119                        let server_id: Arc<str> = id.clone().into();
120                        items.push(PickerItem::ContextServer {
121                            server_id: server_id.clone(),
122                        });
123                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
124                            name: tool.name().into(),
125                            server_id: Some(server_id.clone()),
126                        }));
127                    }
128                }
129            }
130        }
131        items
132    }
133}
134
135impl PickerDelegate for ToolPickerDelegate {
136    type ListItem = AnyElement;
137
138    fn match_count(&self) -> usize {
139        self.filtered_items.len()
140    }
141
142    fn selected_index(&self) -> usize {
143        self.selected_index
144    }
145
146    fn set_selected_index(
147        &mut self,
148        ix: usize,
149        _window: &mut Window,
150        _cx: &mut Context<Picker<Self>>,
151    ) {
152        self.selected_index = ix;
153    }
154
155    fn can_select(
156        &mut self,
157        ix: usize,
158        _window: &mut Window,
159        _cx: &mut Context<Picker<Self>>,
160    ) -> bool {
161        let item = &self.filtered_items[ix];
162        match item {
163            PickerItem::Tool { .. } => true,
164            PickerItem::ContextServer { .. } => false,
165        }
166    }
167
168    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
169        match self.mode {
170            ToolPickerMode::BuiltinTools => "Search built-in tools…",
171            ToolPickerMode::McpTools => "Search MCP tools…",
172        }
173        .into()
174    }
175
176    fn update_matches(
177        &mut self,
178        query: String,
179        window: &mut Window,
180        cx: &mut Context<Picker<Self>>,
181    ) -> Task<()> {
182        let all_items = self.items.clone();
183
184        cx.spawn_in(window, async move |this, cx| {
185            let filtered_items = cx
186                .background_spawn(async move {
187                    let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
188                        BTreeMap::default();
189
190                    for item in all_items.iter() {
191                        if let PickerItem::Tool { server_id, name } = item.clone()
192                            && name.contains(&query)
193                        {
194                            tools_by_provider.entry(server_id).or_default().push(name);
195                        }
196                    }
197
198                    let mut items = Vec::new();
199
200                    for (server_id, names) in tools_by_provider {
201                        if let Some(server_id) = server_id.clone() {
202                            items.push(PickerItem::ContextServer { server_id });
203                        }
204                        for name in names {
205                            items.push(PickerItem::Tool {
206                                server_id: server_id.clone(),
207                                name,
208                            });
209                        }
210                    }
211
212                    items
213                })
214                .await;
215
216            this.update(cx, |this, _cx| {
217                this.delegate.filtered_items = filtered_items;
218                this.delegate.selected_index = this
219                    .delegate
220                    .selected_index
221                    .min(this.delegate.filtered_items.len().saturating_sub(1));
222            })
223            .log_err();
224        })
225    }
226
227    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
228        if self.filtered_items.is_empty() {
229            self.dismissed(window, cx);
230            return;
231        }
232
233        let item = &self.filtered_items[self.selected_index];
234
235        let PickerItem::Tool {
236            name: tool_name,
237            server_id,
238        } = item
239        else {
240            return;
241        };
242
243        let is_currently_enabled = if let Some(server_id) = server_id.clone() {
244            let preset = self
245                .profile_settings
246                .context_servers
247                .entry(server_id)
248                .or_default();
249            let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
250            *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
251            is_enabled
252        } else {
253            let is_enabled = *self
254                .profile_settings
255                .tools
256                .entry(tool_name.clone())
257                .or_default();
258            *self
259                .profile_settings
260                .tools
261                .entry(tool_name.clone())
262                .or_default() = !is_enabled;
263            is_enabled
264        };
265
266        update_settings_file(self.fs.clone(), cx, {
267            let profile_id = self.profile_id.clone();
268            let default_profile = self.profile_settings.clone();
269            let server_id = server_id.clone();
270            let tool_name = tool_name.clone();
271            move |settings, _cx| {
272                let profiles = settings
273                    .agent
274                    .get_or_insert_default()
275                    .profiles
276                    .get_or_insert_default();
277                let profile = profiles
278                    .entry(profile_id.0)
279                    .or_insert_with(|| AgentProfileContent {
280                        name: default_profile.name.into(),
281                        tools: default_profile.tools,
282                        enable_all_context_servers: Some(
283                            default_profile.enable_all_context_servers,
284                        ),
285                        context_servers: default_profile
286                            .context_servers
287                            .into_iter()
288                            .map(|(server_id, preset)| {
289                                (
290                                    server_id,
291                                    ContextServerPresetContent {
292                                        tools: preset.tools,
293                                    },
294                                )
295                            })
296                            .collect(),
297                    });
298
299                if let Some(server_id) = server_id {
300                    let preset = profile.context_servers.entry(server_id).or_default();
301                    *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
302                } else {
303                    *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
304                }
305            }
306        });
307    }
308
309    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
310        self.tool_picker
311            .update(cx, |_this, cx| cx.emit(DismissEvent))
312            .log_err();
313    }
314
315    fn render_match(
316        &self,
317        ix: usize,
318        selected: bool,
319        _window: &mut Window,
320        cx: &mut Context<Picker<Self>>,
321    ) -> Option<Self::ListItem> {
322        let item = &self.filtered_items.get(ix)?;
323        match item {
324            PickerItem::ContextServer { server_id, .. } => Some(
325                div()
326                    .px_2()
327                    .pb_1()
328                    .when(ix > 1, |this| {
329                        this.mt_1()
330                            .pt_2()
331                            .border_t_1()
332                            .border_color(cx.theme().colors().border_variant)
333                    })
334                    .child(
335                        Label::new(server_id)
336                            .size(LabelSize::XSmall)
337                            .color(Color::Muted),
338                    )
339                    .into_any_element(),
340            ),
341            PickerItem::Tool { name, server_id } => {
342                let is_enabled = if let Some(server_id) = server_id {
343                    self.profile_settings
344                        .context_servers
345                        .get(server_id.as_ref())
346                        .and_then(|preset| preset.tools.get(name))
347                        .copied()
348                        .unwrap_or(self.profile_settings.enable_all_context_servers)
349                } else {
350                    self.profile_settings
351                        .tools
352                        .get(name)
353                        .copied()
354                        .unwrap_or(false)
355                };
356
357                Some(
358                    ListItem::new(ix)
359                        .inset(true)
360                        .spacing(ListItemSpacing::Sparse)
361                        .toggle_state(selected)
362                        .child(Label::new(name.clone()))
363                        .end_slot::<Icon>(is_enabled.then(|| {
364                            Icon::new(IconName::Check)
365                                .size(IconSize::Small)
366                                .color(Color::Success)
367                        }))
368                        .into_any_element(),
369                )
370            }
371        }
372    }
373}