crates/assistant2/src/assistant.rs 🔗
@@ -16,6 +16,7 @@ mod terminal_inline_assistant;
mod thread;
mod thread_history;
mod thread_store;
+mod tool_selector;
mod tool_use;
mod ui;
Marshall Bowers created
This PR adds a tool selector to Assistant 2 to facilitate customizing
the tools that the model sees:
<img width="1297" alt="Screenshot 2025-03-11 at 4 25 31 PM"
src="https://github.com/user-attachments/assets/7a656343-83bc-4546-9430-6a5f7ff1fd08"
/>
Release Notes:
- N/A
crates/assistant2/src/assistant.rs | 1
crates/assistant2/src/message_editor.rs | 40 ++++++-----
crates/assistant2/src/thread.rs | 17 ++---
crates/assistant2/src/tool_selector.rs | 70 +++++++++++++++++++++
crates/assistant_tool/src/assistant_tool.rs | 2
crates/assistant_tool/src/tool_working_set.rs | 69 +++++++++++++++++++-
6 files changed, 166 insertions(+), 33 deletions(-)
@@ -16,6 +16,7 @@ mod terminal_inline_assistant;
mod thread;
mod thread_history;
mod thread_store;
+mod tool_selector;
mod tool_use;
mod ui;
@@ -28,6 +28,7 @@ use crate::context_store::{refresh_context_store_text, ContextStore};
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
use crate::thread::{RequestKind, Thread};
use crate::thread_store::ThreadStore;
+use crate::tool_selector::ToolSelector;
use crate::{Chat, ChatMode, RemoveAllContext, ToggleContextPicker};
pub struct MessageEditor {
@@ -39,6 +40,7 @@ pub struct MessageEditor {
inline_context_picker: Entity<ContextPicker>,
inline_context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
model_selector: Entity<AssistantModelSelector>,
+ tool_selector: Entity<ToolSelector>,
use_tools: bool,
edits_expanded: bool,
_subscriptions: Vec<Subscription>,
@@ -53,6 +55,7 @@ impl MessageEditor {
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
+ let tools = thread.read(cx).tools().clone();
let context_store = cx.new(|_cx| ContextStore::new(workspace.clone()));
let context_picker_menu_handle = PopoverMenuHandle::default();
let inline_context_picker_menu_handle = PopoverMenuHandle::default();
@@ -118,6 +121,7 @@ impl MessageEditor {
cx,
)
}),
+ tool_selector: cx.new(|cx| ToolSelector::new(tools, cx)),
use_tools: false,
edits_expanded: false,
_subscriptions: subscriptions,
@@ -538,23 +542,25 @@ impl Render for MessageEditor {
h_flex()
.justify_between()
.child(
- Switch::new("use-tools", self.use_tools.into())
- .label("Tools")
- .on_click(cx.listener(
- |this, selection, _window, _cx| {
- this.use_tools = match selection {
- ToggleState::Selected => true,
- ToggleState::Unselected
- | ToggleState::Indeterminate => false,
- };
- },
- ))
- .key_binding(KeyBinding::for_action_in(
- &ChatMode,
- &focus_handle,
- window,
- cx,
- )),
+ h_flex().gap_2().child(self.tool_selector.clone()).child(
+ Switch::new("use-tools", self.use_tools.into())
+ .label("Tools")
+ .on_click(cx.listener(
+ |this, selection, _window, _cx| {
+ this.use_tools = match selection {
+ ToggleState::Selected => true,
+ ToggleState::Unselected
+ | ToggleState::Indeterminate => false,
+ };
+ },
+ ))
+ .key_binding(KeyBinding::for_action_in(
+ &ChatMode,
+ &focus_handle,
+ window,
+ cx,
+ )),
+ ),
)
.child(
h_flex().gap_1().child(self.model_selector.clone()).child(
@@ -355,16 +355,13 @@ impl Thread {
input_schema: ScriptingTool::input_schema(),
});
- tools.extend(
- self.tools()
- .tools(cx)
- .into_iter()
- .map(|tool| LanguageModelRequestTool {
- name: tool.name(),
- description: tool.description(),
- input_schema: tool.input_schema(),
- }),
- );
+ tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
+ LanguageModelRequestTool {
+ name: tool.name(),
+ description: tool.description(),
+ input_schema: tool.input_schema(),
+ }
+ }));
request.tools = tools;
}
@@ -0,0 +1,70 @@
+use std::sync::Arc;
+
+use assistant_tool::{ToolSource, ToolWorkingSet};
+use gpui::Entity;
+use ui::{prelude::*, ContextMenu, IconButtonShape, PopoverMenu, Tooltip};
+
+pub struct ToolSelector {
+ tools: Arc<ToolWorkingSet>,
+}
+
+impl ToolSelector {
+ pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
+ Self { tools }
+ }
+
+ fn build_context_menu(
+ &self,
+ window: &mut Window,
+ cx: &mut Context<Self>,
+ ) -> Entity<ContextMenu> {
+ ContextMenu::build(window, cx, |mut menu, _window, cx| {
+ let tools_by_source = self.tools.tools_by_source(cx);
+
+ for (source, tools) in tools_by_source {
+ menu = match source {
+ ToolSource::Native => menu.header("Zed"),
+ ToolSource::ContextServer { id } => menu.separator().header(id),
+ };
+
+ for tool in tools {
+ let source = tool.source();
+ let name = tool.name().into();
+ let is_enabled = self.tools.is_enabled(&source, &name);
+
+ menu =
+ menu.toggleable_entry(tool.name(), is_enabled, IconPosition::End, None, {
+ let tools = self.tools.clone();
+ move |_window, _cx| {
+ if is_enabled {
+ tools.disable(source.clone(), &[name.clone()]);
+ } else {
+ tools.enable(source.clone(), &[name.clone()]);
+ }
+ }
+ });
+ }
+ }
+
+ menu
+ })
+ }
+}
+
+impl Render for ToolSelector {
+ fn render(&mut self, _window: &mut Window, cx: &mut Context<'_, Self>) -> impl IntoElement {
+ let this = cx.entity().clone();
+ PopoverMenu::new("tool-selector")
+ .menu(move |window, cx| {
+ Some(this.update(cx, |this, cx| this.build_context_menu(window, cx)))
+ })
+ .trigger_with_tooltip(
+ IconButton::new("tool-selector-button", IconName::SettingsAlt)
+ .shape(IconButtonShape::Square)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted),
+ Tooltip::text("Customize Tools"),
+ )
+ .anchor(gpui::Corner::BottomLeft)
+ }
+}
@@ -14,7 +14,7 @@ pub fn init(cx: &mut App) {
ToolRegistry::default_global(cx);
}
-#[derive(Debug, PartialEq, Eq, Clone)]
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone)]
pub enum ToolSource {
/// A native tool built-in to Zed.
Native,
@@ -1,10 +1,10 @@
use std::sync::Arc;
-use collections::HashMap;
+use collections::{HashMap, HashSet, IndexMap};
use gpui::App;
use parking_lot::Mutex;
-use crate::{Tool, ToolRegistry};
+use crate::{Tool, ToolRegistry, ToolSource};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Default)]
pub struct ToolId(usize);
@@ -19,6 +19,7 @@ pub struct ToolWorkingSet {
struct WorkingSetState {
context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
+ disabled_tools_by_source: HashMap<ToolSource, HashSet<Arc<str>>>,
next_tool_id: ToolId,
}
@@ -45,6 +46,34 @@ impl ToolWorkingSet {
tools
}
+ pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
+ let all_tools = self.tools(cx);
+
+ all_tools
+ .into_iter()
+ .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
+ .collect()
+ }
+
+ pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
+ let mut tools_by_source = IndexMap::default();
+
+ for tool in self.tools(cx) {
+ tools_by_source
+ .entry(tool.source())
+ .or_insert_with(Vec::new)
+ .push(tool);
+ }
+
+ for tools in tools_by_source.values_mut() {
+ tools.sort_by_key(|tool| tool.name());
+ }
+
+ tools_by_source.sort_unstable_keys();
+
+ tools_by_source
+ }
+
pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
let mut state = self.state.lock();
let tool_id = state.next_tool_id;
@@ -56,11 +85,41 @@ impl ToolWorkingSet {
tool_id
}
- pub fn remove(&self, command_ids_to_remove: &[ToolId]) {
+ pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
+ !self.is_disabled(source, name)
+ }
+
+ pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
+ let state = self.state.lock();
+ state
+ .disabled_tools_by_source
+ .get(source)
+ .map_or(false, |disabled_tools| disabled_tools.contains(name))
+ }
+
+ pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
+ let mut state = self.state.lock();
+ state
+ .disabled_tools_by_source
+ .entry(source)
+ .or_default()
+ .retain(|name| !tools_to_enable.contains(name));
+ }
+
+ pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
+ let mut state = self.state.lock();
+ state
+ .disabled_tools_by_source
+ .entry(source)
+ .or_default()
+ .extend(tools_to_disable.into_iter().cloned());
+ }
+
+ pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
let mut state = self.state.lock();
state
.context_server_tools_by_id
- .retain(|id, _| !command_ids_to_remove.contains(id));
+ .retain(|id, _| !tool_ids_to_remove.contains(id));
state.tools_changed();
}
}
@@ -71,7 +130,7 @@ impl WorkingSetState {
self.context_server_tools_by_name.extend(
self.context_server_tools_by_id
.values()
- .map(|command| (command.name(), command.clone())),
+ .map(|tool| (tool.name(), tool.clone())),
);
}
}