assistant2: Add tool selector (#26480)

Marshall Bowers created

This PR adds a tool selector to Assistant 2 to facilitate customizing
the tools that the model sees:

<img width="1297" alt="Screenshot 2025-03-11 at 4 25 31 PM"
src="https://github.com/user-attachments/assets/7a656343-83bc-4546-9430-6a5f7ff1fd08"
/>

Release Notes:

- N/A

Change summary

crates/assistant2/src/assistant.rs            |  1 
crates/assistant2/src/message_editor.rs       | 40 ++++++-----
crates/assistant2/src/thread.rs               | 17 ++---
crates/assistant2/src/tool_selector.rs        | 70 +++++++++++++++++++++
crates/assistant_tool/src/assistant_tool.rs   |  2 
crates/assistant_tool/src/tool_working_set.rs | 69 +++++++++++++++++++-
6 files changed, 166 insertions(+), 33 deletions(-)

Detailed changes

crates/assistant2/src/message_editor.rs 🔗

@@ -28,6 +28,7 @@ use crate::context_store::{refresh_context_store_text, ContextStore};
 use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
 use crate::thread::{RequestKind, Thread};
 use crate::thread_store::ThreadStore;
+use crate::tool_selector::ToolSelector;
 use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker};
 
 pub struct MessageEditor {
@@ -39,6 +40,7 @@ pub struct MessageEditor {
     inline_context_picker: Entity<ContextPicker>,
     inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
     model_selector: Entity<AssistantModelSelector>,
+    tool_selector: Entity<ToolSelector>,
     use_tools: bool,
     edits_expanded: bool,
     _subscriptions: Vec<Subscription>,
@@ -53,6 +55,7 @@ impl MessageEditor {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
+        let tools = thread.read(cx).tools().clone();
         let context_store = cx.new(|_cx| ContextStore::new(workspace.clone()));
         let context_picker_menu_handle = PopoverMenuHandle::default();
         let inline_context_picker_menu_handle = PopoverMenuHandle::default();
@@ -118,6 +121,7 @@ impl MessageEditor {
                     cx,
                 )
             }),
+            tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)),
             use_tools: false,
             edits_expanded: false,
             _subscriptions: subscriptions,
@@ -538,23 +542,25 @@ impl Render for MessageEditor {
                                 h_flex()
                                     .justify_between()
                                     .child(
-                                        Switch::new("use-tools", self.use_tools.into())
-                                            .label("Tools")
-                                            .on_click(cx.listener(
-                                                |this, selection, _window, _cx| {
-                                                    this.use_tools = match selection {
-                                                        ToggleState::Selected => true,
-                                                        ToggleState::Unselected
-                                                        | ToggleState::Indeterminate => false,
-                                                    };
-                                                },
-                                            ))
-                                            .key_binding(KeyBinding::for_action_in(
-                                                &ChatMode,
-                                                &focus_handle,
-                                                window,
-                                                cx,
-                                            )),
+                                        h_flex().gap_2().child(self.tool_selector.clone()).child(
+                                            Switch::new("use-tools", self.use_tools.into())
+                                                .label("Tools")
+                                                .on_click(cx.listener(
+                                                    |this, selection, _window, _cx| {
+                                                        this.use_tools = match selection {
+                                                            ToggleState::Selected => true,
+                                                            ToggleState::Unselected
+                                                            | ToggleState::Indeterminate => false,
+                                                        };
+                                                    },
+                                                ))
+                                                .key_binding(KeyBinding::for_action_in(
+                                                    &ChatMode,
+                                                    &focus_handle,
+                                                    window,
+                                                    cx,
+                                                )),
+                                        ),
                                     )
                                     .child(
                                         h_flex().gap_1().child(self.model_selector.clone()).child(

crates/assistant2/src/thread.rs 🔗

@@ -355,16 +355,13 @@ impl Thread {
                 input_schema: ScriptingTool::input_schema(),
             });
 
-            tools.extend(
-                self.tools()
-                    .tools(cx)
-                    .into_iter()
-                    .map(|tool| LanguageModelRequestTool {
-                        name: tool.name(),
-                        description: tool.description(),
-                        input_schema: tool.input_schema(),
-                    }),
-            );
+            tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
+                LanguageModelRequestTool {
+                    name: tool.name(),
+                    description: tool.description(),
+                    input_schema: tool.input_schema(),
+                }
+            }));
 
             request.tools = tools;
         }

crates/assistant2/src/tool_selector.rs 🔗

