tool_picker.rs

  1use std::{collections::BTreeMap, sync::Arc};
  2
  3use agent::ContextServerRegistry;
  4use agent_settings::{AgentProfileId, AgentProfileSettings};
  5use fs::Fs;
  6use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
  7use picker::{Picker, PickerDelegate};
  8use settings::{AgentProfileContent, ContextServerPresetContent, update_settings_file};
  9use ui::{ListItem, ListItemSpacing, prelude::*};
 10use util::ResultExt as _;
 11
 12pub struct ToolPicker {
 13    picker: Entity<Picker<ToolPickerDelegate>>,
 14}
 15
 16#[derive(Clone, Copy, Debug, PartialEq)]
 17enum ToolPickerMode {
 18    BuiltinTools,
 19    McpTools,
 20}
 21
 22impl ToolPicker {
 23    pub fn builtin_tools(
 24        delegate: ToolPickerDelegate,
 25        window: &mut Window,
 26        cx: &mut Context<Self>,
 27    ) -> Self {
 28        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 29        Self { picker }
 30    }
 31
 32    pub fn mcp_tools(
 33        delegate: ToolPickerDelegate,
 34        window: &mut Window,
 35        cx: &mut Context<Self>,
 36    ) -> Self {
 37        let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
 38        Self { picker }
 39    }
 40}
 41
 42impl EventEmitter<DismissEvent> for ToolPicker {}
 43
 44impl Focusable for ToolPicker {
 45    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 46        self.picker.focus_handle(cx)
 47    }
 48}
 49
 50impl Render for ToolPicker {
 51    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 52        v_flex().w(rems(34.)).child(self.picker.clone())
 53    }
 54}
 55
 56#[derive(Debug, Clone)]
 57pub enum PickerItem {
 58    Tool {
 59        server_id: Option<Arc<str>>,
 60        name: Arc<str>,
 61    },
 62    ContextServer {
 63        server_id: Arc<str>,
 64    },
 65}
 66
 67pub struct ToolPickerDelegate {
 68    tool_picker: WeakEntity<ToolPicker>,
 69    fs: Arc<dyn Fs>,
 70    items: Arc<Vec<PickerItem>>,
 71    profile_id: AgentProfileId,
 72    profile_settings: AgentProfileSettings,
 73    filtered_items: Vec<PickerItem>,
 74    selected_index: usize,
 75    mode: ToolPickerMode,
 76}
 77
 78impl ToolPickerDelegate {
 79    pub fn builtin_tools(
 80        tool_names: Vec<Arc<str>>,
 81        fs: Arc<dyn Fs>,
 82        profile_id: AgentProfileId,
 83        profile_settings: AgentProfileSettings,
 84        cx: &mut Context<ToolPicker>,
 85    ) -> Self {
 86        Self::new(
 87            Arc::new(
 88                tool_names
 89                    .into_iter()
 90                    .map(|name| PickerItem::Tool {
 91                        name,
 92                        server_id: None,
 93                    })
 94                    .collect(),
 95            ),
 96            ToolPickerMode::BuiltinTools,
 97            fs,
 98            profile_id,
 99            profile_settings,
100            cx,
101        )
102    }
103
104    pub fn mcp_tools(
105        registry: &Entity<ContextServerRegistry>,
106        fs: Arc<dyn Fs>,
107        profile_id: AgentProfileId,
108        profile_settings: AgentProfileSettings,
109        cx: &mut Context<ToolPicker>,
110    ) -> Self {
111        let mut items = Vec::new();
112
113        for (id, tools) in registry.read(cx).servers() {
114            let server_id = id.clone().0;
115            items.push(PickerItem::ContextServer {
116                server_id: server_id.clone(),
117            });
118            items.extend(tools.keys().map(|tool_name| PickerItem::Tool {
119                name: tool_name.clone().into(),
120                server_id: Some(server_id.clone()),
121            }));
122        }
123
124        Self::new(
125            Arc::new(items),
126            ToolPickerMode::McpTools,
127            fs,
128            profile_id,
129            profile_settings,
130            cx,
131        )
132    }
133
134    fn new(
135        items: Arc<Vec<PickerItem>>,
136        mode: ToolPickerMode,
137        fs: Arc<dyn Fs>,
138        profile_id: AgentProfileId,
139        profile_settings: AgentProfileSettings,
140        cx: &mut Context<ToolPicker>,
141    ) -> Self {
142        Self {
143            tool_picker: cx.entity().downgrade(),
144            mode,
145            fs,
146            items,
147            profile_id,
148            profile_settings,
149            filtered_items: Vec::new(),
150            selected_index: 0,
151        }
152    }
153}
154
155impl PickerDelegate for ToolPickerDelegate {
156    type ListItem = AnyElement;
157
158    fn match_count(&self) -> usize {
159        self.filtered_items.len()
160    }
161
162    fn selected_index(&self) -> usize {
163        self.selected_index
164    }
165
166    fn set_selected_index(
167        &mut self,
168        ix: usize,
169        _window: &mut Window,
170        _cx: &mut Context<Picker<Self>>,
171    ) {
172        self.selected_index = ix;
173    }
174
175    fn can_select(&self, ix: usize, _window: &mut Window, _cx: &mut Context<Picker<Self>>) -> bool {
176        let item = &self.filtered_items[ix];
177        match item {
178            PickerItem::Tool { .. } => true,
179            PickerItem::ContextServer { .. } => false,
180        }
181    }
182
183    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
184        match self.mode {
185            ToolPickerMode::BuiltinTools => "Search built-in tools…",
186            ToolPickerMode::McpTools => "Search MCP tools…",
187        }
188        .into()
189    }
190
191    fn update_matches(
192        &mut self,
193        query: String,
194        window: &mut Window,
195        cx: &mut Context<Picker<Self>>,
196    ) -> Task<()> {
197        let all_items = self.items.clone();
198
199        cx.spawn_in(window, async move |this, cx| {
200            let filtered_items = cx
201                .background_spawn(async move {
202                    let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
203                        BTreeMap::default();
204
205                    for item in all_items.iter() {
206                        if let PickerItem::Tool { server_id, name } = item.clone()
207                            && name.contains(&query)
208                        {
209                            tools_by_provider.entry(server_id).or_default().push(name);
210                        }
211                    }
212
213                    let mut items = Vec::new();
214
215                    for (server_id, names) in tools_by_provider {
216                        if let Some(server_id) = server_id.clone() {
217                            items.push(PickerItem::ContextServer { server_id });
218                        }
219                        for name in names {
220                            items.push(PickerItem::Tool {
221                                server_id: server_id.clone(),
222                                name,
223                            });
224                        }
225                    }
226
227                    items
228                })
229                .await;
230
231            this.update(cx, |this, _cx| {
232                this.delegate.filtered_items = filtered_items;
233                this.delegate.selected_index = this
234                    .delegate
235                    .selected_index
236                    .min(this.delegate.filtered_items.len().saturating_sub(1));
237            })
238            .log_err();
239        })
240    }
241
242    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
243        if self.filtered_items.is_empty() {
244            self.dismissed(window, cx);
245            return;
246        }
247
248        let item = &self.filtered_items[self.selected_index];
249
250        let PickerItem::Tool {
251            name: tool_name,
252            server_id,
253        } = item
254        else {
255            return;
256        };
257
258        let is_currently_enabled = if let Some(server_id) = server_id.clone() {
259            let preset = self
260                .profile_settings
261                .context_servers
262                .entry(server_id)
263                .or_default();
264            let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
265            *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
266            is_enabled
267        } else {
268            let is_enabled = *self
269                .profile_settings
270                .tools
271                .entry(tool_name.clone())
272                .or_default();
273            *self
274                .profile_settings
275                .tools
276                .entry(tool_name.clone())
277                .or_default() = !is_enabled;
278            is_enabled
279        };
280
281        update_settings_file(self.fs.clone(), cx, {
282            let profile_id = self.profile_id.clone();
283            let default_profile = self.profile_settings.clone();
284            let server_id = server_id.clone();
285            let tool_name = tool_name.clone();
286            move |settings, _cx| {
287                let profiles = settings
288                    .agent
289                    .get_or_insert_default()
290                    .profiles
291                    .get_or_insert_default();
292                let profile = profiles
293                    .entry(profile_id.0)
294                    .or_insert_with(|| AgentProfileContent {
295                        name: default_profile.name.into(),
296                        tools: default_profile.tools,
297                        enable_all_context_servers: Some(
298                            default_profile.enable_all_context_servers,
299                        ),
300                        context_servers: default_profile
301                            .context_servers
302                            .into_iter()
303                            .map(|(server_id, preset)| {
304                                (
305                                    server_id,
306                                    ContextServerPresetContent {
307                                        tools: preset.tools,
308                                    },
309                                )
310                            })
311                            .collect(),
312                        default_model: default_profile.default_model.clone(),
313                    });
314
315                if let Some(server_id) = server_id {
316                    let preset = profile.context_servers.entry(server_id).or_default();
317                    *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
318                } else {
319                    *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
320                }
321            }
322        });
323    }
324
325    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
326        self.tool_picker
327            .update(cx, |_this, cx| cx.emit(DismissEvent))
328            .log_err();
329    }
330
331    fn render_match(
332        &self,
333        ix: usize,
334        selected: bool,
335        _window: &mut Window,
336        cx: &mut Context<Picker<Self>>,
337    ) -> Option<Self::ListItem> {
338        let item = &self.filtered_items.get(ix)?;
339        match item {
340            PickerItem::ContextServer { server_id, .. } => Some(
341                div()
342                    .px_2()
343                    .pb_1()
344                    .when(ix > 1, |this| {
345                        this.mt_1()
346                            .pt_2()
347                            .border_t_1()
348                            .border_color(cx.theme().colors().border_variant)
349                    })
350                    .child(
351                        Label::new(server_id)
352                            .size(LabelSize::XSmall)
353                            .color(Color::Muted),
354                    )
355                    .into_any_element(),
356            ),
357            PickerItem::Tool { name, server_id } => {
358                let is_enabled = if let Some(server_id) = server_id {
359                    self.profile_settings
360                        .context_servers
361                        .get(server_id.as_ref())
362                        .and_then(|preset| preset.tools.get(name))
363                        .copied()
364                        .unwrap_or(self.profile_settings.enable_all_context_servers)
365                } else {
366                    self.profile_settings
367                        .tools
368                        .get(name)
369                        .copied()
370                        .unwrap_or(false)
371                };
372
373                Some(
374                    ListItem::new(ix)
375                        .inset(true)
376                        .spacing(ListItemSpacing::Sparse)
377                        .toggle_state(selected)
378                        .child(Label::new(name.clone()))
379                        .end_slot::<Icon>(is_enabled.then(|| {
380                            Icon::new(IconName::Check)
381                                .size(IconSize::Small)
382                                .color(Color::Success)
383                        }))
384                        .into_any_element(),
385                )
386            }
387        }
388    }
389}