assistant_tool: Reduce locking in `ToolWorkingSet` (#26605)

Marshall Bowers created

This PR updates the `ToolWorkingSet` to reduce the amount of locking we
need to do.

A number of the methods have had corresponding versions moved to the
`ToolWorkingSetState` so that we can take out the lock once and do a
number of operations without needing to continually acquire and release
the lock.

Release Notes:

- N/A

Change summary

crates/assistant_tool/src/tool_working_set.rs | 169 ++++++++++++--------
1 file changed, 99 insertions(+), 70 deletions(-)

Detailed changes

crates/assistant_tool/src/tool_working_set.rs 🔗

@@ -46,72 +46,31 @@ impl ToolWorkingSet {
     }
 
     pub fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
-        let mut tools = ToolRegistry::global(cx).tools();
-        tools.extend(
-            self.state
-                .lock()
-                .context_server_tools_by_id
-                .values()
-                .cloned(),
-        );
+        self.state.lock().tools(cx)
+    }
 
-        tools
+    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
+        self.state.lock().tools_by_source(cx)
     }
 
     pub fn are_all_tools_enabled(&self) -> bool {
         let state = self.state.lock();
-
         state.disabled_tools_by_source.is_empty() && !state.is_scripting_tool_disabled
     }
 
+    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
+        self.state.lock().enabled_tools(cx)
+    }
+
     pub fn enable_all_tools(&self) {
         let mut state = self.state.lock();
-
         state.disabled_tools_by_source.clear();
-        state.is_scripting_tool_disabled = false;
+        state.enable_scripting_tool();
     }
 
     pub fn disable_all_tools(&self, cx: &App) {
-        let tools = self.tools_by_source(cx);
-
-        for (source, tools) in tools {
-            let tool_names = tools
-                .into_iter()
-                .map(|tool| tool.name().into())
-                .collect::<Vec<_>>();
-
-            self.disable(source, &tool_names);
-        }
-
-        self.disable_scripting_tool();
-    }
-
-    pub fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
-        let all_tools = self.tools(cx);
-
-        all_tools
-            .into_iter()
-            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
-            .collect()
-    }
-
-    pub fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
-        let mut tools_by_source = IndexMap::default();
-
-        for tool in self.tools(cx) {
-            tools_by_source
-                .entry(tool.source())
-                .or_insert_with(Vec::new)
-                .push(tool);
-        }
-
-        for tools in tools_by_source.values_mut() {
-            tools.sort_by_key(|tool| tool.name());
-        }
-
-        tools_by_source.sort_unstable_keys();
-
-        tools_by_source
+        let mut state = self.state.lock();
+        state.disable_all_tools(cx);
     }
 
     pub fn insert(&self, tool: Arc<dyn Tool>) -> ToolId {
@@ -126,33 +85,21 @@ impl ToolWorkingSet {
     }
 
     pub fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
-        !self.is_disabled(source, name)
+        self.state.lock().is_enabled(source, name)
     }
 
     pub fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
-        let state = self.state.lock();
-        state
-            .disabled_tools_by_source
-            .get(source)
-            .map_or(false, |disabled_tools| disabled_tools.contains(name))
+        self.state.lock().is_disabled(source, name)
     }
 
     pub fn enable(&self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
         let mut state = self.state.lock();
-        state
-            .disabled_tools_by_source
-            .entry(source)
-            .or_default()
-            .retain(|name| !tools_to_enable.contains(name));
+        state.enable(source, tools_to_enable);
     }
 
     pub fn disable(&self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
         let mut state = self.state.lock();
-        state
-            .disabled_tools_by_source
-            .entry(source)
-            .or_default()
-            .extend(tools_to_disable.into_iter().cloned());
+        state.disable(source, tools_to_disable);
     }
 
     pub fn remove(&self, tool_ids_to_remove: &[ToolId]) {
@@ -170,12 +117,12 @@ impl ToolWorkingSet {
 
     pub fn enable_scripting_tool(&self) {
         let mut state = self.state.lock();
-        state.is_scripting_tool_disabled = false;
+        state.enable_scripting_tool();
     }
 
     pub fn disable_scripting_tool(&self) {
         let mut state = self.state.lock();
-        state.is_scripting_tool_disabled = true;
+        state.disable_scripting_tool();
     }
 }
 
@@ -188,4 +135,86 @@ impl WorkingSetState {
                 .map(|tool| (tool.name(), tool.clone())),
         );
     }
+
+    fn tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
+        let mut tools = ToolRegistry::global(cx).tools();
+        tools.extend(self.context_server_tools_by_id.values().cloned());
+
+        tools
+    }
+
+    fn tools_by_source(&self, cx: &App) -> IndexMap<ToolSource, Vec<Arc<dyn Tool>>> {
+        let mut tools_by_source = IndexMap::default();
+
+        for tool in self.tools(cx) {
+            tools_by_source
+                .entry(tool.source())
+                .or_insert_with(Vec::new)
+                .push(tool);
+        }
+
+        for tools in tools_by_source.values_mut() {
+            tools.sort_by_key(|tool| tool.name());
+        }
+
+        tools_by_source.sort_unstable_keys();
+
+        tools_by_source
+    }
+
+    fn enabled_tools(&self, cx: &App) -> Vec<Arc<dyn Tool>> {
+        let all_tools = self.tools(cx);
+
+        all_tools
+            .into_iter()
+            .filter(|tool| self.is_enabled(&tool.source(), &tool.name().into()))
+            .collect()
+    }
+
+    fn is_enabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
+        !self.is_disabled(source, name)
+    }
+
+    fn is_disabled(&self, source: &ToolSource, name: &Arc<str>) -> bool {
+        self.disabled_tools_by_source
+            .get(source)
+            .map_or(false, |disabled_tools| disabled_tools.contains(name))
+    }
+
+    fn enable(&mut self, source: ToolSource, tools_to_enable: &[Arc<str>]) {
+        self.disabled_tools_by_source
+            .entry(source)
+            .or_default()
+            .retain(|name| !tools_to_enable.contains(name));
+    }
+
+    fn disable(&mut self, source: ToolSource, tools_to_disable: &[Arc<str>]) {
+        self.disabled_tools_by_source
+            .entry(source)
+            .or_default()
+            .extend(tools_to_disable.into_iter().cloned());
+    }
+
+    fn disable_all_tools(&mut self, cx: &App) {
+        let tools = self.tools_by_source(cx);
+
+        for (source, tools) in tools {
+            let tool_names = tools
+                .into_iter()
+                .map(|tool| tool.name().into())
+                .collect::<Vec<_>>();
+
+            self.disable(source, &tool_names);
+        }
+
+        self.disable_scripting_tool();
+    }
+
+    fn enable_scripting_tool(&mut self) {
+        self.is_scripting_tool_disabled = false;
+    }
+
+    fn disable_scripting_tool(&mut self) {
+        self.is_scripting_tool_disabled = true;
+    }
 }