Cleanup tool registry API surface (#11637)

Kyle Kelley and Max created

Fast followups to #11629 

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>

Change summary

crates/assistant2/src/assistant2.rs               |  20 -
crates/assistant_tooling/src/assistant_tooling.rs |   4 
crates/assistant_tooling/src/tool_registry.rs     | 169 +++++++---------
3 files changed, 88 insertions(+), 105 deletions(-)

Detailed changes

crates/assistant2/src/assistant2.rs 🔗

@@ -11,8 +11,7 @@ use crate::ui::UserOrAssistant;
 use ::ui::{div, prelude::*, Color, Tooltip, ViewContext};
 use anyhow::{Context, Result};
 use assistant_tooling::{
-    tool_running_placeholder, AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry,
-    UserAttachment,
+    AttachmentRegistry, ProjectContext, ToolFunctionCall, ToolRegistry, UserAttachment,
 };
 use attachments::ActiveEditorAttachmentTool;
 use client::{proto, Client, UserStore};
@@ -130,16 +129,13 @@ impl AssistantPanel {
 
                 let mut tool_registry = ToolRegistry::new();
                 tool_registry
-                    .register(ProjectIndexTool::new(project_index.clone()), cx)
+                    .register(ProjectIndexTool::new(project_index.clone()))
                     .unwrap();
                 tool_registry
-                    .register(
-                        CreateBufferTool::new(workspace.clone(), project.clone()),
-                        cx,
-                    )
+                    .register(CreateBufferTool::new(workspace.clone(), project.clone()))
                     .unwrap();
                 tool_registry
-                    .register(AnnotationTool::new(workspace.clone(), project.clone()), cx)
+                    .register(AnnotationTool::new(workspace.clone(), project.clone()))
                     .unwrap();
 
                 let mut attachment_registry = AttachmentRegistry::new();
@@ -588,9 +584,9 @@ impl AssistantChat {
                         cx.notify();
                     } else {
                         if let Some(current_message) = messages.last_mut() {
-                            for tool_call in current_message.tool_calls.iter() {
+                            for tool_call in current_message.tool_calls.iter_mut() {
                                 tool_tasks
-                                    .extend(this.tool_registry.execute_tool_call(&tool_call, cx));
+                                    .extend(this.tool_registry.execute_tool_call(tool_call, cx));
                             }
                         }
                     }
@@ -847,7 +843,7 @@ impl AssistantChat {
                     let tools = message
                         .tool_calls
                         .iter()
-                        .map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
+                        .filter_map(|tool_call| self.tool_registry.render_tool_call(tool_call, cx))
                         .collect::<Vec<AnyElement>>();
 
                     if !tools.is_empty() {
@@ -856,7 +852,7 @@ impl AssistantChat {
                 }
 
                 if message_elements.is_empty() {
-                    message_elements.push(tool_running_placeholder());
+                    message_elements.push(::ui::Label::new("Researching...").into_any_element())
                 }
 
                 div()

crates/assistant_tooling/src/assistant_tooling.rs 🔗

@@ -8,6 +8,6 @@ pub use attachment_registry::{
 };
 pub use project_context::ProjectContext;
 pub use tool_registry::{
-    tool_running_placeholder, LanguageModelTool, SavedToolFunctionCall, SavedToolFunctionCallState,
-    ToolFunctionCall, ToolFunctionCallState, ToolFunctionDefinition, ToolOutput, ToolRegistry,
+    LanguageModelTool, SavedToolFunctionCall, ToolFunctionCall, ToolFunctionDefinition, ToolOutput,
+    ToolRegistry,
 };

crates/assistant_tooling/src/tool_registry.rs 🔗

@@ -9,10 +9,8 @@ use std::{
     any::TypeId,
     collections::HashMap,
     fmt::Display,
-    sync::{
-        atomic::{AtomicBool, Ordering::SeqCst},
-        Arc,
-    },
+    mem,
+    sync::atomic::{AtomicBool, Ordering::SeqCst},
 };
 use ui::ViewContext;
 
@@ -29,7 +27,7 @@ pub struct ToolFunctionCall {
 }
 
 #[derive(Default)]
-pub enum ToolFunctionCallState {
+enum ToolFunctionCallState {
     #[default]
     Initializing,
     NoSuchTool,
@@ -37,10 +35,10 @@ pub enum ToolFunctionCallState {
     ExecutedTool(Box<dyn ToolView>),
 }
 
-pub trait ToolView {
+trait ToolView {
     fn view(&self) -> AnyView;
     fn generate(&self, project: &mut ProjectContext, cx: &mut WindowContext) -> String;
-    fn set_input(&self, input: &str, cx: &mut WindowContext);
+    fn try_set_input(&self, input: &str, cx: &mut WindowContext);
     fn execute(&self, cx: &mut WindowContext) -> Task<Result<()>>;
     fn serialize_output(&self, cx: &mut WindowContext) -> Result<Box<RawValue>>;
     fn deserialize_output(&self, raw_value: &RawValue, cx: &mut WindowContext) -> Result<()>;
@@ -48,14 +46,14 @@ pub trait ToolView {
 
 #[derive(Default, Serialize, Deserialize)]
 pub struct SavedToolFunctionCall {
-    pub id: String,
-    pub name: String,
-    pub arguments: String,
-    pub state: SavedToolFunctionCallState,
+    id: String,
+    name: String,
+    arguments: String,
+    state: SavedToolFunctionCallState,
 }
 
 #[derive(Default, Serialize, Deserialize)]
-pub enum SavedToolFunctionCallState {
+enum SavedToolFunctionCallState {
     #[default]
     Initializing,
     NoSuchTool,
@@ -63,7 +61,7 @@ pub enum SavedToolFunctionCallState {
     ExecutedTool(Box<RawValue>),
 }
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq)]
 pub struct ToolFunctionDefinition {
     pub name: String,
     pub description: String,
@@ -100,18 +98,6 @@ pub trait LanguageModelTool {
     fn view(&self, cx: &mut WindowContext) -> View<Self::View>;
 }
 
-pub fn tool_running_placeholder() -> AnyElement {
-    ui::Label::new("Researching...").into_any_element()
-}
-
-pub fn unknown_tool_placeholder() -> AnyElement {
-    ui::Label::new("Unknown tool").into_any_element()
-}
-
-pub fn no_such_tool_placeholder() -> AnyElement {
-    ui::Label::new("No such tool").into_any_element()
-}
-
 pub trait ToolOutput: Render {
     /// The input type that will be passed in to `execute` when the tool is called
     /// by the language model.
@@ -172,11 +158,6 @@ impl ToolRegistry {
             .collect()
     }
 
-    pub fn view_for_tool(&self, name: &str, cx: &mut WindowContext) -> Option<Box<dyn ToolView>> {
-        let tool = self.registered_tools.get(name)?;
-        Some((tool.build_view)(cx))
-    }
-
     pub fn update_tool_call(
         &self,
         call: &mut ToolFunctionCall,
@@ -189,7 +170,8 @@ impl ToolRegistry {
         }
         if let Some(arguments) = arguments {
             if call.arguments.is_empty() {
-                if let Some(view) = self.view_for_tool(&call.name, cx) {
+                if let Some(tool) = self.registered_tools.get(&call.name) {
+                    let view = (tool.build_view)(cx);
                     call.state = ToolFunctionCallState::KnownTool(view);
                 } else {
                     call.state = ToolFunctionCallState::NoSuchTool;
@@ -199,7 +181,7 @@ impl ToolRegistry {
 
             if let ToolFunctionCallState::KnownTool(view) = &call.state {
                 if let Ok(repaired_arguments) = repair(call.arguments.clone()) {
-                    view.set_input(&repaired_arguments, cx)
+                    view.try_set_input(&repaired_arguments, cx)
                 }
             }
         }
@@ -207,11 +189,13 @@ impl ToolRegistry {
 
     pub fn execute_tool_call(
         &self,
-        tool_call: &ToolFunctionCall,
+        tool_call: &mut ToolFunctionCall,
         cx: &mut WindowContext,
     ) -> Option<Task<Result<()>>> {
-        if let ToolFunctionCallState::KnownTool(view) = &tool_call.state {
-            Some(view.execute(cx))
+        if let ToolFunctionCallState::KnownTool(view) = mem::take(&mut tool_call.state) {
+            let task = view.execute(cx);
+            tool_call.state = ToolFunctionCallState::ExecutedTool(view);
+            Some(task)
         } else {
             None
         }
@@ -221,12 +205,14 @@ impl ToolRegistry {
         &self,
         tool_call: &ToolFunctionCall,
         _cx: &mut WindowContext,
-    ) -> AnyElement {
+    ) -> Option<AnyElement> {
         match &tool_call.state {
-            ToolFunctionCallState::NoSuchTool => no_such_tool_placeholder(),
-            ToolFunctionCallState::Initializing => unknown_tool_placeholder(),
+            ToolFunctionCallState::NoSuchTool => {
+                Some(ui::Label::new("No such tool").into_any_element())
+            }
+            ToolFunctionCallState::Initializing => None,
             ToolFunctionCallState::KnownTool(view) | ToolFunctionCallState::ExecutedTool(view) => {
-                view.view().into_any_element()
+                Some(view.view().into_any_element())
             }
         }
     }
@@ -287,12 +273,12 @@ impl ToolRegistry {
                 SavedToolFunctionCallState::KnownTool => {
                     log::error!("Deserialized tool that had not executed");
                     let view = (tool.build_view)(cx);
-                    view.set_input(&call.arguments, cx);
+                    view.try_set_input(&call.arguments, cx);
                     ToolFunctionCallState::KnownTool(view)
                 }
                 SavedToolFunctionCallState::ExecutedTool(output) => {
                     let view = (tool.build_view)(cx);
-                    view.set_input(&call.arguments, cx);
+                    view.try_set_input(&call.arguments, cx);
                     view.deserialize_output(output, cx)?;
                     ToolFunctionCallState::ExecutedTool(view)
                 }
@@ -300,13 +286,8 @@ impl ToolRegistry {
         })
     }
 
-    pub fn register<T: 'static + LanguageModelTool>(
-        &mut self,
-        tool: T,
-        _cx: &mut WindowContext,
-    ) -> Result<()> {
+    pub fn register<T: 'static + LanguageModelTool>(&mut self, tool: T) -> Result<()> {
         let name = tool.name();
-        let tool = Arc::new(tool);
         let registered_tool = RegisteredTool {
             type_id: TypeId::of::<T>(),
             definition: tool.definition(),
@@ -332,7 +313,7 @@ impl<T: ToolOutput> ToolView for View<T> {
         self.update(cx, |view, cx| view.generate(project, cx))
     }
 
-    fn set_input(&self, input: &str, cx: &mut WindowContext) {
+    fn try_set_input(&self, input: &str, cx: &mut WindowContext) {
         if let Ok(input) = serde_json::from_str::<T::Input>(input) {
             self.update(cx, |view, cx| {
                 view.set_input(input, cx);
@@ -372,7 +353,6 @@ mod test {
     use super::*;
     use gpui::{div, prelude::*, Render, TestAppContext};
     use gpui::{EmptyView, View};
-    use schemars::schema_for;
     use schemars::JsonSchema;
     use serde::{Deserialize, Serialize};
     use serde_json::json;
@@ -483,57 +463,64 @@ mod test {
 
     #[gpui::test]
     async fn test_openai_weather_example(cx: &mut TestAppContext) {
-        cx.background_executor.run_until_parked();
         let (_, cx) = cx.add_window_view(|_cx| EmptyView);
 
-        let tool = WeatherTool {
-            current_weather: WeatherResult {
-                location: "San Francisco".to_string(),
-                temperature: 21.0,
-                unit: "Celsius".to_string(),
-            },
-        };
-
-        let tools = vec![tool.definition()];
-        assert_eq!(tools.len(), 1);
-
-        let expected = ToolFunctionDefinition {
-            name: "get_current_weather".to_string(),
-            description: "Fetches the current weather for a given location.".to_string(),
-            parameters: schema_for!(WeatherQuery),
-        };
-
-        assert_eq!(tools[0].name, expected.name);
-        assert_eq!(tools[0].description, expected.description);
-
-        let expected_schema = serde_json::to_value(&tools[0].parameters).unwrap();
+        let mut registry = ToolRegistry::new();
+        registry
+            .register(WeatherTool {
+                current_weather: WeatherResult {
+                    location: "San Francisco".to_string(),
+                    temperature: 21.0,
+                    unit: "Celsius".to_string(),
+                },
+            })
+            .unwrap();
 
+        let definitions = registry.definitions();
         assert_eq!(
-            expected_schema,
-            json!({
-                "$schema": "http://json-schema.org/draft-07/schema#",
-                "title": "WeatherQuery",
-                "type": "object",
-                "properties": {
-                    "location": {
-                        "type": "string"
+            definitions,
+            [ToolFunctionDefinition {
+                name: "get_current_weather".to_string(),
+                description: "Fetches the current weather for a given location.".to_string(),
+                parameters: serde_json::from_value(json!({
+                    "$schema": "http://json-schema.org/draft-07/schema#",
+                    "title": "WeatherQuery",
+                    "type": "object",
+                    "properties": {
+                        "location": {
+                            "type": "string"
+                        },
+                        "unit": {
+                            "type": "string"
+                        }
                     },
-                    "unit": {
-                        "type": "string"
-                    }
-                },
-                "required": ["location", "unit"]
-            })
+                    "required": ["location", "unit"]
+                }))
+                .unwrap(),
+            }]
         );
 
-        let view = cx.update(|cx| tool.view(cx));
+        let mut call = ToolFunctionCall {
+            id: "the-id".to_string(),
+            name: "get_cur".to_string(),
+            ..Default::default()
+        };
 
-        cx.update(|cx| {
-            view.set_input(&r#"{"location": "San Francisco", "unit": "Celsius"}"#, cx);
+        let task = cx.update(|cx| {
+            registry.update_tool_call(
+                &mut call,
+                Some("rent_weather"),
+                Some(r#"{"location": "San Francisco","#),
+                cx,
+            );
+            registry.update_tool_call(&mut call, None, Some(r#" "unit": "Celsius"}"#), cx);
+            registry.execute_tool_call(&mut call, cx).unwrap()
         });
+        task.await.unwrap();
 
-        let finished = cx.update(|cx| view.execute(cx)).await;
-
-        assert!(finished.is_ok());
+        match &call.state {
+            ToolFunctionCallState::ExecutedTool(_view) => {}
+            _ => panic!(),
+        }
     }
 }