diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index e3e55333b638c2c8040d3291ef3d21fa546a7027..3ab8e0fabf1e8224326dbf982c431cc26c6e01d2 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -116,7 +116,7 @@ impl ActiveThread { pub fn cancel_last_completion(&mut self, cx: &mut App) -> bool { self.last_error.take(); self.thread - .update(cx, |thread, _cx| thread.cancel_last_completion()) + .update(cx, |thread, cx| thread.cancel_last_completion(cx)) } pub fn last_error(&self) -> Option { @@ -343,8 +343,11 @@ impl ActiveThread { }); } ThreadEvent::ToolFinished { - pending_tool_use, .. + pending_tool_use, + canceled, + .. } => { + let canceled = *canceled; if let Some(tool_use) = pending_tool_use { self.render_scripting_tool_use_markdown( tool_use.id.clone(), @@ -396,7 +399,10 @@ impl ActiveThread { this.update(&mut cx, |this, cx| { this.thread.update(cx, |thread, cx| { - thread.send_tool_results_to_model(model, updated_context, cx); + thread.attach_tool_results(updated_context, cx); + if !canceled { + thread.send_to_model(model, RequestKind::Chat, cx); + } }); }) }) diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 52e9b56ed23d50f79c5f212306421d1e67d1bbe4..9917982457bbaff657a4a3b22f4d041d38d5934a 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -158,7 +158,7 @@ impl MessageEditor { return; } - if self.thread.read(cx).is_streaming() { + if self.thread.read(cx).is_generating() { return; } @@ -328,7 +328,7 @@ impl Render for MessageEditor { let focus_handle = self.editor.focus_handle(cx); let inline_context_picker = self.inline_context_picker.clone(); let bg_color = cx.theme().colors().editor_background; - let is_streaming_completion = self.thread.read(cx).is_streaming(); + let is_generating = self.thread.read(cx).is_generating(); let is_model_selected = self.is_model_selected(cx); let is_editor_empty = self.is_editor_empty(cx); let submit_label_color = if is_editor_empty { @@ -352,7 +352,7 @@ impl Render for MessageEditor { v_flex() .size_full() - .when(is_streaming_completion, |parent| { + .when(is_generating, |parent| { let focus_handle = self.editor.focus_handle(cx).clone(); parent.child( h_flex().py_3().w_full().justify_center().child( @@ -625,7 +625,7 @@ impl Render for MessageEditor { .disabled( is_editor_empty || !is_model_selected - || is_streaming_completion, + || is_generating, ) .child( h_flex() @@ -660,7 +660,7 @@ impl Render for MessageEditor { "Type a message to submit", )) }) - .when(is_streaming_completion, |button| { + .when(is_generating, |button| { button.tooltip(Tooltip::text( "Cancel to submit a new message", )) diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index b9338c6e6a7185b7b6316befe12028f9d9e8c92f..ddd6732446e86bb3436d0b0f47a5852d436165dd 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -240,7 +240,7 @@ impl Thread { self.messages.iter() } - pub fn is_streaming(&self) -> bool { + pub fn is_generating(&self) -> bool { !self.pending_completions.is_empty() || !self.all_tools_finished() } @@ -267,8 +267,8 @@ impl Thread { .into_iter() .chain(self.scripting_tool_use.pending_tool_uses()); - // If the only pending tool uses left are the ones with errors, then that means that we've finished running all - // of the pending tools. + // If the only pending tool uses left are the ones with errors, then + // that means that we've finished running all of the pending tools. all_pending_tool_uses.all(|tool_use| tool_use.status.is_error()) } @@ -683,7 +683,7 @@ impl Thread { ))); } - thread.cancel_last_completion(); + thread.cancel_last_completion(cx); } } cx.emit(ThreadEvent::DoneStreaming); @@ -833,6 +833,7 @@ impl Thread { cx.emit(ThreadEvent::ToolFinished { tool_use_id, pending_tool_use, + canceled: false, }); }) .ok(); @@ -862,6 +863,7 @@ impl Thread { cx.emit(ThreadEvent::ToolFinished { tool_use_id, pending_tool_use, + canceled: false, }); }) .ok(); @@ -872,9 +874,8 @@ impl Thread { .run_pending_tool(tool_use_id, insert_output_task); } - pub fn send_tool_results_to_model( + pub fn attach_tool_results( &mut self, - model: Arc, updated_context: Vec, cx: &mut Context, ) { @@ -893,17 +894,25 @@ impl Thread { Vec::new(), cx, ); - self.send_to_model(model, RequestKind::Chat, cx); } /// Cancels the last pending completion, if there are any pending. /// /// Returns whether a completion was canceled. - pub fn cancel_last_completion(&mut self) -> bool { - if let Some(_last_completion) = self.pending_completions.pop() { + pub fn cancel_last_completion(&mut self, cx: &mut Context) -> bool { + if self.pending_completions.pop().is_some() { true } else { - false + let mut canceled = false; + for pending_tool_use in self.tool_use.cancel_pending() { + canceled = true; + cx.emit(ThreadEvent::ToolFinished { + tool_use_id: pending_tool_use.id.clone(), + pending_tool_use: Some(pending_tool_use), + canceled: true, + }); + } + canceled } } @@ -1114,6 +1123,8 @@ pub enum ThreadEvent { tool_use_id: LanguageModelToolUseId, /// The pending tool use that corresponds to this tool. pending_tool_use: Option, + /// Whether the tool was canceled by the user. + canceled: bool, }, } diff --git a/crates/assistant2/src/tool_use.rs b/crates/assistant2/src/tool_use.rs index 5170e7743875b22046d3b289f72d5e4f8ac373b0..2045aa9b7361f96743537be8f3c90179a548a042 100644 --- a/crates/assistant2/src/tool_use.rs +++ b/crates/assistant2/src/tool_use.rs @@ -118,6 +118,22 @@ impl ToolUseState { this } + pub fn cancel_pending(&mut self) -> Vec { + let mut pending_tools = Vec::new(); + for (tool_use_id, tool_use) in self.pending_tool_uses_by_id.drain() { + self.tool_results.insert( + tool_use_id.clone(), + LanguageModelToolResult { + tool_use_id, + content: "Tool canceled by user".into(), + is_error: true, + }, + ); + pending_tools.push(tool_use.clone()); + } + pending_tools + } + pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> { self.pending_tool_uses_by_id.values().collect() } diff --git a/crates/assistant_eval/src/headless_assistant.rs b/crates/assistant_eval/src/headless_assistant.rs index 008b4f63da3f4e622d0b9aaf83650163eb08d127..827f3e6a9e1846616e1be9f16883245cc2415541 100644 --- a/crates/assistant_eval/src/headless_assistant.rs +++ b/crates/assistant_eval/src/headless_assistant.rs @@ -1,5 +1,5 @@ use anyhow::anyhow; -use assistant2::{Thread, ThreadEvent, ThreadStore}; +use assistant2::{RequestKind, Thread, ThreadEvent, ThreadStore}; use assistant_tool::ToolWorkingSet; use client::{Client, UserStore}; use collections::HashMap; @@ -103,6 +103,7 @@ impl HeadlessAssistant { ThreadEvent::ToolFinished { tool_use_id, pending_tool_use, + .. } => { if let Some(pending_tool_use) = pending_tool_use { println!( @@ -121,9 +122,8 @@ impl HeadlessAssistant { let model_registry = LanguageModelRegistry::read_global(cx); if let Some(model) = model_registry.active_model() { thread.update(cx, |thread, cx| { - // Currently evals do not support specifying context. - let updated_context = vec![]; - thread.send_tool_results_to_model(model, updated_context, cx); + thread.attach_tool_results(vec![], cx); + thread.send_to_model(model, RequestKind::Chat, cx); }); } }