assistant: Add basic tool invocation (#17368)

Marshall Bowers created

This PR adds the initial groundwork for invoking tools in response to
tool uses from the model.

Tool uses are run when the model responds with a `stop_reason` of
`tool_use`.

Currently the tool results are just inserted as text into the user
message. We'll want to include these as `tool_result` content on the
message, but Claude seems to understand it regardless.

Release Notes:

- N/A

Change summary

crates/assistant/src/assistant_panel.rs    | 22 +++++++
crates/assistant/src/context.rs            | 75 ++++++++++++++++++++---
crates/assistant_tool/src/tool_registry.rs |  5 +
3 files changed, 92 insertions(+), 10 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -20,6 +20,7 @@ use crate::{
 };
 use anyhow::{anyhow, Result};
 use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
+use assistant_tool::ToolRegistry;
 use client::{proto, Client, Status};
 use collections::{BTreeSet, HashMap, HashSet};
 use editor::{
@@ -2091,6 +2092,27 @@ impl ContextEditor {
                     }
                 }
             }
+            ContextEvent::UsePendingTools => {
+                let pending_tool_uses = self
+                    .context
+                    .read(cx)
+                    .pending_tool_uses()
+                    .into_iter()
+                    .filter(|tool_use| tool_use.status.is_idle())
+                    .cloned()
+                    .collect::<Vec<_>>();
+
+                for tool_use in pending_tool_uses {
+                    let tool_registry = ToolRegistry::global(cx);
+                    if let Some(tool) = tool_registry.tool(&tool_use.name) {
+                        let task = tool.run(tool_use.input, self.workspace.clone(), cx);
+
+                        self.context.update(cx, |context, cx| {
+                            context.insert_tool_output(tool_use.id.clone(), task, cx);
+                        });
+                    }
+                }
+            }
             ContextEvent::Operation(_) => {}
             ContextEvent::ShowAssistError(error_message) => {
                 self.error_message = Some(error_message.clone());

crates/assistant/src/context.rs 🔗

@@ -29,7 +29,7 @@ use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, P
 use language_model::{
     LanguageModel, LanguageModelCacheConfiguration, LanguageModelCompletionEvent,
     LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
-    LanguageModelRequestTool, MessageContent, Role,
+    LanguageModelRequestTool, MessageContent, Role, StopReason,
 };
 use open_ai::Model as OpenAiModel;
 use paths::{context_images_dir, contexts_dir};
@@ -306,6 +306,7 @@ pub enum ContextEvent {
         run_commands_in_output: bool,
         expand_result: bool,
     },
+    UsePendingTools,
     Operation(ContextOperation),
 }
 
@@ -416,6 +417,7 @@ impl Message {
 
             range_start = *image_offset;
         }
+
         if range_start != self.offset_range.end {
             if let Some(text) =
                 Self::collect_text_content(buffer, range_start..self.offset_range.end)
@@ -492,7 +494,7 @@ pub struct Context {
     edits_since_last_parse: language::Subscription,
     finished_slash_commands: HashSet<SlashCommandId>,
     slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
-    pending_tool_uses_by_id: HashMap<String, PendingToolUse>,
+    pending_tool_uses_by_id: HashMap<Arc<str>, PendingToolUse>,
     message_anchors: Vec<MessageAnchor>,
     images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
     image_anchors: Vec<ImageAnchor>,
@@ -1012,7 +1014,7 @@ impl Context {
         self.pending_tool_uses_by_id.values().collect()
     }
 
-    pub fn get_tool_use_by_id(&self, id: &String) -> Option<&PendingToolUse> {
+    pub fn get_tool_use_by_id(&self, id: &Arc<str>) -> Option<&PendingToolUse> {
         self.pending_tool_uses_by_id.get(id)
     }
 
@@ -1919,6 +1921,45 @@ impl Context {
         }
     }
 
+    pub fn insert_tool_output(
+        &mut self,
+        tool_id: Arc<str>,
+        output: Task<Result<String>>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let insert_output_task = cx.spawn(|this, mut cx| {
+            let tool_id = tool_id.clone();
+            async move {
+                let output = output.await;
+                this.update(&mut cx, |this, cx| match output {
+                    Ok(mut output) => {
+                        if !output.ends_with('\n') {
+                            output.push('\n');
+                        }
+
+                        this.buffer.update(cx, |buffer, cx| {
+                            let buffer_end = buffer.len().to_offset(buffer);
+
+                            buffer.edit([(buffer_end..buffer_end, output)], None, cx);
+                        });
+                    }
+                    Err(err) => {
+                        if let Some(tool_use) = this.pending_tool_uses_by_id.get_mut(&tool_id) {
+                            tool_use.status = PendingToolUseStatus::Error(err.to_string());
+                        }
+                    }
+                })
+                .ok();
+            }
+        });
+
+        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_id) {
+            tool_use.status = PendingToolUseStatus::Running {
+                _task: insert_output_task.shared(),
+            };
+        }
+    }
+
     pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
         self.count_remaining_tokens(cx);
     }