@@ -0,0 +1,70 @@
+use std::sync::Arc;
+
+use assistant_tool::{ToolSource, ToolWorkingSet};
+use gpui::Entity;
+use ui::{prelude::*, ContextMenu, IconButtonShape, PopoverMenu, Tooltip};
+
+pub struct ToolSelector {
+    tools: Arc<ToolWorkingSet>,
+}
+
+impl ToolSelector {
+    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
+        Self { tools }
+    }
+
+    fn build_context_menu(
+        &self,
+        window: &mut Window,
+        cx: &mut Context<Self>,
+    ) -> Entity<ContextMenu> {
+        ContextMenu::build(window, cx, |mut menu, _window, cx| {
+            let tools_by_source = self.tools.tools_by_source(cx);
+
+            for (source, tools) in tools_by_source {
+                menu = match source {
+                    ToolSource::Native => menu.header("Zed"),
+                    ToolSource::ContextServer { id } => menu.separator().header(id),
+                };
+
+                for tool in tools {
+                    let source = tool.source();
+                    let name = tool.name().into();
+                    let is_enabled = self.tools.is_enabled(&source, &name);
+
+                    menu =
+                        menu.toggleable_entry(tool.name(), is_enabled, IconPosition::End, None, {
+                            let tools = self.tools.clone();
+                            move |_window, _cx| {
+                                if is_enabled {
+                                    tools.disable(source.clone(), &[name.clone()]);
+                                } else {
+                                    tools.enable(source.clone(), &[name.clone()]);
+                                }
+                            }
+                        });
+                }
+            }
+
+            menu
+        })
+    }
+}
+
+impl Render for ToolSelector {
+    fn render(&mut self, _window: &mut Window, cx: &mut Context<'_, Self>) -> impl IntoElement {
+        let this = cx.entity().clone();
+        PopoverMenu::new("tool-selector")
+            .menu(move |window, cx| {
+                Some(this.update(cx, |this, cx| this.build_context_menu(window, cx)))
+            })
+            .trigger_with_tooltip(
+                IconButton::new("tool-selector-button", IconName::SettingsAlt)
+                    .shape(IconButtonShape::Square)
+                    .icon_size(IconSize::Small)
+                    .icon_color(Color::Muted),
+                Tooltip::text("Customize Tools"),
+            )
+            .anchor(gpui::Corner::BottomLeft)
+    }
+}

crates/assistant_tool/src/assistant_tool.rs 🔗

@@ -14,7 +14,7 @@ pub fn init(cx: &mut App) {
     ToolRegistry::default_global(cx);
 }
 
-#[derive(Debug, PartialEq, Eq, Clone)]
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
 pub enum ToolSource {
     /// A native tool built-in to Zed.
     Native,

crates/assistant_tool/src/tool_working_set.rs 🔗

@@ -1,10 +1,10 @@
 use std::sync::Arc;
 
-use collections::HashMap;
+use collections::{HashMap, HashSet, IndexMap};
 use gpui::App;
 use parking_lot::Mutex;
 
-use crate::{Tool, ToolRegistry};
+use crate::{Tool, ToolRegistry, ToolSource};
 
 #[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
 pub struct ToolId(usize);
@@ -19,6 +19,7 @@ pub struct ToolWorkingSet {
 struct WorkingSetState {
     context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
     context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
+    disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
     next_tool_id: ToolId,
 }
 
@@ -45,6 +46,34 @@ impl ToolWorkingSet {
         tools
     }
 
+    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
+        let all_tools = self.tools(cx);
+
+        all_tools
+            .into_iter()
+            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
+            .collect()
+    }
+
+    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
+        let mut tools_by_source = IndexMap::default();
+
+        for tool in self.tools(cx) {
+            tools_by_source
+                .entry(tool.source())
+                .or_insert_with(Vec::new)
+                .push(tool);
+        }
+
+        for tools in tools_by_source.values_mut() {
+            tools.sort_by_key(|tool| tool.name());
+        }
+
+        tools_by_source.sort_unstable_keys();
+
+        tools_by_source
+    }
+
     pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
         let mut state = self.state.lock();
         let tool_id = state.next_tool_id;
@@ -56,11 +85,41 @@ impl ToolWorkingSet {
         tool_id
     }
 
-    pub fn remove(&self, command_ids_to_remove: &[ToolId]) {
+    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
+        !self.is_disabled(source, name)
+    }
+
+    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
+        let state = self.state.lock();
+        state
+            .disabled_tools_by_source
+            .get(source)
+            .map_or(false, |disabled_tools| disabled_tools.contains(name))
+    }
+
+    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
+        let mut state = self.state.lock();
+        state
+            .disabled_tools_by_source
+            .entry(source)
+            .or_default()
+            .retain(|name| !tools_to_enable.contains(name));
+    }
+
+    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
+        let mut state = self.state.lock();
+        state
+            .disabled_tools_by_source
+            .entry(source)
+            .or_default()
+            .extend(tools_to_disable.into_iter().cloned());
+    }
+
+    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
         let mut state = self.state.lock();
         state
             .context_server_tools_by_id
-            .retain(|id, _| !command_ids_to_remove.contains(id));
+            .retain(|id, _| !tool_ids_to_remove.contains(id));
         state.tools_changed();
     }
 }
@@ -71,7 +130,7 @@ impl WorkingSetState {
         self.context_server_tools_by_name.extend(
             self.context_server_tools_by_id
                 .values()
-                .map(|command| (command.name(), command.clone())),
+                .map(|tool| (tool.name(), tool.clone())),
         );
     }
 }