acp: Stay in edit mode when current completion ends (#36413)

Agus Zubiaga created

When a turn ends and the checkpoint is updated, `AcpThread` emits
`EntryUpdated` with the index of the user message. This was causing the
message editor to be recreated and, therefore, lose focus.

Release Notes:

- N/A

Change summary

crates/acp_thread/src/acp_thread.rs         |   1 
crates/acp_thread/src/connection.rs         | 119 +++++++++++++++-------
crates/agent_ui/src/acp/entry_view_state.rs |  66 +++++++-----
crates/agent_ui/src/acp/thread_view.rs      |  96 +++++++++++++++++
4 files changed, 213 insertions(+), 69 deletions(-)

Detailed changes

crates/acp_thread/src/connection.rs 🔗

@@ -186,7 +186,7 @@ mod test_support {
     use std::sync::Arc;
 
     use collections::HashMap;
-    use futures::future::try_join_all;
+    use futures::{channel::oneshot, future::try_join_all};
     use gpui::{AppContext as _, WeakEntity};
     use parking_lot::Mutex;
 
@@ -194,11 +194,16 @@ mod test_support {
 
     #[derive(Clone, Default)]
     pub struct StubAgentConnection {
-        sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+        sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
         permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
         next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
     }
 
+    struct Session {
+        thread: WeakEntity<AcpThread>,
+        response_tx: Option<oneshot::Sender<()>>,
+    }
+
     impl StubAgentConnection {
         pub fn new() -> Self {
             Self {
@@ -226,15 +231,33 @@ mod test_support {
             update: acp::SessionUpdate,
             cx: &mut App,
         ) {
+            assert!(
+                self.next_prompt_updates.lock().is_empty(),
+                "Use either send_update or set_next_prompt_updates"
+            );
+
             self.sessions
                 .lock()
                 .get(&session_id)
                 .unwrap()
+                .thread
                 .update(cx, |thread, cx| {
                     thread.handle_session_update(update.clone(), cx).unwrap();
                 })
                 .unwrap();
         }
+
+        pub fn end_turn(&self, session_id: acp::SessionId) {
+            self.sessions
+                .lock()
+                .get_mut(&session_id)
+                .unwrap()
+                .response_tx
+                .take()
+                .expect("No pending turn")
+                .send(())
+                .unwrap();
+        }
     }
 
     impl AgentConnection for StubAgentConnection {
@@ -251,7 +274,13 @@ mod test_support {
             let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
             let thread =
                 cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
-            self.sessions.lock().insert(session_id, thread.downgrade());
+            self.sessions.lock().insert(
+                session_id,
+                Session {
+                    thread: thread.downgrade(),
+                    response_tx: None,
+                },
+            );
             Task::ready(Ok(thread))
         }
 
@@ -269,43 +298,59 @@ mod test_support {
             params: acp::PromptRequest,
             cx: &mut App,
         ) -> Task<gpui::Result<acp::PromptResponse>> {
-            let sessions = self.sessions.lock();
-            let thread = sessions.get(&params.session_id).unwrap();
+            let mut sessions = self.sessions.lock();
+            let Session {
+                thread,
+                response_tx,
+            } = sessions.get_mut(&params.session_id).unwrap();
             let mut tasks = vec![];
-            for update in self.next_prompt_updates.lock().drain(..) {
-                let thread = thread.clone();
-                let update = update.clone();
-                let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
-                    && let Some(options) = self.permission_requests.get(&tool_call.id)
-                {
-                    Some((tool_call.clone(), options.clone()))
-                } else {
-                    None
-                };
-                let task = cx.spawn(async move |cx| {
-                    if let Some((tool_call, options)) = permission_request {
-                        let permission = thread.update(cx, |thread, cx| {
-                            thread.request_tool_call_authorization(
-                                tool_call.clone().into(),
-                                options.clone(),
-                                cx,
-                            )
+            if self.next_prompt_updates.lock().is_empty() {
+                let (tx, rx) = oneshot::channel();
+                response_tx.replace(tx);
+                cx.spawn(async move |_| {
+                    rx.await?;
+                    Ok(acp::PromptResponse {
+                        stop_reason: acp::StopReason::EndTurn,
+                    })
+                })
+            } else {
+                for update in self.next_prompt_updates.lock().drain(..) {
+                    let thread = thread.clone();
+                    let update = update.clone();
+                    let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
+                        &update
+                        && let Some(options) = self.permission_requests.get(&tool_call.id)
+                    {
+                        Some((tool_call.clone(), options.clone()))
+                    } else {
+                        None
+                    };
+                    let task = cx.spawn(async move |cx| {
+                        if let Some((tool_call, options)) = permission_request {
+                            let permission = thread.update(cx, |thread, cx| {
+                                thread.request_tool_call_authorization(
+                                    tool_call.clone().into(),
+                                    options.clone(),
+                                    cx,
+                                )
+                            })?;
+                            permission?.await?;
+                        }
+                        thread.update(cx, |thread, cx| {
+                            thread.handle_session_update(update.clone(), cx).unwrap();
                         })?;
-                        permission?.await?;
-                    }
-                    thread.update(cx, |thread, cx| {
-                        thread.handle_session_update(update.clone(), cx).unwrap();
-                    })?;
-                    anyhow::Ok(())
-                });
-                tasks.push(task);
-            }
-            cx.spawn(async move |_| {
-                try_join_all(tasks).await?;
-                Ok(acp::PromptResponse {
-                    stop_reason: acp::StopReason::EndTurn,
+                        anyhow::Ok(())
+                    });
+                    tasks.push(task);
+                }
+
+                cx.spawn(async move |_| {
+                    try_join_all(tasks).await?;
+                    Ok(acp::PromptResponse {
+                        stop_reason: acp::StopReason::EndTurn,
+                    })
                 })
-            })
+            }
         }
 
         fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {

crates/agent_ui/src/acp/entry_view_state.rs 🔗

@@ -5,8 +5,8 @@ use agent::{TextThreadStore, ThreadStore};
 use collections::HashMap;
 use editor::{Editor, EditorMode, MinimapVisibility};
 use gpui::{
-    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
-    WeakEntity, Window,
+    AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
+    TextStyleRefinement, WeakEntity, Window,
 };
 use language::language_settings::SoftWrap;
 use project::Project;
@@ -61,34 +61,44 @@ impl EntryViewState {
             AgentThreadEntry::UserMessage(message) => {
                 let has_id = message.id.is_some();
                 let chunks = message.chunks.clone();
-                let message_editor = cx.new(|cx| {
-                    let mut editor = MessageEditor::new(
-                        self.workspace.clone(),
-                        self.project.clone(),
-                        self.thread_store.clone(),
-                        self.text_thread_store.clone(),
-                        "Edit message - @ to include context",
-                        editor::EditorMode::AutoHeight {
-                            min_lines: 1,
-                            max_lines: None,
-                        },
-                        window,
-                        cx,
-                    );
-                    if !has_id {
-                        editor.set_read_only(true, cx);
+                if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
+                    if !editor.focus_handle(cx).is_focused(window) {
+                        // Only update if we are not editing.
+                        // If we are, cancelling the edit will set the message to the newest content.
+                        editor.update(cx, |editor, cx| {
+                            editor.set_message(chunks, window, cx);
+                        });
                     }
-                    editor.set_message(chunks, window, cx);
-                    editor
-                });
-                cx.subscribe(&message_editor, move |_, editor, event, cx| {
-                    cx.emit(EntryViewEvent {
-                        entry_index: index,
-                        view_event: ViewEvent::MessageEditorEvent(editor, *event),
+                } else {
+                    let message_editor = cx.new(|cx| {
+                        let mut editor = MessageEditor::new(
+                            self.workspace.clone(),
+                            self.project.clone(),
+                            self.thread_store.clone(),
+                            self.text_thread_store.clone(),
+                            "Edit message - @ to include context",
+                            editor::EditorMode::AutoHeight {
+                                min_lines: 1,
+                                max_lines: None,
+                            },
+                            window,
+                            cx,
+                        );
+                        if !has_id {
+                            editor.set_read_only(true, cx);
+                        }
+                        editor.set_message(chunks, window, cx);
+                        editor
+                    });
+                    cx.subscribe(&message_editor, move |_, editor, event, cx| {
+                        cx.emit(EntryViewEvent {
+                            entry_index: index,
+                            view_event: ViewEvent::MessageEditorEvent(editor, *event),
+                        })
                     })
-                })
-                .detach();
-                self.set_entry(index, Entry::UserMessage(message_editor));
+                    .detach();
+                    self.set_entry(index, Entry::UserMessage(message_editor));
+                }
             }
             AgentThreadEntry::ToolCall(tool_call) => {
                 let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();

crates/agent_ui/src/acp/thread_view.rs 🔗

@@ -3606,7 +3606,7 @@ pub(crate) mod tests {
     async fn test_drop(cx: &mut TestAppContext) {
         init_test(cx);
 
-        let (thread_view, _cx) = setup_thread_view(StubAgentServer::default(), cx).await;
+        let (thread_view, _cx) = setup_thread_view(StubAgentServer::default_response(), cx).await;
         let weak_view = thread_view.downgrade();
         drop(thread_view);
         assert!(!weak_view.is_upgradable());
@@ -3616,7 +3616,7 @@ pub(crate) mod tests {
     async fn test_notification_for_stop_event(cx: &mut TestAppContext) {
         init_test(cx);
 
-        let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await;
+        let (thread_view, cx) = setup_thread_view(StubAgentServer::default_response(), cx).await;
 
         let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
         message_editor.update_in(cx, |editor, window, cx| {
@@ -3800,8 +3800,12 @@ pub(crate) mod tests {
     }
 
     impl StubAgentServer<StubAgentConnection> {
-        fn default() -> Self {
-            Self::new(StubAgentConnection::default())
+        fn default_response() -> Self {
+            let conn = StubAgentConnection::new();
+            conn.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
+                content: "Default response".into(),
+            }]);
+            Self::new(conn)
         }
     }
 
@@ -4214,4 +4218,88 @@ pub(crate) mod tests {
             assert_eq!(new_editor.read(cx).text(cx), "Edited message content");
         })
     }
+
+    #[gpui::test]
+    async fn test_message_editing_while_generating(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let connection = StubAgentConnection::new();
+
+        let (thread_view, cx) =
+            setup_thread_view(StubAgentServer::new(connection.clone()), cx).await;
+        add_to_workspace(thread_view.clone(), cx);
+
+        let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+        message_editor.update_in(cx, |editor, window, cx| {
+            editor.set_text("Original message to edit", window, cx);
+        });
+        thread_view.update_in(cx, |thread_view, window, cx| {
+            thread_view.send(window, cx);
+        });
+
+        cx.run_until_parked();
+
+        let (user_message_editor, session_id) = thread_view.read_with(cx, |view, cx| {
+            let thread = view.thread().unwrap().read(cx);
+            assert_eq!(thread.entries().len(), 1);
+
+            let editor = view
+                .entry_view_state
+                .read(cx)
+                .entry(0)
+                .unwrap()
+                .message_editor()
+                .unwrap()
+                .clone();
+
+            (editor, thread.session_id().clone())
+        });
+
+        // Focus
+        cx.focus(&user_message_editor);
+
+        thread_view.read_with(cx, |view, _cx| {
+            assert_eq!(view.editing_message, Some(0));
+        });
+
+        // Edit
+        user_message_editor.update_in(cx, |editor, window, cx| {
+            editor.set_text("Edited message content", window, cx);
+        });
+
+        thread_view.read_with(cx, |view, _cx| {
+            assert_eq!(view.editing_message, Some(0));
+        });
+
+        // Finish streaming response
+        cx.update(|_, cx| {
+            connection.send_update(
+                session_id.clone(),
+                acp::SessionUpdate::AgentMessageChunk {
+                    content: acp::ContentBlock::Text(acp::TextContent {
+                        text: "Response".into(),
+                        annotations: None,
+                    }),
+                },
+                cx,
+            );
+            connection.end_turn(session_id);
+        });
+
+        thread_view.read_with(cx, |view, _cx| {
+            assert_eq!(view.editing_message, Some(0));
+        });
+
+        cx.run_until_parked();
+
+        // Should still be editing
+        cx.update(|window, cx| {
+            assert!(user_message_editor.focus_handle(cx).is_focused(window));
+            assert_eq!(thread_view.read(cx).editing_message, Some(0));
+            assert_eq!(
+                user_message_editor.read(cx).text(cx),
+                "Edited message content"
+            );
+        });
+    }
 }