tool_picker.rs

  1use std::{collections::BTreeMap, sync::Arc};
  2
  3use agent::ContextServerRegistry;
  4use agent_settings::{AgentProfileId, AgentProfileSettings};
  5use fs::Fs;
  6use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
  7use picker::{Picker, PickerDelegate};
  8use settings::{AgentProfileContent, ContextServerPresetContent, update_settings_file};
  9use ui::{ListItem, ListItemSpacing, prelude::*};
 10use util::ResultExt as _;
 11
 12pub struct ToolPicker {
 13    picker: Entity<Picker<ToolPickerDelegate>>,
 14}
 15
 16#[derive(Clone, Copy, Debug, PartialEq)]
 17enum ToolPickerMode {
 18    BuiltinTools,
 19    McpTools,
 20}
 21
 22impl ToolPicker {
 23    pub fn builtin_tools(
 24        delegate: ToolPickerDelegate,
 25        window: &mut Window,
 26        cx: &mut Context<Self>,
 27    ) -> Self {
 28        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 29        Self { picker }
 30    }
 31
 32    pub fn mcp_tools(
 33        delegate: ToolPickerDelegate,
 34        window: &mut Window,
 35        cx: &mut Context<Self>,
 36    ) -> Self {
 37        let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
 38        Self { picker }
 39    }
 40}
 41
 42impl EventEmitter<DismissEvent> for ToolPicker {}
 43
 44impl Focusable for ToolPicker {
 45    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 46        self.picker.focus_handle(cx)
 47    }
 48}
 49
 50impl Render for ToolPicker {
 51    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 52        v_flex().w(rems(34.)).child(self.picker.clone())
 53    }
 54}
 55
 56#[derive(Debug, Clone)]
 57pub enum PickerItem {
 58    Tool {
 59        server_id: Option<Arc<str>>,
 60        name: Arc<str>,
 61    },
 62    ContextServer {
 63        server_id: Arc<str>,
 64    },
 65}
 66
 67pub struct ToolPickerDelegate {
 68    tool_picker: WeakEntity<ToolPicker>,
 69    fs: Arc<dyn Fs>,
 70    items: Arc<Vec<PickerItem>>,
 71    profile_id: AgentProfileId,
 72    profile_settings: AgentProfileSettings,
 73    filtered_items: Vec<PickerItem>,
 74    selected_index: usize,
 75    mode: ToolPickerMode,
 76}
 77
 78impl ToolPickerDelegate {
 79    pub fn builtin_tools(
 80        tool_names: Vec<Arc<str>>,
 81        fs: Arc<dyn Fs>,
 82        profile_id: AgentProfileId,
 83        profile_settings: AgentProfileSettings,
 84        cx: &mut Context<ToolPicker>,
 85    ) -> Self {
 86        Self::new(
 87            Arc::new(
 88                tool_names
 89                    .into_iter()
 90                    .map(|name| PickerItem::Tool {
 91                        name,
 92                        server_id: None,
 93                    })
 94                    .collect(),
 95            ),
 96            ToolPickerMode::BuiltinTools,
 97            fs,
 98            profile_id,
 99            profile_settings,
100            cx,
101        )
102    }
103
104    pub fn mcp_tools(
105        registry: &Entity<ContextServerRegistry>,
106        fs: Arc<dyn Fs>,
107        profile_id: AgentProfileId,
108        profile_settings: AgentProfileSettings,
109        cx: &mut Context<ToolPicker>,
110    ) -> Self {
111        let mut items = Vec::new();
112
113        for (id, tools) in registry.read(cx).servers() {
114            let server_id = id.clone().0;
115            items.push(PickerItem::ContextServer {
116                server_id: server_id.clone(),
117            });
118            items.extend(tools.keys().map(|tool_name| PickerItem::Tool {
119                name: tool_name.clone().into(),
120                server_id: Some(server_id.clone()),
121            }));
122        }
123
124        Self::new(
125            Arc::new(items),
126            ToolPickerMode::McpTools,
127            fs,
128            profile_id,
129            profile_settings,
130            cx,
131        )
132    }
133
134    fn new(
135        items: Arc<Vec<PickerItem>>,
136        mode: ToolPickerMode,
137        fs: Arc<dyn Fs>,
138        profile_id: AgentProfileId,
139        profile_settings: AgentProfileSettings,
140        cx: &mut Context<ToolPicker>,
141    ) -> Self {
142        Self {
143            tool_picker: cx.entity().downgrade(),
144            mode,
145            fs,
146            items,
147            profile_id,
148            profile_settings,
149            filtered_items: Vec::new(),
150            selected_index: 0,
151        }
152    }
153}
154
155impl PickerDelegate for ToolPickerDelegate {
156    type ListItem = AnyElement;
157
158    fn match_count(&self) -> usize {
159        self.filtered_items.len()
160    }
161
162    fn selected_index(&self) -> usize {
163        self.selected_index
164    }
165
166    fn set_selected_index(
167        &mut self,
168        ix: usize,
169        _window: &mut Window,
170        _cx: &mut Context<Picker<Self>>,
171    ) {
172        self.selected_index = ix;
173    }
174
175    fn can_select(
176        &mut self,
177        ix: usize,
178        _window: &mut Window,
179        _cx: &mut Context<Picker<Self>>,
180    ) -> bool {
181        let item = &self.filtered_items[ix];
182        match item {
183            PickerItem::Tool { .. } => true,
184            PickerItem::ContextServer { .. } => false,
185        }
186    }
187
188    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
189        match self.mode {
190            ToolPickerMode::BuiltinTools => "Search built-in tools…",
191            ToolPickerMode::McpTools => "Search MCP tools…",
192        }
193        .into()
194    }
195
196    fn update_matches(
197        &mut self,
198        query: String,
199        window: &mut Window,
200        cx: &mut Context<Picker<Self>>,
201    ) -> Task<()> {
202        let all_items = self.items.clone();
203
204        cx.spawn_in(window, async move |this, cx| {
205            let filtered_items = cx
206                .background_spawn(async move {
207                    let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
208                        BTreeMap::default();
209
210                    for item in all_items.iter() {
211                        if let PickerItem::Tool { server_id, name } = item.clone()
212                            && name.contains(&query)
213                        {
214                            tools_by_provider.entry(server_id).or_default().push(name);
215                        }
216                    }
217
218                    let mut items = Vec::new();
219
220                    for (server_id, names) in tools_by_provider {
221                        if let Some(server_id) = server_id.clone() {
222                            items.push(PickerItem::ContextServer { server_id });
223                        }
224                        for name in names {
225                            items.push(PickerItem::Tool {
226                                server_id: server_id.clone(),
227                                name,
228                            });
229                        }
230                    }
231
232                    items
233                })
234                .await;
235
236            this.update(cx, |this, _cx| {
237                this.delegate.filtered_items = filtered_items;
238                this.delegate.selected_index = this
239                    .delegate
240                    .selected_index
241                    .min(this.delegate.filtered_items.len().saturating_sub(1));
242            })
243            .log_err();
244        })
245    }
246
247    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
248        if self.filtered_items.is_empty() {
249            self.dismissed(window, cx);
250            return;
251        }
252
253        let item = &self.filtered_items[self.selected_index];
254
255        let PickerItem::Tool {
256            name: tool_name,
257            server_id,
258        } = item
259        else {
260            return;
261        };
262
263        let is_currently_enabled = if let Some(server_id) = server_id.clone() {
264            let preset = self
265                .profile_settings
266                .context_servers
267                .entry(server_id)
268                .or_default();
269            let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
270            *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
271            is_enabled
272        } else {
273            let is_enabled = *self
274                .profile_settings
275                .tools
276                .entry(tool_name.clone())
277                .or_default();
278            *self
279                .profile_settings
280                .tools
281                .entry(tool_name.clone())
282                .or_default() = !is_enabled;
283            is_enabled
284        };
285
286        update_settings_file(self.fs.clone(), cx, {
287            let profile_id = self.profile_id.clone();
288            let default_profile = self.profile_settings.clone();
289            let server_id = server_id.clone();
290            let tool_name = tool_name.clone();
291            move |settings, _cx| {
292                let profiles = settings
293                    .agent
294                    .get_or_insert_default()
295                    .profiles
296                    .get_or_insert_default();
297                let profile = profiles
298                    .entry(profile_id.0)
299                    .or_insert_with(|| AgentProfileContent {
300                        name: default_profile.name.into(),
301                        tools: default_profile.tools,
302                        enable_all_context_servers: Some(
303                            default_profile.enable_all_context_servers,
304                        ),
305                        context_servers: default_profile
306                            .context_servers
307                            .into_iter()
308                            .map(|(server_id, preset)| {
309                                (
310                                    server_id,
311                                    ContextServerPresetContent {
312                                        tools: preset.tools,
313                                    },
314                                )
315                            })
316                            .collect(),
317                    });
318
319                if let Some(server_id) = server_id {
320                    let preset = profile.context_servers.entry(server_id).or_default();
321                    *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
322                } else {
323                    *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
324                }
325            }
326        });
327    }
328
329    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
330        self.tool_picker
331            .update(cx, |_this, cx| cx.emit(DismissEvent))
332            .log_err();
333    }
334
335    fn render_match(
336        &self,
337        ix: usize,
338        selected: bool,
339        _window: &mut Window,
340        cx: &mut Context<Picker<Self>>,
341    ) -> Option<Self::ListItem> {
342        let item = &self.filtered_items.get(ix)?;
343        match item {
344            PickerItem::ContextServer { server_id, .. } => Some(
345                div()
346                    .px_2()
347                    .pb_1()
348                    .when(ix > 1, |this| {
349                        this.mt_1()
350                            .pt_2()
351                            .border_t_1()
352                            .border_color(cx.theme().colors().border_variant)
353                    })
354                    .child(
355                        Label::new(server_id)
356                            .size(LabelSize::XSmall)
357                            .color(Color::Muted),
358                    )
359                    .into_any_element(),
360            ),
361            PickerItem::Tool { name, server_id } => {
362                let is_enabled = if let Some(server_id) = server_id {
363                    self.profile_settings
364                        .context_servers
365                        .get(server_id.as_ref())
366                        .and_then(|preset| preset.tools.get(name))
367                        .copied()
368                        .unwrap_or(self.profile_settings.enable_all_context_servers)
369                } else {
370                    self.profile_settings
371                        .tools
372                        .get(name)
373                        .copied()
374                        .unwrap_or(false)
375                };
376
377                Some(
378                    ListItem::new(ix)
379                        .inset(true)
380                        .spacing(ListItemSpacing::Sparse)
381                        .toggle_state(selected)
382                        .child(Label::new(name.clone()))
383                        .end_slot::<Icon>(is_enabled.then(|| {
384                            Icon::new(IconName::Check)
385                                .size(IconSize::Small)
386                                .color(Color::Success)
387                        }))
388                        .into_any_element(),
389                )
390            }
391        }
392    }
393}