Allow cancellation of tool uses (#26906)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/assistant2/src/active_thread.rs          | 12 +++++-
crates/assistant2/src/message_editor.rs         | 10 +++---
crates/assistant2/src/thread.rs                 | 31 ++++++++++++------
crates/assistant2/src/tool_use.rs               | 16 +++++++++
crates/assistant_eval/src/headless_assistant.rs |  8 ++--
5 files changed, 55 insertions(+), 22 deletions(-)

Detailed changes

crates/assistant2/src/active_thread.rs 🔗

@@ -116,7 +116,7 @@ impl ActiveThread {
     pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool {
         self.last_error.take();
         self.thread
-            .update(cx, |thread, _cx| thread.cancel_last_completion())
+            .update(cx, |thread, cx| thread.cancel_last_completion(cx))
     }
 
     pub fn last_error(&self) -> Option<ThreadError> {
@@ -343,8 +343,11 @@ impl ActiveThread {
                 });
             }
             ThreadEvent::ToolFinished {
-                pending_tool_use, ..
+                pending_tool_use,
+                canceled,
+                ..
             } => {
+                let canceled = *canceled;
                 if let Some(tool_use) = pending_tool_use {
                     self.render_scripting_tool_use_markdown(
                         tool_use.id.clone(),
@@ -396,7 +399,10 @@ impl ActiveThread {
 
                             this.update(&mut cx, |this, cx| {
                                 this.thread.update(cx, |thread, cx| {
-                                    thread.send_tool_results_to_model(model, updated_context, cx);
+                                    thread.attach_tool_results(updated_context, cx);
+                                    if !canceled {
+                                        thread.send_to_model(model, RequestKind::Chat, cx);
+                                    }
                                 });
                             })
                         })

crates/assistant2/src/message_editor.rs 🔗

@@ -158,7 +158,7 @@ impl MessageEditor {
             return;
         }
 
-        if self.thread.read(cx).is_streaming() {
+        if self.thread.read(cx).is_generating() {
             return;
         }
 
@@ -328,7 +328,7 @@ impl Render for MessageEditor {
         let focus_handle = self.editor.focus_handle(cx);
         let inline_context_picker = self.inline_context_picker.clone();
         let bg_color = cx.theme().colors().editor_background;
-        let is_streaming_completion = self.thread.read(cx).is_streaming();
+        let is_generating = self.thread.read(cx).is_generating();
         let is_model_selected = self.is_model_selected(cx);
         let is_editor_empty = self.is_editor_empty(cx);
         let submit_label_color = if is_editor_empty {
@@ -352,7 +352,7 @@ impl Render for MessageEditor {
 
         v_flex()
             .size_full()
-            .when(is_streaming_completion, |parent| {
+            .when(is_generating, |parent| {
                 let focus_handle = self.editor.focus_handle(cx).clone();
                 parent.child(
                     h_flex().py_3().w_full().justify_center().child(
@@ -625,7 +625,7 @@ impl Render for MessageEditor {
                                                 .disabled(
                                                     is_editor_empty
                                                         || !is_model_selected
-                                                        || is_streaming_completion,
+                                                        || is_generating,
                                                 )
                                                 .child(
                                                     h_flex()
@@ -660,7 +660,7 @@ impl Render for MessageEditor {
                                                         "Type a message to submit",
                                                     ))
                                                 })
-                                                .when(is_streaming_completion, |button| {
+                                                .when(is_generating, |button| {
                                                     button.tooltip(Tooltip::text(
                                                         "Cancel to submit a new message",
                                                     ))

crates/assistant2/src/thread.rs 🔗

@@ -240,7 +240,7 @@ impl Thread {
         self.messages.iter()
     }
 
-    pub fn is_streaming(&self) -> bool {
+    pub fn is_generating(&self) -> bool {
         !self.pending_completions.is_empty() || !self.all_tools_finished()
     }
 
@@ -267,8 +267,8 @@ impl Thread {
             .into_iter()
             .chain(self.scripting_tool_use.pending_tool_uses());
 
-        // If the only pending tool uses left are the ones with errors, then that means that we've finished running all
-        // of the pending tools.
+        // If the only pending tool uses left are the ones with errors, then
+        // that means that we've finished running all of the pending tools.
         all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
     }
 
@@ -683,7 +683,7 @@ impl Thread {
                                 )));
                             }
 
-                            thread.cancel_last_completion();
+                            thread.cancel_last_completion(cx);
                         }
                     }
                     cx.emit(ThreadEvent::DoneStreaming);
@@ -833,6 +833,7 @@ impl Thread {
                         cx.emit(ThreadEvent::ToolFinished {
                             tool_use_id,
                             pending_tool_use,
+                            canceled: false,
                         });
                     })
                     .ok();
@@ -862,6 +863,7 @@ impl Thread {
                         cx.emit(ThreadEvent::ToolFinished {
                             tool_use_id,
                             pending_tool_use,
+                            canceled: false,
                         });
                     })
                     .ok();
@@ -872,9 +874,8 @@ impl Thread {
             .run_pending_tool(tool_use_id, insert_output_task);
     }
 
-    pub fn send_tool_results_to_model(
+    pub fn attach_tool_results(
         &mut self,
-        model: Arc<dyn LanguageModel>,
         updated_context: Vec<ContextSnapshot>,
         cx: &mut Context<Self>,
     ) {
@@ -893,17 +894,25 @@ impl Thread {
             Vec::new(),
             cx,
         );
-        self.send_to_model(model, RequestKind::Chat, cx);
     }
 
     /// Cancels the last pending completion, if there are any pending.
     ///
     /// Returns whether a completion was canceled.
-    pub fn cancel_last_completion(&mut self) -> bool {
-        if let Some(_last_completion) = self.pending_completions.pop() {
+    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
+        if self.pending_completions.pop().is_some() {
             true
         } else {
-            false
+            let mut canceled = false;
+            for pending_tool_use in self.tool_use.cancel_pending() {
+                canceled = true;
+                cx.emit(ThreadEvent::ToolFinished {
+                    tool_use_id: pending_tool_use.id.clone(),
+                    pending_tool_use: Some(pending_tool_use),
+                    canceled: true,
+                });
+            }
+            canceled
         }
     }
 
@@ -1114,6 +1123,8 @@ pub enum ThreadEvent {
         tool_use_id: LanguageModelToolUseId,
         /// The pending tool use that corresponds to this tool.
         pending_tool_use: Option<PendingToolUse>,
+        /// Whether the tool was canceled by the user.
+        canceled: bool,
     },
 }
 

crates/assistant2/src/tool_use.rs 🔗

@@ -118,6 +118,22 @@ impl ToolUseState {
         this
     }
 
+    pub fn cancel_pending(&mut self) -> Vec<PendingToolUse> {
+        let mut pending_tools = Vec::new();
+        for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() {
+            self.tool_results.insert(
+                tool_use_id.clone(),
+                LanguageModelToolResult {
+                    tool_use_id,
+                    content: "Tool canceled by user".into(),
+                    is_error: true,
+                },
+            );
+            pending_tools.push(tool_use.clone());
+        }
+        pending_tools
+    }
+
     pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
         self.pending_tool_uses_by_id.values().collect()
     }

crates/assistant_eval/src/headless_assistant.rs 🔗

@@ -1,5 +1,5 @@
 use anyhow::anyhow;
-use assistant2::{Thread, ThreadEvent, ThreadStore};
+use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore};
 use assistant_tool::ToolWorkingSet;
 use client::{Client, UserStore};
 use collections::HashMap;
@@ -103,6 +103,7 @@ impl HeadlessAssistant {
             ThreadEvent::ToolFinished {
                 tool_use_id,
                 pending_tool_use,
+                ..
             } => {
                 if let Some(pending_tool_use) = pending_tool_use {
                     println!(
@@ -121,9 +122,8 @@ impl HeadlessAssistant {
                     let model_registry = LanguageModelRegistry::read_global(cx);
                     if let Some(model) = model_registry.active_model() {
                         thread.update(cx, |thread, cx| {
-                            // Currently evals do not support specifying context.
-                            let updated_context = vec![];
-                            thread.send_tool_results_to_model(model, updated_context, cx);
+                            thread.attach_tool_results(vec![], cx);
+                            thread.send_to_model(model, RequestKind::Chat, cx);
                         });
                     }
                 }