agent: Make `ToolWorkingSet` an `Entity` (#28757)

Bennet Bo Fenner created

Motivation is to emit events when enabled tools change, want to use this
in #28755

Release Notes:

- N/A

Change summary

crates/agent/src/agent_diff.rs                                    |   3 
crates/agent/src/assistant_configuration.rs                       |   6 
crates/agent/src/assistant_configuration/manage_profiles_modal.rs |   4 
crates/agent/src/assistant_configuration/tool_picker.rs           |   6 
crates/agent/src/assistant_panel.rs                               |   2 
crates/agent/src/profile_selector.rs                              |   2 
crates/agent/src/thread.rs                                        |  16 
crates/agent/src/thread_store.rs                                  | 131 
crates/agent/src/tool_use.rs                                      |  22 
crates/assistant_tool/src/tool_working_set.rs                     | 185 
crates/eval/src/example.rs                                        |   4 
11 files changed, 181 insertions(+), 200 deletions(-)

Detailed changes

crates/agent/src/agent_diff.rs 🔗

@@ -894,6 +894,7 @@ mod tests {
     use super::*;
     use crate::{ThreadStore, thread_store};
     use assistant_settings::AssistantSettings;
+    use assistant_tool::ToolWorkingSet;
     use context_server::ContextServerSettings;
     use editor::EditorSettings;
     use gpui::TestAppContext;
@@ -937,7 +938,7 @@ mod tests {
             .update(|cx| {
                 ThreadStore::load(
                     project.clone(),
-                    Arc::default(),
+                    cx.new(|_| ToolWorkingSet::default()),
                     Arc::new(PromptBuilder::new(None).unwrap()),
                     cx,
                 )

crates/agent/src/assistant_configuration.rs 🔗

@@ -29,7 +29,7 @@ pub struct AssistantConfiguration {
     configuration_views_by_provider: HashMap<LanguageModelProviderId, AnyView>,
     context_server_manager: Entity<ContextServerManager>,
     expanded_context_server_tools: HashMap<Arc<str>, bool>,
-    tools: Arc<ToolWorkingSet>,
+    tools: Entity<ToolWorkingSet>,
     _registry_subscription: Subscription,
 }
 
@@ -37,7 +37,7 @@ impl AssistantConfiguration {
     pub fn new(
         fs: Arc<dyn Fs>,
         context_server_manager: Entity<ContextServerManager>,
-        tools: Arc<ToolWorkingSet>,
+        tools: Entity<ToolWorkingSet>,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -226,7 +226,7 @@ impl AssistantConfiguration {
 
     fn render_context_servers_section(&mut self, cx: &mut Context<Self>) -> impl IntoElement {
         let context_servers = self.context_server_manager.read(cx).all_servers().clone();
-        let tools_by_source = self.tools.tools_by_source(cx);
+        let tools_by_source = self.tools.read(cx).tools_by_source(cx);
         let empty = Vec::new();
 
         const SUBHEADING: &str = "Connect to context servers via the Model Context Protocol either via Zed extensions or directly.";

crates/agent/src/assistant_configuration/manage_profiles_modal.rs 🔗

@@ -84,7 +84,7 @@ pub struct NewProfileMode {
 
 pub struct ManageProfilesModal {
     fs: Arc<dyn Fs>,
-    tools: Arc<ToolWorkingSet>,
+    tools: Entity<ToolWorkingSet>,
     thread_store: WeakEntity<ThreadStore>,
     focus_handle: FocusHandle,
     mode: Mode,
@@ -117,7 +117,7 @@ impl ManageProfilesModal {
 
     pub fn new(
         fs: Arc<dyn Fs>,
-        tools: Arc<ToolWorkingSet>,
+        tools: Entity<ToolWorkingSet>,
         thread_store: WeakEntity<ThreadStore>,
         window: &mut Window,
         cx: &mut Context<Self>,

crates/agent/src/assistant_configuration/tool_picker.rs 🔗

@@ -60,7 +60,7 @@ pub struct ToolPickerDelegate {
 impl ToolPickerDelegate {
     pub fn new(
         fs: Arc<dyn Fs>,
-        tool_set: Arc<ToolWorkingSet>,
+        tool_set: Entity<ToolWorkingSet>,
         thread_store: WeakEntity<ThreadStore>,
         profile_id: AgentProfileId,
         profile: AgentProfile,
@@ -68,7 +68,7 @@ impl ToolPickerDelegate {
     ) -> Self {
         let mut tool_entries = Vec::new();
 
-        for (source, tools) in tool_set.tools_by_source(cx) {
+        for (source, tools) in tool_set.read(cx).tools_by_source(cx) {
             tool_entries.extend(tools.into_iter().map(|tool| ToolEntry {
                 name: tool.name().into(),
                 source: source.clone(),
@@ -192,7 +192,7 @@ impl PickerDelegate for ToolPickerDelegate {
         if active_profile_id == &self.profile_id {
             self.thread_store
                 .update(cx, |this, cx| {
-                    this.load_profile(&self.profile, cx);
+                    this.load_profile(self.profile.clone(), cx);
                 })
                 .log_err();
         }

crates/agent/src/assistant_panel.rs 🔗

@@ -203,7 +203,7 @@ impl AssistantPanel {
         cx: AsyncWindowContext,
     ) -> Task<Result<Entity<Self>>> {
         cx.spawn(async move |cx| {
-            let tools = Arc::new(ToolWorkingSet::default());
+            let tools = cx.new(|_| ToolWorkingSet::default())?;
             let thread_store = workspace
                 .update(cx, |workspace, cx| {
                     let project = workspace.project().clone();

crates/agent/src/profile_selector.rs 🔗

@@ -86,7 +86,7 @@ impl ProfileSelector {
 
                             thread_store
                                 .update(cx, |this, cx| {
-                                    this.load_profile_by_id(&profile_id, cx);
+                                    this.load_profile_by_id(profile_id.clone(), cx);
                                 })
                                 .log_err();
                         }

crates/agent/src/thread.rs 🔗

@@ -254,7 +254,7 @@ pub struct Thread {
     pending_completions: Vec<PendingCompletion>,
     project: Entity<Project>,
     prompt_builder: Arc<PromptBuilder>,
-    tools: Arc<ToolWorkingSet>,
+    tools: Entity<ToolWorkingSet>,
     tool_use: ToolUseState,
     action_log: Entity<ActionLog>,
     last_restore_checkpoint: Option<LastRestoreCheckpoint>,
@@ -278,7 +278,7 @@ pub struct ExceededWindowError {
 impl Thread {
     pub fn new(
         project: Entity<Project>,
-        tools: Arc<ToolWorkingSet>,
+        tools: Entity<ToolWorkingSet>,
         prompt_builder: Arc<PromptBuilder>,
         system_prompt: SharedProjectContext,
         cx: &mut Context<Self>,
@@ -322,7 +322,7 @@ impl Thread {
         id: ThreadId,
         serialized: SerializedThread,
         project: Entity<Project>,
-        tools: Arc<ToolWorkingSet>,
+        tools: Entity<ToolWorkingSet>,
         prompt_builder: Arc<PromptBuilder>,
         project_context: SharedProjectContext,
         cx: &mut Context<Self>,
@@ -458,7 +458,7 @@ impl Thread {
         !self.pending_completions.is_empty() || !self.all_tools_finished()
     }
 
-    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
+    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
         &self.tools
     }
 
@@ -846,6 +846,7 @@ impl Thread {
                 let mut tools = Vec::new();
                 tools.extend(
                     self.tools()
+                        .read(cx)
                         .enabled_tools(cx)
                         .into_iter()
                         .filter_map(|tool| {
@@ -1354,7 +1355,7 @@ impl Thread {
             .collect::<Vec<_>>();
 
         for tool_use in pending_tool_uses.iter() {
-            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
+            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
                 if tool.needs_confirmation(&tool_use.input, cx)
                     && !AssistantSettings::get_global(cx).always_allow_tool_actions
                 {
@@ -1406,7 +1407,7 @@ impl Thread {
     ) -> Task<()> {
         let tool_name: Arc<str> = tool.name().into();
 
-        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
+        let run_tool = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
             Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
         } else {
             tool.run(
@@ -1521,6 +1522,7 @@ impl Thread {
 
         let enabled_tool_names: Vec<String> = self
             .tools()
+            .read(cx)
             .enabled_tools(cx)
             .iter()
             .map(|tool| tool.name().to_string())
@@ -2341,7 +2343,7 @@ fn main() {{
             .update(|_, cx| {
                 ThreadStore::load(
                     project.clone(),
-                    Arc::default(),
+                    cx.new(|_| ToolWorkingSet::default()),
                     Arc::new(PromptBuilder::new(None).unwrap()),
                     cx,
                 )

crates/agent/src/thread_store.rs 🔗

@@ -56,7 +56,7 @@ impl SharedProjectContext {
 
 pub struct ThreadStore {
     project: Entity<Project>,
-    tools: Arc<ToolWorkingSet>,
+    tools: Entity<ToolWorkingSet>,
     prompt_builder: Arc<PromptBuilder>,
     context_server_manager: Entity<ContextServerManager>,
     context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
@@ -74,7 +74,7 @@ impl EventEmitter<RulesLoadingError> for ThreadStore {}
 impl ThreadStore {
     pub fn load(
         project: Entity<Project>,
-        tools: Arc<ToolWorkingSet>,
+        tools: Entity<ToolWorkingSet>,
         prompt_builder: Arc<PromptBuilder>,
         cx: &mut App,
     ) -> Task<Entity<Self>> {
@@ -88,7 +88,7 @@ impl ThreadStore {
 
     fn new(
         project: Entity<Project>,
-        tools: Arc<ToolWorkingSet>,
+        tools: Entity<ToolWorkingSet>,
         prompt_builder: Arc<PromptBuilder>,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -248,7 +248,7 @@ impl ThreadStore {
         self.context_server_manager.clone()
     }
 
-    pub fn tools(&self) -> Arc<ToolWorkingSet> {
+    pub fn tools(&self) -> Entity<ToolWorkingSet> {
         self.tools.clone()
     }
 
@@ -355,52 +355,60 @@ impl ThreadStore {
         })
     }
 
-    fn load_default_profile(&self, cx: &Context<Self>) {
+    fn load_default_profile(&self, cx: &mut Context<Self>) {
         let assistant_settings = AssistantSettings::get_global(cx);
 
-        self.load_profile_by_id(&assistant_settings.default_profile, cx);
+        self.load_profile_by_id(assistant_settings.default_profile.clone(), cx);
     }
 
-    pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
+    pub fn load_profile_by_id(&self, profile_id: AgentProfileId, cx: &mut Context<Self>) {
         let assistant_settings = AssistantSettings::get_global(cx);
 
-        if let Some(profile) = assistant_settings.profiles.get(profile_id) {
-            self.load_profile(profile, cx);
+        if let Some(profile) = assistant_settings.profiles.get(&profile_id) {
+            self.load_profile(profile.clone(), cx);
         }
     }
 
-    pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
-        self.tools.disable_all_tools();
-        self.tools.enable(
-            ToolSource::Native,
-            &profile
-                .tools
-                .iter()
-                .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
-                .collect::<Vec<_>>(),
-        );
+    pub fn load_profile(&self, profile: AgentProfile, cx: &mut Context<Self>) {
+        self.tools.update(cx, |tools, cx| {
+            tools.disable_all_tools(cx);
+            tools.enable(
+                ToolSource::Native,
+                &profile
+                    .tools
+                    .iter()
+                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
+                    .collect::<Vec<_>>(),
+                cx,
+            );
+        });
 
         if profile.enable_all_context_servers {
             for context_server in self.context_server_manager.read(cx).all_servers() {
-                self.tools.enable_source(
-                    ToolSource::ContextServer {
-                        id: context_server.id().into(),
-                    },
-                    cx,
-                );
+                self.tools.update(cx, |tools, cx| {
+                    tools.enable_source(
+                        ToolSource::ContextServer {
+                            id: context_server.id().into(),
+                        },
+                        cx,
+                    );
+                });
             }
         } else {
             for (context_server_id, preset) in &profile.context_servers {
-                self.tools.enable(
-                    ToolSource::ContextServer {
-                        id: context_server_id.clone().into(),
-                    },
-                    &preset
-                        .tools
-                        .iter()
-                        .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
-                        .collect::<Vec<_>>(),
-                )
+                self.tools.update(cx, |tools, cx| {
+                    tools.enable(
+                        ToolSource::ContextServer {
+                            id: context_server_id.clone().into(),
+                        },
+                        &preset
+                            .tools
+                            .iter()
+                            .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
+                            .collect::<Vec<_>>(),
+                        cx,
+                    )
+                })
             }
         }
     }
@@ -434,29 +442,36 @@ impl ThreadStore {
 
                             if protocol.capable(context_server::protocol::ServerCapability::Tools) {
                                 if let Some(tools) = protocol.list_tools().await.log_err() {
-                                    let tool_ids = tools
-                                        .tools
-                                        .into_iter()
-                                        .map(|tool| {
-                                            log::info!(
-                                                "registering context server tool: {:?}",
-                                                tool.name
-                                            );
-                                            tool_working_set.insert(Arc::new(
-                                                ContextServerTool::new(
-                                                    context_server_manager.clone(),
-                                                    server.id(),
-                                                    tool,
-                                                ),
-                                            ))
+                                    let tool_ids = tool_working_set
+                                        .update(cx, |tool_working_set, _| {
+                                            tools
+                                                .tools
+                                                .into_iter()
+                                                .map(|tool| {
+                                                    log::info!(
+                                                        "registering context server tool: {:?}",
+                                                        tool.name
+                                                    );
+                                                    tool_working_set.insert(Arc::new(
+                                                        ContextServerTool::new(
+                                                            context_server_manager.clone(),
+                                                            server.id(),
+                                                            tool,
+                                                        ),
+                                                    ))
+                                                })
+                                                .collect::<Vec<_>>()
                                         })
-                                        .collect::<Vec<_>>();
+                                        .log_err();
 
-                                    this.update(cx, |this, cx| {
-                                        this.context_server_tool_ids.insert(server_id, tool_ids);
-                                        this.load_default_profile(cx);
-                                    })
-                                    .log_err();
+                                    if let Some(tool_ids) = tool_ids {
+                                        this.update(cx, |this, cx| {
+                                            this.context_server_tool_ids
+                                                .insert(server_id, tool_ids);
+                                            this.load_default_profile(cx);
+                                        })
+                                        .log_err();
+                                    }
                                 }
                             }
                         }
@@ -466,7 +481,9 @@ impl ThreadStore {
             }
             context_server::manager::Event::ServerStopped { server_id } => {
                 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
-                    tool_working_set.remove(&tool_ids);
+                    tool_working_set.update(cx, |tool_working_set, _| {
+                        tool_working_set.remove(&tool_ids);
+                    });
                     self.load_default_profile(cx);
                 }
             }

crates/agent/src/tool_use.rs 🔗

@@ -5,7 +5,7 @@ use assistant_tool::{Tool, ToolWorkingSet};
 use collections::HashMap;
 use futures::FutureExt as _;
 use futures::future::Shared;
-use gpui::{App, SharedString, Task};
+use gpui::{App, Entity, SharedString, Task};
 use language_model::{
     LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
     LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
@@ -49,7 +49,7 @@ impl ToolUseStatus {
 }
 
 pub struct ToolUseState {
-    tools: Arc<ToolWorkingSet>,
+    tools: Entity<ToolWorkingSet>,
     tool_uses_by_assistant_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
     tool_uses_by_user_message: HashMap<MessageId, Vec<LanguageModelToolUseId>>,
     tool_results: HashMap<LanguageModelToolUseId, LanguageModelToolResult>,
@@ -59,7 +59,7 @@ pub struct ToolUseState {
 pub const USING_TOOL_MARKER: &str = "<using_tool>";
 
 impl ToolUseState {
-    pub fn new(tools: Arc<ToolWorkingSet>) -> Self {
+    pub fn new(tools: Entity<ToolWorkingSet>) -> Self {
         Self {
             tools,
             tool_uses_by_assistant_message: HashMap::default(),
@@ -73,7 +73,7 @@ impl ToolUseState {
     ///
     /// Accepts a function to filter the tools that should be used to populate the state.
     pub fn from_serialized_messages(
-        tools: Arc<ToolWorkingSet>,
+        tools: Entity<ToolWorkingSet>,
         messages: &[SerializedMessage],
         mut filter_by_tool_name: impl FnMut(&str) -> bool,
     ) -> Self {
@@ -199,12 +199,12 @@ impl ToolUseState {
                 }
             })();
 
-            let (icon, needs_confirmation) = if let Some(tool) = self.tools.tool(&tool_use.name, cx)
-            {
-                (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
-            } else {
-                (IconName::Cog, false)
-            };
+            let (icon, needs_confirmation) =
+                if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
+                    (tool.icon(), tool.needs_confirmation(&tool_use.input, cx))
+                } else {
+                    (IconName::Cog, false)
+                };
 
             tool_uses.push(ToolUse {
                 id: tool_use.id.clone(),
@@ -226,7 +226,7 @@ impl ToolUseState {
         input: &serde_json::Value,
         cx: &App,
     ) -> SharedString {
-        if let Some(tool) = self.tools.tool(tool_name, cx) {
+        if let Some(tool) = self.tools.read(cx).tool(tool_name, cx) {
             tool.ui_text(input).into()
         } else {
             format!("Unknown tool {tool_name:?}").into()

crates/assistant_tool/src/tool_working_set.rs 🔗

@@ -1,8 +1,7 @@
 use std::sync::Arc;
 
 use collections::{HashMap, HashSet, IndexMap};
-use gpui::App;
-use parking_lot::Mutex;
+use gpui::{App, Context, EventEmitter};
 
 use crate::{Tool, ToolRegistry, ToolSource};
 
@@ -12,11 +11,6 @@ pub struct ToolId(usize);
 /// A working set of tools for use in one instance of the Assistant Panel.
 #[derive(Default)]
 pub struct ToolWorkingSet {
-    state: Mutex<WorkingSetState>,
-}
-
-#[derive(Default)]
-struct WorkingSetState {
     context_server_tools_by_id: HashMap<ToolId, Arc<dyn Tool>>,
     context_server_tools_by_name: HashMap<String, Arc<dyn Tool>>,
     enabled_sources: HashSet<ToolSource>,
@@ -24,99 +18,27 @@ struct WorkingSetState {
     next_tool_id: ToolId,
 }
 
+pub enum ToolWorkingSetEvent {
+    EnabledToolsChanged,
+}
+
+impl EventEmitter<ToolWorkingSetEvent> for ToolWorkingSet {}
+
 impl ToolWorkingSet {
     pub fn tool(&self, name: &str, cx: &App) -> Option<Arc<dyn Tool>> {
-        self.state
-            .lock()
-            .context_server_tools_by_name
+        self.context_server_tools_by_name
             .get(name)
             .cloned()
             .or_else(|| ToolRegistry::global(cx).tool(name))
     }
 
     pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
-        self.state.lock().tools(cx)
-    }
-
-    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
-        self.state.lock().tools_by_source(cx)
-    }
-
-    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
-        self.state.lock().enabled_tools(cx)
-    }
-
-    pub fn disable_all_tools(&self) {
-        let mut state = self.state.lock();
-        state.disable_all_tools();
-    }
-
-    pub fn enable_source(&self, source: ToolSource, cx: &App) {
-        let mut state = self.state.lock();
-        state.enable_source(source, cx);
-    }
-
-    pub fn disable_source(&self, source: &ToolSource) {
-        let mut state = self.state.lock();
-        state.disable_source(source);
-    }
-
-    pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
-        let mut state = self.state.lock();
-        let tool_id = state.next_tool_id;
-        state.next_tool_id.0 += 1;
-        state
-            .context_server_tools_by_id
-            .insert(tool_id, tool.clone());
-        state.tools_changed();
-        tool_id
-    }
-
-    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
-        self.state.lock().is_enabled(source, name)
-    }
-
-    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
-        self.state.lock().is_disabled(source, name)
-    }
-
-    pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
-        let mut state = self.state.lock();
-        state.enable(source, tools_to_enable);
-    }
-
-    pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
-        let mut state = self.state.lock();
-        state.disable(source, tools_to_disable);
-    }
-
-    pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
-        let mut state = self.state.lock();
-        state
-            .context_server_tools_by_id
-            .retain(|id, _| !tool_ids_to_remove.contains(id));
-        state.tools_changed();
-    }
-}
-
-impl WorkingSetState {
-    fn tools_changed(&mut self) {
-        self.context_server_tools_by_name.clear();
-        self.context_server_tools_by_name.extend(
-            self.context_server_tools_by_id
-                .values()
-                .map(|tool| (tool.name(), tool.clone())),
-        );
-    }
-
-    fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
         let mut tools = ToolRegistry::global(cx).tools();
         tools.extend(self.context_server_tools_by_id.values().cloned());
-
         tools
     }
 
-    fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
+    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
         let mut tools_by_source = IndexMap::default();
 
         for tool in self.tools(cx) {
@@ -135,7 +57,7 @@ impl WorkingSetState {
         tools_by_source
     }
 
-    fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
+    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
         let all_tools = self.tools(cx);
 
         all_tools
@@ -144,51 +66,90 @@ impl WorkingSetState {
             .collect()
     }
 
-    fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
+    pub fn disable_all_tools(&mut self, cx: &mut Context<Self>) {
+        self.enabled_tools_by_source.clear();
+        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
+    }
+
+    pub fn enable_source(&mut self, source: ToolSource, cx: &mut Context<Self>) {
+        self.enabled_sources.insert(source.clone());
+
+        let tools_by_source = self.tools_by_source(cx);
+        if let Some(tools) = tools_by_source.get(&source) {
+            self.enabled_tools_by_source.insert(
+                source,
+                tools
+                    .into_iter()
+                    .map(|tool| tool.name().into())
+                    .collect::<HashSet<_>>(),
+            );
+        }
+        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
+    }
+
+    pub fn disable_source(&mut self, source: &ToolSource, cx: &mut Context<Self>) {
+        self.enabled_sources.remove(source);
+        self.enabled_tools_by_source.remove(source);
+        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
+    }
+
+    pub fn insert(&mut self, tool: Arc<dyn Tool>) -> ToolId {
+        let tool_id = self.next_tool_id;
+        self.next_tool_id.0 += 1;
+        self.context_server_tools_by_id
+            .insert(tool_id, tool.clone());
+        self.tools_changed();
+        tool_id
+    }
+
+    pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
         self.enabled_tools_by_source
             .get(source)
             .map_or(false, |enabled_tools| enabled_tools.contains(name))
     }
 
-    fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
+    pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
         !self.is_enabled(source, name)
     }
 
-    fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
+    pub fn enable(
+        &mut self,
+        source: ToolSource,
+        tools_to_enable: &[Arc<str>],
+        cx: &mut Context<Self>,
+    ) {
         self.enabled_tools_by_source
             .entry(source)
             .or_default()
             .extend(tools_to_enable.into_iter().cloned());
+        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
     }
 
-    fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
+    pub fn disable(
+        &mut self,
+        source: ToolSource,
+        tools_to_disable: &[Arc<str>],
+        cx: &mut Context<Self>,
+    ) {
         self.enabled_tools_by_source
             .entry(source)
             .or_default()
             .retain(|name| !tools_to_disable.contains(name));
+        cx.emit(ToolWorkingSetEvent::EnabledToolsChanged);
     }
 
-    fn enable_source(&mut self, source: ToolSource, cx: &App) {
-        self.enabled_sources.insert(source.clone());
-
-        let tools_by_source = self.tools_by_source(cx);
-        if let Some(tools) = tools_by_source.get(&source) {
-            self.enabled_tools_by_source.insert(
-                source,
-                tools
-                    .into_iter()
-                    .map(|tool| tool.name().into())
-                    .collect::<HashSet<_>>(),
-            );
-        }
-    }
-
-    fn disable_source(&mut self, source: &ToolSource) {
-        self.enabled_sources.remove(source);
-        self.enabled_tools_by_source.remove(source);
+    pub fn remove(&mut self, tool_ids_to_remove: &[ToolId]) {
+        self.context_server_tools_by_id
+            .retain(|id, _| !tool_ids_to_remove.contains(id));
+        self.tools_changed();
     }
 
-    fn disable_all_tools(&mut self) {
-        self.enabled_tools_by_source.clear();
+    fn tools_changed(&mut self) {
+        self.context_server_tools_by_name.clear();
+        self.context_server_tools_by_name.extend(
+            self.context_server_tools_by_id
+                .values()
+                .map(|tool| (tool.name(), tool.clone())),
+        );
     }
 }

crates/eval/src/example.rs 🔗

@@ -6,7 +6,7 @@ use collections::HashMap;
 use dap::DapRegistry;
 use futures::channel::mpsc;
 use futures::{FutureExt, StreamExt as _, select_biased};
-use gpui::{App, AsyncApp, Entity, Task};
+use gpui::{App, AppContext as _, AsyncApp, Entity, Task};
 use handlebars::Handlebars;
 use language::{DiagnosticSeverity, OffsetRangeExt};
 use language_model::{
@@ -181,7 +181,7 @@ impl Example {
             project.create_worktree(&worktree_path, true, cx)
         });
 
-        let tools = Arc::new(ToolWorkingSet::default());
+        let tools = cx.new(|_| ToolWorkingSet::default());
         let thread_store =
             ThreadStore::load(project.clone(), tools, app_state.prompt_builder.clone(), cx);
         let this = self.clone();