assistant2: Decouple scripting tool from the `Tool` trait (#26382)

Marshall Bowers created

This PR decouples the scripting tool from the `Tool` trait while still
allowing it to be used as a tool from the model's perspective.

This will allow us to evolve the scripting tool as more of a first-class
citizen while still retaining the ability to have the model call it as a
regular tool.

Release Notes:

- N/A

Change summary

Cargo.lock                                  |   3 
Cargo.toml                                  |   2 
crates/assistant2/Cargo.toml                |   1 
crates/assistant2/src/active_thread.rs      |  33 ++++-
crates/assistant2/src/thread.rs             | 122 +++++++++++++++++++---
crates/assistant2/src/tool_use.rs           |   1 
crates/scripting_tool/Cargo.toml            |   1 
crates/scripting_tool/src/scripting_tool.rs |  25 +---
crates/zed/Cargo.toml                       |   1 
crates/zed/src/main.rs                      |   1 
10 files changed, 138 insertions(+), 52 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -490,6 +490,7 @@ dependencies = [
  "proto",
  "rand 0.8.5",
  "rope",
+ "scripting_tool",
  "serde",
  "serde_json",
  "settings",
@@ -11915,7 +11916,6 @@ name = "scripting_tool"
 version = "0.1.0"
 dependencies = [
  "anyhow",
- "assistant_tool",
  "collections",
  "futures 0.3.31",
  "gpui",
@@ -16986,7 +16986,6 @@ dependencies = [
  "repl",
  "reqwest_client",
  "rope",
- "scripting_tool",
  "search",
  "serde",
  "serde_json",

Cargo.toml 🔗

@@ -8,7 +8,6 @@ members = [
     "crates/assistant",
     "crates/assistant2",
     "crates/assistant_context_editor",
-    "crates/scripting_tool",
     "crates/assistant_settings",
     "crates/assistant_slash_command",
     "crates/assistant_slash_commands",
@@ -119,6 +118,7 @@ members = [
     "crates/rope",
     "crates/rpc",
     "crates/schema_generator",
+    "crates/scripting_tool",
     "crates/search",
     "crates/semantic_index",
     "crates/semantic_version",

crates/assistant2/Cargo.toml 🔗

@@ -59,6 +59,7 @@ prompt_library.workspace = true
 prompt_store.workspace = true
 proto.workspace = true
 rope.workspace = true
+scripting_tool.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true

crates/assistant2/src/active_thread.rs 🔗

@@ -457,9 +457,13 @@ impl ActiveThread {
 
         let context = thread.context_for_message(message_id);
         let tool_uses = thread.tool_uses_for_message(message_id);
+        let scripting_tool_uses = thread.scripting_tool_uses_for_message(message_id);
 
         // Don't render user messages that are just there for returning tool results.
-        if message.role == Role::User && thread.message_has_tool_results(message_id) {
+        if message.role == Role::User
+            && (thread.message_has_tool_results(message_id)
+                || thread.message_has_scripting_tool_results(message_id))
+        {
             return Empty.into_any();
         }
 
@@ -609,16 +613,22 @@ impl ActiveThread {
                 .id(("message-container", ix))
                 .child(message_content)
                 .map(|parent| {
-                    if tool_uses.is_empty() {
+                    if tool_uses.is_empty() && scripting_tool_uses.is_empty() {
                         return parent;
                     }
 
                     parent.child(
-                        v_flex().children(
-                            tool_uses
-                                .into_iter()
-                                .map(|tool_use| self.render_tool_use(tool_use, cx)),
-                        ),
+                        v_flex()
+                            .children(
+                                tool_uses
+                                    .into_iter()
+                                    .map(|tool_use| self.render_tool_use(tool_use, cx)),
+                            )
+                            .children(
+                                scripting_tool_uses
+                                    .into_iter()
+                                    .map(|tool_use| self.render_scripting_tool_use(tool_use, cx)),
+                            ),
                     )
                 }),
             Role::System => div().id(("message-container", ix)).py_1().px_2().child(
@@ -727,6 +737,15 @@ impl ActiveThread {
                 }),
         )
     }
+
+    fn render_scripting_tool_use(
+        &self,
+        tool_use: ToolUse,
+        cx: &mut Context<Self>,
+    ) -> impl IntoElement {
+        // TODO: Add custom rendering for scripting tool uses.
+        self.render_tool_use(tool_use, cx)
+    }
 }
 
 impl Render for ActiveThread {

crates/assistant2/src/thread.rs 🔗

@@ -13,13 +13,14 @@ use language_model::{
     Role, StopReason,
 };
 use project::Project;
+use scripting_tool::ScriptingTool;
 use serde::{Deserialize, Serialize};
 use util::{post_inc, TryFutureExt as _};
 use uuid::Uuid;
 
 use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 use crate::thread_store::SavedThread;
-use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
+use crate::tool_use::{ToolUse, ToolUseState};
 
 #[derive(Debug, Clone, Copy)]
 pub enum RequestKind {
@@ -75,6 +76,7 @@ pub struct Thread {
     project: Entity<Project>,
     tools: Arc<ToolWorkingSet>,
     tool_use: ToolUseState,
+    scripting_tool_use: ToolUseState,
 }
 
 impl Thread {
@@ -97,6 +99,7 @@ impl Thread {
             project,
             tools,
             tool_use: ToolUseState::new(),
+            scripting_tool_use: ToolUseState::new(),
         }
     }
 
@@ -115,6 +118,7 @@ impl Thread {
                 .unwrap_or(0),
         );
         let tool_use = ToolUseState::from_saved_messages(&saved.messages);
+        let scripting_tool_use = ToolUseState::new();
 
         Self {
             id,
@@ -138,6 +142,7 @@ impl Thread {
             project,
             tools,
             tool_use,
+            scripting_tool_use,
         }
     }
 
@@ -198,31 +203,46 @@ impl Thread {
         )
     }
 
-    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
-        self.tool_use.pending_tool_uses()
-    }
-
     /// Returns whether all of the tool uses have finished running.
     pub fn all_tools_finished(&self) -> bool {
+        let mut all_pending_tool_uses = self
+            .tool_use
+            .pending_tool_uses()
+            .into_iter()
+            .chain(self.scripting_tool_use.pending_tool_uses());
+
         // If the only pending tool uses left are the ones with errors, then that means that we've finished running all
         // of the pending tools.
-        self.pending_tool_uses()
-            .into_iter()
-            .all(|tool_use| tool_use.status.is_error())
+        all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
     }
 
     pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
         self.tool_use.tool_uses_for_message(id)
     }
 
+    pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
+        self.scripting_tool_use.tool_uses_for_message(id)
+    }
+
     pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
         self.tool_use.tool_results_for_message(id)
     }
 
+    pub fn scripting_tool_results_for_message(
+        &self,
+        id: MessageId,
+    ) -> Vec<&LanguageModelToolResult> {
+        self.scripting_tool_use.tool_results_for_message(id)
+    }
+
     pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
         self.tool_use.message_has_tool_results(message_id)
     }
 
+    pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
+        self.scripting_tool_use.message_has_tool_results(message_id)
+    }
+
     pub fn insert_user_message(
         &mut self,
         text: impl Into<String>,
@@ -313,16 +333,25 @@ impl Thread {
         let mut request = self.to_completion_request(request_kind, cx);
 
         if use_tools {
-            request.tools = self
-                .tools()
-                .tools(cx)
-                .into_iter()
-                .map(|tool| LanguageModelRequestTool {
-                    name: tool.name(),
-                    description: tool.description(),
-                    input_schema: tool.input_schema(),
-                })
-                .collect();
+            let mut tools = Vec::new();
+            tools.push(LanguageModelRequestTool {
+                name: ScriptingTool::NAME.into(),
+                description: ScriptingTool::DESCRIPTION.into(),
+                input_schema: ScriptingTool::input_schema(),
+            });
+
+            tools.extend(
+                self.tools()
+                    .tools(cx)
+                    .into_iter()
+                    .map(|tool| LanguageModelRequestTool {
+                        name: tool.name(),
+                        description: tool.description(),
+                        input_schema: tool.input_schema(),
+                    }),
+            );
+
+            request.tools = tools;
         }
 
         self.stream_completion(request, model, cx);
@@ -357,6 +386,8 @@ impl Thread {
                 RequestKind::Chat => {
                     self.tool_use
                         .attach_tool_results(message.id, &mut request_message);
+                    self.scripting_tool_use
+                        .attach_tool_results(message.id, &mut request_message);
                 }
                 RequestKind::Summarize => {
                     // We don't care about tool use during summarization.
@@ -373,6 +404,8 @@ impl Thread {
                 RequestKind::Chat => {
                     self.tool_use
                         .attach_tool_uses(message.id, &mut request_message);
+                    self.scripting_tool_use
+                        .attach_tool_uses(message.id, &mut request_message);
                 }
                 RequestKind::Summarize => {
                     // We don't care about tool use during summarization.
@@ -450,9 +483,15 @@ impl Thread {
                                     .iter()
                                     .rfind(|message| message.role == Role::Assistant)
                                 {
-                                    thread
-                                        .tool_use
-                                        .request_tool_use(last_assistant_message.id, tool_use);
+                                    if tool_use.name.as_ref() == ScriptingTool::NAME {
+                                        thread
+                                            .scripting_tool_use
+                                            .request_tool_use(last_assistant_message.id, tool_use);
+                                    } else {
+                                        thread
+                                            .tool_use
+                                            .request_tool_use(last_assistant_message.id, tool_use);
+                                    }
                                 }
                             }
                         }
@@ -572,6 +611,7 @@ impl Thread {
 
     pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
         let pending_tool_uses = self
+            .tool_use
             .pending_tool_uses()
             .into_iter()
             .filter(|tool_use| tool_use.status.is_idle())
@@ -585,6 +625,20 @@ impl Thread {
                 self.insert_tool_output(tool_use.id.clone(), task, cx);
             }
         }
+
+        let pending_scripting_tool_uses = self
+            .scripting_tool_use
+            .pending_tool_uses()
+            .into_iter()
+            .filter(|tool_use| tool_use.status.is_idle())
+            .cloned()
+            .collect::<Vec<_>>();
+
+        for scripting_tool_use in pending_scripting_tool_uses {
+            let task = ScriptingTool.run(scripting_tool_use.input, self.project.clone(), cx);
+
+            self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
+        }
     }
 
     pub fn insert_tool_output(
@@ -613,6 +667,32 @@ impl Thread {
             .run_pending_tool(tool_use_id, insert_output_task);
     }
 
+    pub fn insert_scripting_tool_output(
+        &mut self,
+        tool_use_id: LanguageModelToolUseId,
+        output: Task<Result<String>>,
+        cx: &mut Context<Self>,
+    ) {
+        let insert_output_task = cx.spawn(|thread, mut cx| {
+            let tool_use_id = tool_use_id.clone();
+            async move {
+                let output = output.await;
+                thread
+                    .update(&mut cx, |thread, cx| {
+                        thread
+                            .scripting_tool_use
+                            .insert_tool_output(tool_use_id.clone(), output);
+
+                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
+                    })
+                    .ok();
+            }
+        });
+
+        self.scripting_tool_use
+            .run_pending_tool(tool_use_id, insert_output_task);
+    }
+
     pub fn send_tool_results_to_model(
         &mut self,
         model: Arc<dyn LanguageModel>,

crates/assistant2/src/tool_use.rs 🔗

@@ -267,6 +267,7 @@ impl ToolUseState {
 pub struct PendingToolUse {
     pub id: LanguageModelToolUseId,
     /// The ID of the Assistant message in which the tool use was requested.
+    #[allow(unused)]
     pub assistant_message_id: MessageId,
     pub name: Arc<str>,
     pub input: serde_json::Value,

crates/scripting_tool/Cargo.toml 🔗

@@ -14,7 +14,6 @@ doctest = false
 
 [dependencies]
 anyhow.workspace = true
-assistant_tool.workspace = true
 collections.workspace = true
 futures.workspace = true
 gpui.workspace = true

crates/scripting_tool/src/scripting_tool.rs 🔗

@@ -3,40 +3,29 @@ mod session;
 use project::Project;
 use session::*;
 
-use assistant_tool::{Tool, ToolRegistry};
 use gpui::{App, AppContext as _, Entity, Task};
 use schemars::JsonSchema;
 use serde::Deserialize;
-use std::sync::Arc;
-
-pub fn init(cx: &App) {
-    let registry = ToolRegistry::global(cx);
-    registry.register_tool(ScriptingTool);
-}
 
 #[derive(Debug, Deserialize, JsonSchema)]
 struct ScriptingToolInput {
     lua_script: String,
 }
 
-struct ScriptingTool;
+pub struct ScriptingTool;
 
-impl Tool for ScriptingTool {
-    fn name(&self) -> String {
-        "lua-interpreter".into()
-    }
+impl ScriptingTool {
+    pub const NAME: &str = "lua-interpreter";
 
-    fn description(&self) -> String {
-        include_str!("scripting_tool_description.txt").into()
-    }
+    pub const DESCRIPTION: &str = include_str!("scripting_tool_description.txt");
 
-    fn input_schema(&self) -> serde_json::Value {
+    pub fn input_schema() -> serde_json::Value {
         let schema = schemars::schema_for!(ScriptingToolInput);
         serde_json::to_value(&schema).unwrap()
     }
 
-    fn run(
-        self: Arc<Self>,
+    pub fn run(
+        &self,
         input: serde_json::Value,
         project: Entity<Project>,
         cx: &mut App,

crates/zed/Cargo.toml 🔗

@@ -98,7 +98,6 @@ remote.workspace = true
 repl.workspace = true
 reqwest_client.workspace = true
 rope.workspace = true
-scripting_tool.workspace = true
 search.workspace = true
 serde.workspace = true
 serde_json.workspace = true

crates/zed/src/main.rs 🔗

@@ -476,7 +476,6 @@ fn main() {
             cx,
         );
         assistant_tools::init(cx);
-        scripting_tool::init(cx);
         repl::init(app_state.fs.clone(), cx);
         extension_host::init(
             extension_host_proxy,