agent_ui: Enable the message queue feature for external agents (#47379)

Danilo Leal created

Closes https://github.com/zed-industries/zed/issues/47330

This PR moves the queue logic out of the native Thread into the shared
UI layer (thread view) to enable it for external agents. There's a key
difference in behavior, though, between native and external agents:
queued messages in the former will be sent in the next turn boundary,
given we can easily tell this, whereas for the latter, queued messages
will be sent by the end of the generation. We'd need an ACP-level change
to provide exactly the same UX between both types of agents, and I
figured that's better to have _some_ version of the feature for external
agents as opposed to not having it all due to this difference.

Release Notes:

- Agent: Made the message queue feature available for external agents as
well.

Change summary

crates/agent/src/tests/mod.rs          |  21 +--
crates/agent/src/thread.rs             |  66 +---------
crates/agent_ui/src/acp/thread_view.rs | 157 ++++++++++++++++-----------
3 files changed, 111 insertions(+), 133 deletions(-)

Detailed changes

crates/agent/src/tests/mod.rs 🔗

@@ -5703,14 +5703,9 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
     fake_model
         .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
 
-    // Queue a message before ending the stream
+    // Signal that a message is queued before ending the stream
     thread.update(cx, |thread, _cx| {
-        thread.queue_message(
-            vec![acp::ContentBlock::Text(acp::TextContent::new(
-                "This is my queued message".to_string(),
-            ))],
-            vec![],
-        );
+        thread.set_has_queued_message(true);
     });
 
     // Now end the stream - tool will run, and the boundary check should see the queue
@@ -5741,14 +5736,12 @@ async fn test_queued_message_ends_turn_at_boundary(cx: &mut TestAppContext) {
         "Turn should have ended after tool completion due to queued message"
     );
 
-    // Verify the queued message is still there
+    // Verify the queued message flag is still set
     thread.update(cx, |thread, _cx| {
-        let queued = thread.queued_messages();
-        assert_eq!(queued.len(), 1, "Should still have one queued message");
-        assert!(matches!(
-            &queued[0].content[0],
-            acp::ContentBlock::Text(t) if t.text == "This is my queued message"
-        ));
+        assert!(
+            thread.has_queued_message(),
+            "Should still have queued message flag set"
+        );
     });
 
     // Thread should be idle now

crates/agent/src/thread.rs 🔗

@@ -31,7 +31,6 @@ use futures::{
 use gpui::{
     App, AppContext, AsyncApp, Context, Entity, EventEmitter, SharedString, Task, WeakEntity,
 };
-use language::Buffer;
 use language_model::{
     LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
     LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
@@ -716,11 +715,6 @@ enum CompletionError {
     Other(#[from] anyhow::Error),
 }
 
-pub struct QueuedMessage {
-    pub content: Vec<acp::ContentBlock>,
-    pub tracked_buffers: Vec<Entity<Buffer>>,
-}
-
 pub struct Thread {
     id: acp::SessionId,
     prompt_id: PromptId,
@@ -735,7 +729,9 @@ pub struct Thread {
     /// Survives across multiple requests as the model performs tool calls and
     /// we run tools, report their results.
     running_turn: Option<RunningTurn>,
-    queued_messages: Vec<QueuedMessage>,
+    /// Flag indicating the UI has a queued message waiting to be sent.
+    /// Used to signal that the turn should end at the next message boundary.
+    has_queued_message: bool,
     pending_message: Option<AgentMessage>,
     tools: BTreeMap<SharedString, Arc<dyn AnyAgentTool>>,
     request_token_usage: HashMap<UserMessageId, language_model::TokenUsage>,
@@ -795,7 +791,7 @@ impl Thread {
             messages: Vec::new(),
             user_store: project.read(cx).user_store(),
             running_turn: None,
-            queued_messages: Vec::new(),
+            has_queued_message: false,
             pending_message: None,
             tools: BTreeMap::default(),
             request_token_usage: HashMap::default(),
@@ -862,7 +858,7 @@ impl Thread {
             messages: Vec::new(),
             user_store: project.read(cx).user_store(),
             running_turn: None,
-            queued_messages: Vec::new(),
+            has_queued_message: false,
             pending_message: None,
             tools,
             request_token_usage: HashMap::default(),
@@ -1060,7 +1056,7 @@ impl Thread {
             messages: db_thread.messages,
             user_store: project.read(cx).user_store(),
             running_turn: None,
-            queued_messages: Vec::new(),
+            has_queued_message: false,
             pending_message: None,
             tools: BTreeMap::default(),
             request_token_usage: db_thread.request_token_usage.clone(),
@@ -1298,52 +1294,12 @@ impl Thread {
         })
     }
 
-    pub fn queue_message(
-        &mut self,
-        content: Vec<acp::ContentBlock>,
-        tracked_buffers: Vec<Entity<Buffer>>,
-    ) {
-        self.queued_messages.push(QueuedMessage {
-            content,
-            tracked_buffers,
-        });
-    }
-
-    pub fn queued_messages(&self) -> &[QueuedMessage] {
-        &self.queued_messages
-    }
-
-    pub fn remove_queued_message(&mut self, index: usize) -> Option<QueuedMessage> {
-        if index < self.queued_messages.len() {
-            Some(self.queued_messages.remove(index))
-        } else {
-            None
-        }
-    }
-
-    pub fn update_queued_message(
-        &mut self,
-        index: usize,
-        content: Vec<acp::ContentBlock>,
-        tracked_buffers: Vec<Entity<Buffer>>,
-    ) -> bool {
-        if index < self.queued_messages.len() {
-            self.queued_messages[index] = QueuedMessage {
-                content,
-                tracked_buffers,
-            };
-            true
-        } else {
-            false
-        }
-    }
-
-    pub fn clear_queued_messages(&mut self) {
-        self.queued_messages.clear();
+    pub fn set_has_queued_message(&mut self, has_queued: bool) {
+        self.has_queued_message = has_queued;
     }
 
-    fn has_queued_messages(&self) -> bool {
-        !self.queued_messages.is_empty()
+    pub fn has_queued_message(&self) -> bool {
+        self.has_queued_message
     }
 
     fn update_token_usage(&mut self, update: language_model::TokenUsage, cx: &mut Context<Self>) {
@@ -1760,7 +1716,7 @@ impl Thread {
             } else if end_turn {
                 return Ok(());
             } else {
-                let has_queued = this.update(cx, |this, _| this.has_queued_messages())?;
+                let has_queued = this.update(cx, |this, _| this.has_queued_message())?;
                 if has_queued {
                     log::debug!("Queued message found, ending turn at message boundary");
                     return Ok(());

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -88,6 +88,11 @@ use crate::{
 const STOPWATCH_THRESHOLD: Duration = Duration::from_secs(30);
 const TOKEN_THRESHOLD: u64 = 250;
 
+pub struct QueuedMessage {
+    pub content: Vec<acp::ContentBlock>,
+    pub tracked_buffers: Vec<Entity<Buffer>>,
+}
+
 #[derive(Copy, Clone, Debug, PartialEq, Eq)]
 enum ThreadFeedback {
     Positive,
@@ -364,6 +369,7 @@ pub struct AcpThreadView {
     editor_expanded: bool,
     should_be_following: bool,
     editing_message: Option<usize>,
+    local_queued_messages: Vec<QueuedMessage>,
     queued_message_editors: Vec<Entity<MessageEditor>>,
     queued_message_editor_subscriptions: Vec<Subscription>,
     last_synced_queue_length: usize,
@@ -596,6 +602,7 @@ impl AcpThreadView {
             expanded_subagents: HashSet::default(),
             subagent_scroll_handles: RefCell::new(HashMap::default()),
             editing_message: None,
+            local_queued_messages: Vec::new(),
             queued_message_editors: Vec::new(),
             queued_message_editor_subscriptions: Vec::new(),
             last_synced_queue_length: 0,
@@ -1399,9 +1406,7 @@ impl AcpThreadView {
         let is_editor_empty = self.message_editor.read(cx).is_empty(cx);
         let is_generating = thread.read(cx).status() != ThreadStatus::Idle;
 
-        let has_queued = self
-            .as_native_thread(cx)
-            .is_some_and(|t| !t.read(cx).queued_messages().is_empty());
+        let has_queued = self.has_queued_messages();
         if is_editor_empty && self.can_fast_track_queue && has_queued {
             self.can_fast_track_queue = false;
             self.send_queued_message_at_index(0, true, window, cx);
@@ -1737,11 +1742,7 @@ impl AcpThreadView {
             }
 
             this.update_in(cx, |this, window, cx| {
-                if let Some(thread) = this.as_native_thread(cx) {
-                    thread.update(cx, |thread, _| {
-                        thread.queue_message(content, tracked_buffers);
-                    });
-                }
+                this.add_to_queue(content, tracked_buffers, cx);
                 // Enable fast-track: user can press Enter again to send this queued message immediately
                 this.can_fast_track_queue = true;
                 message_editor.update(cx, |message_editor, cx| {
@@ -1761,13 +1762,7 @@ impl AcpThreadView {
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
-        let Some(native_thread) = self.as_native_thread(cx) else {
-            return;
-        };
-
-        let Some(queued) =
-            native_thread.update(cx, |thread, _| thread.remove_queued_message(index))
-        else {
+        let Some(queued) = self.remove_from_queue(index, cx) else {
             return;
         };
         let content = queued.content;
@@ -2062,9 +2057,7 @@ impl AcpThreadView {
                     // Reset the flag so future completions can process normally.
                     self.user_interrupted_generation = false;
                 } else {
-                    let has_queued = self
-                        .as_native_thread(cx)
-                        .is_some_and(|t| !t.read(cx).queued_messages().is_empty());
+                    let has_queued = self.has_queued_messages();
                     // Don't auto-send if the first message editor is currently focused
                     let is_first_editor_focused = self
                         .queued_message_editors
@@ -5197,9 +5190,7 @@ impl AcpThreadView {
         let telemetry = ActionLogTelemetry::from(thread);
         let changed_buffers = action_log.read(cx).changed_buffers(cx);
         let plan = thread.plan();
-        let queue_is_empty = self
-            .as_native_thread(cx)
-            .map_or(true, |t| t.read(cx).queued_messages().is_empty());
+        let queue_is_empty = !self.has_queued_messages();
 
         if changed_buffers.is_empty() && plan.is_empty() && queue_is_empty {
             return None;
@@ -5876,9 +5867,7 @@ impl AcpThreadView {
         _window: &mut Window,
         cx: &Context<Self>,
     ) -> impl IntoElement {
-        let queue_count = self
-            .as_native_thread(cx)
-            .map_or(0, |t| t.read(cx).queued_messages().len());
+        let queue_count = self.queued_messages_len();
         let title: SharedString = if queue_count == 1 {
             "1 Queued Message".into()
         } else {
@@ -5909,9 +5898,7 @@ impl AcpThreadView {
                     .label_size(LabelSize::Small)
                     .key_binding(KeyBinding::for_action(&ClearMessageQueue, cx))
                     .on_click(cx.listener(|this, _, _, cx| {
-                        if let Some(thread) = this.as_native_thread(cx) {
-                            thread.update(cx, |thread, _| thread.clear_queued_messages());
-                        }
+                        this.clear_queue(cx);
                         this.can_fast_track_queue = false;
                         cx.notify();
                     })),
@@ -6086,11 +6073,7 @@ impl AcpThreadView {
                                                 }
                                             })
                                             .on_click(cx.listener(move |this, _, _, cx| {
-                                                if let Some(thread) = this.as_native_thread(cx) {
-                                                    thread.update(cx, |thread, _| {
-                                                        thread.remove_queued_message(index);
-                                                    });
-                                                }
+                                                this.remove_from_queue(index, cx);
                                                 cx.notify();
                                             })),
                                     )
@@ -6245,15 +6228,83 @@ impl AcpThreadView {
             .thread(acp_thread.session_id(), cx)
     }
 
+    fn queued_messages_len(&self) -> usize {
+        self.local_queued_messages.len()
+    }
+
+    fn has_queued_messages(&self) -> bool {
+        !self.local_queued_messages.is_empty()
+    }
+
+    /// Syncs the has_queued_message flag to the native thread (if applicable).
+    /// This flag tells the native thread to end its turn at the next message boundary.
+    fn sync_queue_flag_to_native_thread(&self, cx: &mut Context<Self>) {
+        if let Some(native_thread) = self.as_native_thread(cx) {
+            let has_queued = !self.local_queued_messages.is_empty();
+            native_thread.update(cx, |thread, _| {
+                thread.set_has_queued_message(has_queued);
+            });
+        }
+    }
+
+    fn add_to_queue(
+        &mut self,
+        content: Vec<acp::ContentBlock>,
+        tracked_buffers: Vec<Entity<Buffer>>,
+        cx: &mut Context<Self>,
+    ) {
+        self.local_queued_messages.push(QueuedMessage {
+            content,
+            tracked_buffers,
+        });
+        self.sync_queue_flag_to_native_thread(cx);
+    }
+
+    fn remove_from_queue(&mut self, index: usize, cx: &mut Context<Self>) -> Option<QueuedMessage> {
+        if index < self.local_queued_messages.len() {
+            let removed = self.local_queued_messages.remove(index);
+            self.sync_queue_flag_to_native_thread(cx);
+            Some(removed)
+        } else {
+            None
+        }
+    }
+
+    fn update_queued_message(
+        &mut self,
+        index: usize,
+        content: Vec<acp::ContentBlock>,
+        tracked_buffers: Vec<Entity<Buffer>>,
+        _cx: &mut Context<Self>,
+    ) -> bool {
+        if index < self.local_queued_messages.len() {
+            self.local_queued_messages[index] = QueuedMessage {
+                content,
+                tracked_buffers,
+            };
+            true
+        } else {
+            false
+        }
+    }
+
+    fn clear_queue(&mut self, cx: &mut Context<Self>) {
+        self.local_queued_messages.clear();
+        self.sync_queue_flag_to_native_thread(cx);
+    }
+
+    fn queued_message_contents(&self) -> Vec<Vec<acp::ContentBlock>> {
+        self.local_queued_messages
+            .iter()
+            .map(|q| q.content.clone())
+            .collect()
+    }
+
     fn save_queued_message_at_index(&mut self, index: usize, cx: &mut Context<Self>) {
         let Some(editor) = self.queued_message_editors.get(index) else {
             return;
         };
 
-        let Some(_native_thread) = self.as_native_thread(cx) else {
-            return;
-        };
-
         let contents_task = editor.update(cx, |editor, cx| editor.contents(false, cx));
 
         cx.spawn(async move |this, cx| {
@@ -6262,11 +6313,7 @@ impl AcpThreadView {
             };
 
             this.update(cx, |this, cx| {
-                if let Some(native_thread) = this.as_native_thread(cx) {
-                    native_thread.update(cx, |thread, _| {
-                        thread.update_queued_message(index, content, tracked_buffers);
-                    });
-                }
+                this.update_queued_message(index, content, tracked_buffers, cx);
                 cx.notify();
             })?;
 
@@ -6276,26 +6323,14 @@ impl AcpThreadView {
     }
 
     fn sync_queued_message_editors(&mut self, window: &mut Window, cx: &mut Context<Self>) {
-        let Some(native_thread) = self.as_native_thread(cx) else {
-            self.queued_message_editors.clear();
-            self.queued_message_editor_subscriptions.clear();
-            self.last_synced_queue_length = 0;
-            return;
-        };
-
-        let thread = native_thread.read(cx);
-        let needed_count = thread.queued_messages().len();
+        let needed_count = self.queued_messages_len();
         let current_count = self.queued_message_editors.len();
 
         if current_count == needed_count && needed_count == self.last_synced_queue_length {
             return;
         }
 
-        let queued_messages: Vec<_> = thread
-            .queued_messages()
-            .iter()
-            .map(|q| q.content.clone())
-            .collect();
+        let queued_messages = self.queued_message_contents();
 
         if current_count > needed_count {
             self.queued_message_editors.truncate(needed_count);
@@ -8348,12 +8383,8 @@ impl Render for AcpThreadView {
                 this.send_queued_message_at_index(0, true, window, cx);
             }))
             .on_action(cx.listener(|this, _: &RemoveFirstQueuedMessage, _, cx| {
-                if let Some(thread) = this.as_native_thread(cx) {
-                    thread.update(cx, |thread, _| {
-                        thread.remove_queued_message(0);
-                    });
-                    cx.notify();
-                }
+                this.remove_from_queue(0, cx);
+                cx.notify();
             }))
             .on_action(cx.listener(|this, _: &EditFirstQueuedMessage, window, cx| {
                 if let Some(editor) = this.queued_message_editors.first() {
@@ -8361,9 +8392,7 @@ impl Render for AcpThreadView {
                 }
             }))
             .on_action(cx.listener(|this, _: &ClearMessageQueue, _, cx| {
-                if let Some(thread) = this.as_native_thread(cx) {
-                    thread.update(cx, |thread, _| thread.clear_queued_messages());
-                }
+                this.clear_queue(cx);
                 this.can_fast_track_queue = false;
                 cx.notify();
             }))