@@ -1990,7 +2031,7 @@ impl Context {
                                 .message_anchors
                                 .iter()
                                 .position(|message| message.id == assistant_message_id)?;
-                            this.buffer.update(cx, |buffer, cx| {
+                            let event_to_emit = this.buffer.update(cx, |buffer, cx| {
                                 let message_old_end_offset = this.message_anchors[message_ix + 1..]
                                     .iter()
                                     .find(|message| message.start.is_valid(buffer))
@@ -2000,9 +2041,11 @@ impl Context {
 
                                 match event {
                                     LanguageModelCompletionEvent::Stop(reason) => match reason {
-                                        language_model::StopReason::ToolUse => {}
-                                        language_model::StopReason::EndTurn => {}
-                                        language_model::StopReason::MaxTokens => {}
+                                        StopReason::ToolUse => {
+                                            return Some(ContextEvent::UsePendingTools);
+                                        }
+                                        StopReason::EndTurn => {}
+                                        StopReason::MaxTokens => {}
                                     },
                                     LanguageModelCompletionEvent::Text(chunk) => {
                                         buffer.edit(
@@ -2041,10 +2084,11 @@ impl Context {
                                         let source_range = buffer.anchor_after(start_ix)
                                             ..buffer.anchor_after(end_ix);
 
+                                        let tool_use_id: Arc<str> = tool_use.id.into();
                                         this.pending_tool_uses_by_id.insert(
-                                            tool_use.id.clone(),
+                                            tool_use_id.clone(),
                                             PendingToolUse {
-                                                id: tool_use.id,
+                                                id: tool_use_id,
                                                 name: tool_use.name,
                                                 input: tool_use.input,
                                                 status: PendingToolUseStatus::Idle,
@@ -2053,9 +2097,14 @@ impl Context {
                                         );
                                     }
                                 }
+
+                                None
                             });
 
                             cx.emit(ContextEvent::StreamedCompletion);
+                            if let Some(event) = event_to_emit {
+                                cx.emit(event);
+                            }
 
                             Some(())
                         })?;
@@ -2821,7 +2870,7 @@ impl FeatureFlag for ToolUseFeatureFlag {
 
 #[derive(Debug, Clone)]
 pub struct PendingToolUse {
-    pub id: String,
+    pub id: Arc<str>,
     pub name: String,
     pub input: serde_json::Value,
     pub status: PendingToolUseStatus,
@@ -2835,6 +2884,12 @@ pub enum PendingToolUseStatus {
     Error(String),
 }
 
+impl PendingToolUseStatus {
+    pub fn is_idle(&self) -> bool {
+        matches!(self, PendingToolUseStatus::Idle)
+    }
+}
+
 #[derive(Serialize, Deserialize)]
 pub struct SavedMessage {
     pub id: MessageId,

crates/assistant_tool/src/tool_registry.rs 🔗

@@ -66,4 +66,9 @@ impl ToolRegistry {
     pub fn tools(&self) -> Vec<Arc<dyn Tool>> {
         self.state.read().tools.values().cloned().collect()
     }
+
+    /// Returns the [`Tool`] with the given name.
+    pub fn tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
+        self.state.read().tools.get(name).cloned()
+    }
 }