tool_picker.rs

  1use std::{collections::BTreeMap, sync::Arc};
  2
  3use agent_settings::{
  4    AgentProfile, AgentProfileContent, AgentProfileId, AgentSettings, AgentSettingsContent,
  5    ContextServerPresetContent,
  6};
  7use assistant_tool::{ToolSource, ToolWorkingSet};
  8use fs::Fs;
  9use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
 10use picker::{Picker, PickerDelegate};
 11use settings::{Settings as _, update_settings_file};
 12use ui::{ListItem, ListItemSpacing, prelude::*};
 13use util::ResultExt as _;
 14
 15use crate::ThreadStore;
 16
 17pub struct ToolPicker {
 18    picker: Entity<Picker<ToolPickerDelegate>>,
 19}
 20
 21#[derive(Clone, Copy, Debug, PartialEq)]
 22pub enum ToolPickerMode {
 23    BuiltinTools,
 24    McpTools,
 25}
 26
 27impl ToolPicker {
 28    pub fn builtin_tools(
 29        delegate: ToolPickerDelegate,
 30        window: &mut Window,
 31        cx: &mut Context<Self>,
 32    ) -> Self {
 33        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 34        Self { picker }
 35    }
 36
 37    pub fn mcp_tools(
 38        delegate: ToolPickerDelegate,
 39        window: &mut Window,
 40        cx: &mut Context<Self>,
 41    ) -> Self {
 42        let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
 43        Self { picker }
 44    }
 45}
 46
 47impl EventEmitter<DismissEvent> for ToolPicker {}
 48
 49impl Focusable for ToolPicker {
 50    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 51        self.picker.focus_handle(cx)
 52    }
 53}
 54
 55impl Render for ToolPicker {
 56    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 57        v_flex().w(rems(34.)).child(self.picker.clone())
 58    }
 59}
 60
 61#[derive(Debug, Clone)]
 62pub enum PickerItem {
 63    Tool {
 64        server_id: Option<Arc<str>>,
 65        name: Arc<str>,
 66    },
 67    ContextServer {
 68        server_id: Arc<str>,
 69    },
 70}
 71
 72pub struct ToolPickerDelegate {
 73    tool_picker: WeakEntity<ToolPicker>,
 74    thread_store: WeakEntity<ThreadStore>,
 75    fs: Arc<dyn Fs>,
 76    items: Arc<Vec<PickerItem>>,
 77    profile_id: AgentProfileId,
 78    profile: AgentProfile,
 79    filtered_items: Vec<PickerItem>,
 80    selected_index: usize,
 81    mode: ToolPickerMode,
 82}
 83
 84impl ToolPickerDelegate {
 85    pub fn new(
 86        mode: ToolPickerMode,
 87        fs: Arc<dyn Fs>,
 88        tool_set: Entity<ToolWorkingSet>,
 89        thread_store: WeakEntity<ThreadStore>,
 90        profile_id: AgentProfileId,
 91        profile: AgentProfile,
 92        cx: &mut Context<ToolPicker>,
 93    ) -> Self {
 94        let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
 95
 96        Self {
 97            tool_picker: cx.entity().downgrade(),
 98            thread_store,
 99            fs,
100            items,
101            profile_id,
102            profile,
103            filtered_items: Vec::new(),
104            selected_index: 0,
105            mode,
106        }
107    }
108
109    fn resolve_items(
110        mode: ToolPickerMode,
111        tool_set: &Entity<ToolWorkingSet>,
112        cx: &mut App,
113    ) -> Vec<PickerItem> {
114        let mut items = Vec::new();
115        for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
116            match source {
117                ToolSource::Native => {
118                    if mode == ToolPickerMode::BuiltinTools {
119                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
120                            name: tool.name().into(),
121                            server_id: None,
122                        }));
123                    }
124                }
125                ToolSource::ContextServer { id } => {
126                    if mode == ToolPickerMode::McpTools && !tools.is_empty() {
127                        let server_id: Arc<str> = id.clone().into();
128                        items.push(PickerItem::ContextServer {
129                            server_id: server_id.clone(),
130                        });
131                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
132                            name: tool.name().into(),
133                            server_id: Some(server_id.clone()),
134                        }));
135                    }
136                }
137            }
138        }
139        items
140    }
141}
142
143impl PickerDelegate for ToolPickerDelegate {
144    type ListItem = AnyElement;
145
146    fn match_count(&self) -> usize {
147        self.filtered_items.len()
148    }
149
150    fn selected_index(&self) -> usize {
151        self.selected_index
152    }
153
154    fn set_selected_index(
155        &mut self,
156        ix: usize,
157        _window: &mut Window,
158        _cx: &mut Context<Picker<Self>>,
159    ) {
160        self.selected_index = ix;
161    }
162
163    fn can_select(
164        &mut self,
165        ix: usize,
166        _window: &mut Window,
167        _cx: &mut Context<Picker<Self>>,
168    ) -> bool {
169        let item = &self.filtered_items[ix];
170        match item {
171            PickerItem::Tool { .. } => true,
172            PickerItem::ContextServer { .. } => false,
173        }
174    }
175
176    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
177        match self.mode {
178            ToolPickerMode::BuiltinTools => "Search built-in tools…",
179            ToolPickerMode::McpTools => "Search MCP tools…",
180        }
181        .into()
182    }
183
184    fn update_matches(
185        &mut self,
186        query: String,
187        window: &mut Window,
188        cx: &mut Context<Picker<Self>>,
189    ) -> Task<()> {
190        let all_items = self.items.clone();
191
192        cx.spawn_in(window, async move |this, cx| {
193            let filtered_items = cx
194                .background_spawn(async move {
195                    let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
196                        BTreeMap::default();
197
198                    for item in all_items.iter() {
199                        if let PickerItem::Tool { server_id, name } = item.clone() {
200                            if name.contains(&query) {
201                                tools_by_provider.entry(server_id).or_default().push(name);
202                            }
203                        }
204                    }
205
206                    let mut items = Vec::new();
207
208                    for (server_id, names) in tools_by_provider {
209                        if let Some(server_id) = server_id.clone() {
210                            items.push(PickerItem::ContextServer { server_id });
211                        }
212                        for name in names {
213                            items.push(PickerItem::Tool {
214                                server_id: server_id.clone(),
215                                name,
216                            });
217                        }
218                    }
219
220                    items
221                })
222                .await;
223
224            this.update(cx, |this, _cx| {
225                this.delegate.filtered_items = filtered_items;
226                this.delegate.selected_index = this
227                    .delegate
228                    .selected_index
229                    .min(this.delegate.filtered_items.len().saturating_sub(1));
230            })
231            .log_err();
232        })
233    }
234
235    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
236        if self.filtered_items.is_empty() {
237            self.dismissed(window, cx);
238            return;
239        }
240
241        let item = &self.filtered_items[self.selected_index];
242
243        let PickerItem::Tool {
244            name: tool_name,
245            server_id,
246        } = item
247        else {
248            return;
249        };
250
251        let is_currently_enabled = if let Some(server_id) = server_id.clone() {
252            let preset = self.profile.context_servers.entry(server_id).or_default();
253            let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
254            *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
255            is_enabled
256        } else {
257            let is_enabled = *self.profile.tools.entry(tool_name.clone()).or_default();
258            *self.profile.tools.entry(tool_name.clone()).or_default() = !is_enabled;
259            is_enabled
260        };
261
262        let active_profile_id = &AgentSettings::get_global(cx).default_profile;
263        if active_profile_id == &self.profile_id {
264            self.thread_store
265                .update(cx, |this, cx| {
266                    this.load_profile(self.profile.clone(), cx);
267                })
268                .log_err();
269        }
270
271        update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
272            let profile_id = self.profile_id.clone();
273            let default_profile = self.profile.clone();
274            let server_id = server_id.clone();
275            let tool_name = tool_name.clone();
276            move |settings: &mut AgentSettingsContent, _cx| {
277                settings
278                    .v2_setting(|v2_settings| {
279                        let profiles = v2_settings.profiles.get_or_insert_default();
280                        let profile =
281                            profiles
282                                .entry(profile_id)
283                                .or_insert_with(|| AgentProfileContent {
284                                    name: default_profile.name.into(),
285                                    tools: default_profile.tools,
286                                    enable_all_context_servers: Some(
287                                        default_profile.enable_all_context_servers,
288                                    ),
289                                    context_servers: default_profile
290                                        .context_servers
291                                        .into_iter()
292                                        .map(|(server_id, preset)| {
293                                            (
294                                                server_id,
295                                                ContextServerPresetContent {
296                                                    tools: preset.tools,
297                                                },
298                                            )
299                                        })
300                                        .collect(),
301                                });
302
303                        if let Some(server_id) = server_id {
304                            let preset = profile.context_servers.entry(server_id).or_default();
305                            *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
306                        } else {
307                            *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
308                        }
309
310                        Ok(())
311                    })
312                    .ok();
313            }
314        });
315    }
316
317    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
318        self.tool_picker
319            .update(cx, |_this, cx| cx.emit(DismissEvent))
320            .log_err();
321    }
322
323    fn render_match(
324        &self,
325        ix: usize,
326        selected: bool,
327        _window: &mut Window,
328        cx: &mut Context<Picker<Self>>,
329    ) -> Option<Self::ListItem> {
330        let item = &self.filtered_items[ix];
331        match item {
332            PickerItem::ContextServer { server_id, .. } => Some(
333                div()
334                    .px_2()
335                    .pb_1()
336                    .when(ix > 1, |this| {
337                        this.mt_1()
338                            .pt_2()
339                            .border_t_1()
340                            .border_color(cx.theme().colors().border_variant)
341                    })
342                    .child(
343                        Label::new(server_id)
344                            .size(LabelSize::XSmall)
345                            .color(Color::Muted),
346                    )
347                    .into_any_element(),
348            ),
349            PickerItem::Tool { name, server_id } => {
350                let is_enabled = if let Some(server_id) = server_id {
351                    self.profile
352                        .context_servers
353                        .get(server_id.as_ref())
354                        .and_then(|preset| preset.tools.get(name))
355                        .copied()
356                        .unwrap_or(self.profile.enable_all_context_servers)
357                } else {
358                    self.profile.tools.get(name).copied().unwrap_or(false)
359                };
360
361                Some(
362                    ListItem::new(ix)
363                        .inset(true)
364                        .spacing(ListItemSpacing::Sparse)
365                        .toggle_state(selected)
366                        .child(Label::new(name.clone()))
367                        .end_slot::<Icon>(is_enabled.then(|| {
368                            Icon::new(IconName::Check)
369                                .size(IconSize::Small)
370                                .color(Color::Success)
371                        }))
372                        .into_any_element(),
373                )
374            }
375        }
376    }
377}