diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index a4706f6a752b0ae2fd251320106da998819b0b47..1fe3df375fa1bdecb906f1b963e71a3f0cecfd56 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -5703,14 +5703,9 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) { fake_model .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)); - // Queue a message before ending the stream + // Signal that a message is queued before ending the stream thread.update(cx, |thread, _cx| { - thread.queue_message( - vec![acp::ContentBlock::Text(acp::TextContent::new( - "This is my queued message".to_string(), - ))], - vec![], - ); + thread.set_has_queued_message(true); }); // Now end the stream - tool will run, and the boundary check should see the queue @@ -5741,14 +5736,12 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) { "Turn should have ended after tool completion due to queued message" ); - // Verify the queued message is still there + // Verify the queued message flag is still set thread.update(cx, |thread, _cx| { - let queued = thread.queued_messages(); - assert_eq!(queued.len(), 1, "Should still have one queued message"); - assert!(matches!( - &queued[0].content[0], - acp::ContentBlock::Text(t) if t.text == "This is my queued message" - )); + assert!( + thread.has_queued_message(), + "Should still have queued message flag set" + ); }); // Thread should be idle now diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index dbcbe8eda0358ff71997dfe695a871efad954ac6..80f0829d3ed0706aaf06ff34e9dea586c66a5842 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -31,7 +31,6 @@ use futures::{ use gpui::{ App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity, }; -use language::Buffer; use language_model::{ LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest, @@ -716,11 +715,6 @@ enum CompletionError { Other(#[from] anyhow::Error), } -pub struct QueuedMessage { - pub content: Vec, - pub tracked_buffers: Vec>, -} - pub struct Thread { id: acp::SessionId, prompt_id: PromptId, @@ -735,7 +729,9 @@ pub struct Thread { /// Survives across multiple requests as the model performs tool calls and /// we run tools, report their results. running_turn: Option, - queued_messages: Vec, + /// Flag indicating the UI has a queued message waiting to be sent. + /// Used to signal that the turn should end at the next message boundary. + has_queued_message: bool, pending_message: Option, tools: BTreeMap>, request_token_usage: HashMap, @@ -795,7 +791,7 @@ impl Thread { messages: Vec::new(), user_store: project.read(cx).user_store(), running_turn: None, - queued_messages: Vec::new(), + has_queued_message: false, pending_message: None, tools: BTreeMap::default(), request_token_usage: HashMap::default(), @@ -862,7 +858,7 @@ impl Thread { messages: Vec::new(), user_store: project.read(cx).user_store(), running_turn: None, - queued_messages: Vec::new(), + has_queued_message: false, pending_message: None, tools, request_token_usage: HashMap::default(), @@ -1060,7 +1056,7 @@ impl Thread { messages: db_thread.messages, user_store: project.read(cx).user_store(), running_turn: None, - queued_messages: Vec::new(), + has_queued_message: false, pending_message: None, tools: BTreeMap::default(), request_token_usage: db_thread.request_token_usage.clone(), @@ -1298,52 +1294,12 @@ impl Thread { }) } - pub fn queue_message( - &mut self, - content: Vec, - tracked_buffers: Vec>, - ) { - self.queued_messages.push(QueuedMessage { - content, - tracked_buffers, - }); - } - - pub fn queued_messages(&self) -> &[QueuedMessage] { - &self.queued_messages - } - - pub fn remove_queued_message(&mut self, index: usize) -> Option { - if index < self.queued_messages.len() { - Some(self.queued_messages.remove(index)) - } else { - None - } - } - - pub fn update_queued_message( - &mut self, - index: usize, - content: Vec, - tracked_buffers: Vec>, - ) -> bool { - if index < self.queued_messages.len() { - self.queued_messages[index] = QueuedMessage { - content, - tracked_buffers, - }; - true - } else { - false - } - } - - pub fn clear_queued_messages(&mut self) { - self.queued_messages.clear(); + pub fn set_has_queued_message(&mut self, has_queued: bool) { + self.has_queued_message = has_queued; } - fn has_queued_messages(&self) -> bool { - !self.queued_messages.is_empty() + pub fn has_queued_message(&self) -> bool { + self.has_queued_message } fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context) { @@ -1760,7 +1716,7 @@ impl Thread { } else if end_turn { return Ok(()); } else { - let has_queued = this.update(cx, |this, _| this.has_queued_messages())?; + let has_queued = this.update(cx, |this, _| this.has_queued_message())?; if has_queued { log::debug!("Queued message found, ending turn at message boundary"); return Ok(()); diff --git a/crates/agent_ui/src/acp/thread_view.rs b/crates/agent_ui/src/acp/thread_view.rs index e2aa6547539390f3ed018e220b12a212e494b680..07444999d1402d15dfcb8d563a8fcdc82764d7a9 100644 --- a/crates/agent_ui/src/acp/thread_view.rs +++ b/crates/agent_ui/src/acp/thread_view.rs @@ -88,6 +88,11 @@ use crate::{ const STOPWATCH_THRESHOLD: Duration = Duration::from_secs(30); const TOKEN_THRESHOLD: u64 = 250; +pub struct QueuedMessage { + pub content: Vec, + pub tracked_buffers: Vec>, +} + #[derive(Copy, Clone, Debug, PartialEq, Eq)] enum ThreadFeedback { Positive, @@ -364,6 +369,7 @@ pub struct AcpThreadView { editor_expanded: bool, should_be_following: bool, editing_message: Option, + local_queued_messages: Vec, queued_message_editors: Vec>, queued_message_editor_subscriptions: Vec, last_synced_queue_length: usize, @@ -596,6 +602,7 @@ impl AcpThreadView { expanded_subagents: HashSet::default(), subagent_scroll_handles: RefCell::new(HashMap::default()), editing_message: None, + local_queued_messages: Vec::new(), queued_message_editors: Vec::new(), queued_message_editor_subscriptions: Vec::new(), last_synced_queue_length: 0, @@ -1399,9 +1406,7 @@ impl AcpThreadView { let is_editor_empty = self.message_editor.read(cx).is_empty(cx); let is_generating = thread.read(cx).status() != ThreadStatus::Idle; - let has_queued = self - .as_native_thread(cx) - .is_some_and(|t| !t.read(cx).queued_messages().is_empty()); + let has_queued = self.has_queued_messages(); if is_editor_empty && self.can_fast_track_queue && has_queued { self.can_fast_track_queue = false; self.send_queued_message_at_index(0, true, window, cx); @@ -1737,11 +1742,7 @@ impl AcpThreadView { } this.update_in(cx, |this, window, cx| { - if let Some(thread) = this.as_native_thread(cx) { - thread.update(cx, |thread, _| { - thread.queue_message(content, tracked_buffers); - }); - } + this.add_to_queue(content, tracked_buffers, cx); // Enable fast-track: user can press Enter again to send this queued message immediately this.can_fast_track_queue = true; message_editor.update(cx, |message_editor, cx| { @@ -1761,13 +1762,7 @@ impl AcpThreadView { window: &mut Window, cx: &mut Context, ) { - let Some(native_thread) = self.as_native_thread(cx) else { - return; - }; - - let Some(queued) = - native_thread.update(cx, |thread, _| thread.remove_queued_message(index)) - else { + let Some(queued) = self.remove_from_queue(index, cx) else { return; }; let content = queued.content; @@ -2062,9 +2057,7 @@ impl AcpThreadView { // Reset the flag so future completions can process normally. self.user_interrupted_generation = false; } else { - let has_queued = self - .as_native_thread(cx) - .is_some_and(|t| !t.read(cx).queued_messages().is_empty()); + let has_queued = self.has_queued_messages(); // Don't auto-send if the first message editor is currently focused let is_first_editor_focused = self .queued_message_editors @@ -5197,9 +5190,7 @@ impl AcpThreadView { let telemetry = ActionLogTelemetry::from(thread); let changed_buffers = action_log.read(cx).changed_buffers(cx); let plan = thread.plan(); - let queue_is_empty = self - .as_native_thread(cx) - .map_or(true, |t| t.read(cx).queued_messages().is_empty()); + let queue_is_empty = !self.has_queued_messages(); if changed_buffers.is_empty() && plan.is_empty() && queue_is_empty { return None; @@ -5876,9 +5867,7 @@ impl AcpThreadView { _window: &mut Window, cx: &Context, ) -> impl IntoElement { - let queue_count = self - .as_native_thread(cx) - .map_or(0, |t| t.read(cx).queued_messages().len()); + let queue_count = self.queued_messages_len(); let title: SharedString = if queue_count == 1 { "1 Queued Message".into() } else { @@ -5909,9 +5898,7 @@ impl AcpThreadView { .label_size(LabelSize::Small) .key_binding(KeyBinding::for_action(&ClearMessageQueue, cx)) .on_click(cx.listener(|this, _, _, cx| { - if let Some(thread) = this.as_native_thread(cx) { - thread.update(cx, |thread, _| thread.clear_queued_messages()); - } + this.clear_queue(cx); this.can_fast_track_queue = false; cx.notify(); })), @@ -6086,11 +6073,7 @@ impl AcpThreadView { } }) .on_click(cx.listener(move |this, _, _, cx| { - if let Some(thread) = this.as_native_thread(cx) { - thread.update(cx, |thread, _| { - thread.remove_queued_message(index); - }); - } + this.remove_from_queue(index, cx); cx.notify(); })), ) @@ -6245,15 +6228,83 @@ impl AcpThreadView { .thread(acp_thread.session_id(), cx) } + fn queued_messages_len(&self) -> usize { + self.local_queued_messages.len() + } + + fn has_queued_messages(&self) -> bool { + !self.local_queued_messages.is_empty() + } + + /// Syncs the has_queued_message flag to the native thread (if applicable). + /// This flag tells the native thread to end its turn at the next message boundary. + fn sync_queue_flag_to_native_thread(&self, cx: &mut Context) { + if let Some(native_thread) = self.as_native_thread(cx) { + let has_queued = !self.local_queued_messages.is_empty(); + native_thread.update(cx, |thread, _| { + thread.set_has_queued_message(has_queued); + }); + } + } + + fn add_to_queue( + &mut self, + content: Vec, + tracked_buffers: Vec>, + cx: &mut Context, + ) { + self.local_queued_messages.push(QueuedMessage { + content, + tracked_buffers, + }); + self.sync_queue_flag_to_native_thread(cx); + } + + fn remove_from_queue(&mut self, index: usize, cx: &mut Context) -> Option { + if index < self.local_queued_messages.len() { + let removed = self.local_queued_messages.remove(index); + self.sync_queue_flag_to_native_thread(cx); + Some(removed) + } else { + None + } + } + + fn update_queued_message( + &mut self, + index: usize, + content: Vec, + tracked_buffers: Vec>, + _cx: &mut Context, + ) -> bool { + if index < self.local_queued_messages.len() { + self.local_queued_messages[index] = QueuedMessage { + content, + tracked_buffers, + }; + true + } else { + false + } + } + + fn clear_queue(&mut self, cx: &mut Context) { + self.local_queued_messages.clear(); + self.sync_queue_flag_to_native_thread(cx); + } + + fn queued_message_contents(&self) -> Vec> { + self.local_queued_messages + .iter() + .map(|q| q.content.clone()) + .collect() + } + fn save_queued_message_at_index(&mut self, index: usize, cx: &mut Context) { let Some(editor) = self.queued_message_editors.get(index) else { return; }; - let Some(_native_thread) = self.as_native_thread(cx) else { - return; - }; - let contents_task = editor.update(cx, |editor, cx| editor.contents(false, cx)); cx.spawn(async move |this, cx| { @@ -6262,11 +6313,7 @@ impl AcpThreadView { }; this.update(cx, |this, cx| { - if let Some(native_thread) = this.as_native_thread(cx) { - native_thread.update(cx, |thread, _| { - thread.update_queued_message(index, content, tracked_buffers); - }); - } + this.update_queued_message(index, content, tracked_buffers, cx); cx.notify(); })?; @@ -6276,26 +6323,14 @@ impl AcpThreadView { } fn sync_queued_message_editors(&mut self, window: &mut Window, cx: &mut Context) { - let Some(native_thread) = self.as_native_thread(cx) else { - self.queued_message_editors.clear(); - self.queued_message_editor_subscriptions.clear(); - self.last_synced_queue_length = 0; - return; - }; - - let thread = native_thread.read(cx); - let needed_count = thread.queued_messages().len(); + let needed_count = self.queued_messages_len(); let current_count = self.queued_message_editors.len(); if current_count == needed_count && needed_count == self.last_synced_queue_length { return; } - let queued_messages: Vec<_> = thread - .queued_messages() - .iter() - .map(|q| q.content.clone()) - .collect(); + let queued_messages = self.queued_message_contents(); if current_count > needed_count { self.queued_message_editors.truncate(needed_count); @@ -8348,12 +8383,8 @@ impl Render for AcpThreadView { this.send_queued_message_at_index(0, true, window, cx); })) .on_action(cx.listener(|this, _: &RemoveFirstQueuedMessage, _, cx| { - if let Some(thread) = this.as_native_thread(cx) { - thread.update(cx, |thread, _| { - thread.remove_queued_message(0); - }); - cx.notify(); - } + this.remove_from_queue(0, cx); + cx.notify(); })) .on_action(cx.listener(|this, _: &EditFirstQueuedMessage, window, cx| { if let Some(editor) = this.queued_message_editors.first() { @@ -8361,9 +8392,7 @@ impl Render for AcpThreadView { } })) .on_action(cx.listener(|this, _: &ClearMessageQueue, _, cx| { - if let Some(thread) = this.as_native_thread(cx) { - thread.update(cx, |thread, _| thread.clear_queued_messages()); - } + this.clear_queue(cx); this.can_fast_track_queue = false; cx.notify(); }))