tool_picker.rs

  1use std::{collections::BTreeMap, sync::Arc};
  2
  3use agent_settings::{
  4    AgentProfileContent, AgentProfileId, AgentProfileSettings, AgentSettings, AgentSettingsContent,
  5    ContextServerPresetContent,
  6};
  7use assistant_tool::{ToolSource, ToolWorkingSet};
  8use fs::Fs;
  9use gpui::{App, Context, DismissEvent, Entity, EventEmitter, Focusable, Task, WeakEntity, Window};
 10use picker::{Picker, PickerDelegate};
 11use settings::update_settings_file;
 12use ui::{ListItem, ListItemSpacing, prelude::*};
 13use util::ResultExt as _;
 14
 15pub struct ToolPicker {
 16    picker: Entity<Picker<ToolPickerDelegate>>,
 17}
 18
 19#[derive(Clone, Copy, Debug, PartialEq)]
 20pub enum ToolPickerMode {
 21    BuiltinTools,
 22    McpTools,
 23}
 24
 25impl ToolPicker {
 26    pub fn builtin_tools(
 27        delegate: ToolPickerDelegate,
 28        window: &mut Window,
 29        cx: &mut Context<Self>,
 30    ) -> Self {
 31        let picker = cx.new(|cx| Picker::uniform_list(delegate, window, cx).modal(false));
 32        Self { picker }
 33    }
 34
 35    pub fn mcp_tools(
 36        delegate: ToolPickerDelegate,
 37        window: &mut Window,
 38        cx: &mut Context<Self>,
 39    ) -> Self {
 40        let picker = cx.new(|cx| Picker::list(delegate, window, cx).modal(false));
 41        Self { picker }
 42    }
 43}
 44
 45impl EventEmitter<DismissEvent> for ToolPicker {}
 46
 47impl Focusable for ToolPicker {
 48    fn focus_handle(&self, cx: &App) -> gpui::FocusHandle {
 49        self.picker.focus_handle(cx)
 50    }
 51}
 52
 53impl Render for ToolPicker {
 54    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
 55        v_flex().w(rems(34.)).child(self.picker.clone())
 56    }
 57}
 58
 59#[derive(Debug, Clone)]
 60pub enum PickerItem {
 61    Tool {
 62        server_id: Option<Arc<str>>,
 63        name: Arc<str>,
 64    },
 65    ContextServer {
 66        server_id: Arc<str>,
 67    },
 68}
 69
 70pub struct ToolPickerDelegate {
 71    tool_picker: WeakEntity<ToolPicker>,
 72    fs: Arc<dyn Fs>,
 73    items: Arc<Vec<PickerItem>>,
 74    profile_id: AgentProfileId,
 75    profile_settings: AgentProfileSettings,
 76    filtered_items: Vec<PickerItem>,
 77    selected_index: usize,
 78    mode: ToolPickerMode,
 79}
 80
 81impl ToolPickerDelegate {
 82    pub fn new(
 83        mode: ToolPickerMode,
 84        fs: Arc<dyn Fs>,
 85        tool_set: Entity<ToolWorkingSet>,
 86        profile_id: AgentProfileId,
 87        profile_settings: AgentProfileSettings,
 88        cx: &mut Context<ToolPicker>,
 89    ) -> Self {
 90        let items = Arc::new(Self::resolve_items(mode, &tool_set, cx));
 91
 92        Self {
 93            tool_picker: cx.entity().downgrade(),
 94            fs,
 95            items,
 96            profile_id,
 97            profile_settings,
 98            filtered_items: Vec::new(),
 99            selected_index: 0,
100            mode,
101        }
102    }
103
104    fn resolve_items(
105        mode: ToolPickerMode,
106        tool_set: &Entity<ToolWorkingSet>,
107        cx: &mut App,
108    ) -> Vec<PickerItem> {
109        let mut items = Vec::new();
110        for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
111            match source {
112                ToolSource::Native => {
113                    if mode == ToolPickerMode::BuiltinTools {
114                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
115                            name: tool.name().into(),
116                            server_id: None,
117                        }));
118                    }
119                }
120                ToolSource::ContextServer { id } => {
121                    if mode == ToolPickerMode::McpTools && !tools.is_empty() {
122                        let server_id: Arc<str> = id.clone().into();
123                        items.push(PickerItem::ContextServer {
124                            server_id: server_id.clone(),
125                        });
126                        items.extend(tools.into_iter().map(|tool| PickerItem::Tool {
127                            name: tool.name().into(),
128                            server_id: Some(server_id.clone()),
129                        }));
130                    }
131                }
132            }
133        }
134        items
135    }
136}
137
138impl PickerDelegate for ToolPickerDelegate {
139    type ListItem = AnyElement;
140
141    fn match_count(&self) -> usize {
142        self.filtered_items.len()
143    }
144
145    fn selected_index(&self) -> usize {
146        self.selected_index
147    }
148
149    fn set_selected_index(
150        &mut self,
151        ix: usize,
152        _window: &mut Window,
153        _cx: &mut Context<Picker<Self>>,
154    ) {
155        self.selected_index = ix;
156    }
157
158    fn can_select(
159        &mut self,
160        ix: usize,
161        _window: &mut Window,
162        _cx: &mut Context<Picker<Self>>,
163    ) -> bool {
164        let item = &self.filtered_items[ix];
165        match item {
166            PickerItem::Tool { .. } => true,
167            PickerItem::ContextServer { .. } => false,
168        }
169    }
170
171    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
172        match self.mode {
173            ToolPickerMode::BuiltinTools => "Search built-in tools…",
174            ToolPickerMode::McpTools => "Search MCP tools…",
175        }
176        .into()
177    }
178
179    fn update_matches(
180        &mut self,
181        query: String,
182        window: &mut Window,
183        cx: &mut Context<Picker<Self>>,
184    ) -> Task<()> {
185        let all_items = self.items.clone();
186
187        cx.spawn_in(window, async move |this, cx| {
188            let filtered_items = cx
189                .background_spawn(async move {
190                    let mut tools_by_provider: BTreeMap<Option<Arc<str>>, Vec<Arc<str>>> =
191                        BTreeMap::default();
192
193                    for item in all_items.iter() {
194                        if let PickerItem::Tool { server_id, name } = item.clone()
195                            && name.contains(&query)
196                        {
197                            tools_by_provider.entry(server_id).or_default().push(name);
198                        }
199                    }
200
201                    let mut items = Vec::new();
202
203                    for (server_id, names) in tools_by_provider {
204                        if let Some(server_id) = server_id.clone() {
205                            items.push(PickerItem::ContextServer { server_id });
206                        }
207                        for name in names {
208                            items.push(PickerItem::Tool {
209                                server_id: server_id.clone(),
210                                name,
211                            });
212                        }
213                    }
214
215                    items
216                })
217                .await;
218
219            this.update(cx, |this, _cx| {
220                this.delegate.filtered_items = filtered_items;
221                this.delegate.selected_index = this
222                    .delegate
223                    .selected_index
224                    .min(this.delegate.filtered_items.len().saturating_sub(1));
225            })
226            .log_err();
227        })
228    }
229
230    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
231        if self.filtered_items.is_empty() {
232            self.dismissed(window, cx);
233            return;
234        }
235
236        let item = &self.filtered_items[self.selected_index];
237
238        let PickerItem::Tool {
239            name: tool_name,
240            server_id,
241        } = item
242        else {
243            return;
244        };
245
246        let is_currently_enabled = if let Some(server_id) = server_id.clone() {
247            let preset = self
248                .profile_settings
249                .context_servers
250                .entry(server_id)
251                .or_default();
252            let is_enabled = *preset.tools.entry(tool_name.clone()).or_default();
253            *preset.tools.entry(tool_name.clone()).or_default() = !is_enabled;
254            is_enabled
255        } else {
256            let is_enabled = *self
257                .profile_settings
258                .tools
259                .entry(tool_name.clone())
260                .or_default();
261            *self
262                .profile_settings
263                .tools
264                .entry(tool_name.clone())
265                .or_default() = !is_enabled;
266            is_enabled
267        };
268
269        update_settings_file::<AgentSettings>(self.fs.clone(), cx, {
270            let profile_id = self.profile_id.clone();
271            let default_profile = self.profile_settings.clone();
272            let server_id = server_id.clone();
273            let tool_name = tool_name.clone();
274            move |settings: &mut AgentSettingsContent, _cx| {
275                let profiles = settings.profiles.get_or_insert_default();
276                let profile = profiles
277                    .entry(profile_id)
278                    .or_insert_with(|| AgentProfileContent {
279                        name: default_profile.name.into(),
280                        tools: default_profile.tools,
281                        enable_all_context_servers: Some(
282                            default_profile.enable_all_context_servers,
283                        ),
284                        context_servers: default_profile
285                            .context_servers
286                            .into_iter()
287                            .map(|(server_id, preset)| {
288                                (
289                                    server_id,
290                                    ContextServerPresetContent {
291                                        tools: preset.tools,
292                                    },
293                                )
294                            })
295                            .collect(),
296                    });
297
298                if let Some(server_id) = server_id {
299                    let preset = profile.context_servers.entry(server_id).or_default();
300                    *preset.tools.entry(tool_name).or_default() = !is_currently_enabled;
301                } else {
302                    *profile.tools.entry(tool_name).or_default() = !is_currently_enabled;
303                }
304            }
305        });
306    }
307
308    fn dismissed(&mut self, _window: &mut Window, cx: &mut Context<Picker<Self>>) {
309        self.tool_picker
310            .update(cx, |_this, cx| cx.emit(DismissEvent))
311            .log_err();
312    }
313
314    fn render_match(
315        &self,
316        ix: usize,
317        selected: bool,
318        _window: &mut Window,
319        cx: &mut Context<Picker<Self>>,
320    ) -> Option<Self::ListItem> {
321        let item = &self.filtered_items.get(ix)?;
322        match item {
323            PickerItem::ContextServer { server_id, .. } => Some(
324                div()
325                    .px_2()
326                    .pb_1()
327                    .when(ix > 1, |this| {
328                        this.mt_1()
329                            .pt_2()
330                            .border_t_1()
331                            .border_color(cx.theme().colors().border_variant)
332                    })
333                    .child(
334                        Label::new(server_id)
335                            .size(LabelSize::XSmall)
336                            .color(Color::Muted),
337                    )
338                    .into_any_element(),
339            ),
340            PickerItem::Tool { name, server_id } => {
341                let is_enabled = if let Some(server_id) = server_id {
342                    self.profile_settings
343                        .context_servers
344                        .get(server_id.as_ref())
345                        .and_then(|preset| preset.tools.get(name))
346                        .copied()
347                        .unwrap_or(self.profile_settings.enable_all_context_servers)
348                } else {
349                    self.profile_settings
350                        .tools
351                        .get(name)
352                        .copied()
353                        .unwrap_or(false)
354                };
355
356                Some(
357                    ListItem::new(ix)
358                        .inset(true)
359                        .spacing(ListItemSpacing::Sparse)
360                        .toggle_state(selected)
361                        .child(Label::new(name.clone()))
362                        .end_slot::<Icon>(is_enabled.then(|| {
363                            Icon::new(IconName::Check)
364                                .size(IconSize::Small)
365                                .color(Color::Success)
366                        }))
367                        .into_any_element(),
368                )
369            }
370        }
371    }
372}