crates/acp_thread/src/acp_thread.rs 🔗
@@ -670,6 +670,7 @@ pub struct AcpThread {
session_id: acp::SessionId,
}
+#[derive(Debug)]
pub enum AcpThreadEvent {
NewEntry,
EntryUpdated(usize),
Agus Zubiaga created
When a turn ends and the checkpoint is updated, `AcpThread` emits
`EntryUpdated` with the index of the user message. This was causing the
message editor to be recreated and, therefore, lose focus.
Release Notes:
- N/A
crates/acp_thread/src/acp_thread.rs | 1
crates/acp_thread/src/connection.rs | 119 +++++++++++++++-------
crates/agent_ui/src/acp/entry_view_state.rs | 66 +++++++-----
crates/agent_ui/src/acp/thread_view.rs | 96 +++++++++++++++++
4 files changed, 213 insertions(+), 69 deletions(-)
@@ -670,6 +670,7 @@ pub struct AcpThread {
session_id: acp::SessionId,
}
+#[derive(Debug)]
pub enum AcpThreadEvent {
NewEntry,
EntryUpdated(usize),
@@ -186,7 +186,7 @@ mod test_support {
use std::sync::Arc;
use collections::HashMap;
- use futures::future::try_join_all;
+ use futures::{channel::oneshot, future::try_join_all};
use gpui::{AppContext as _, WeakEntity};
use parking_lot::Mutex;
@@ -194,11 +194,16 @@ mod test_support {
#[derive(Clone, Default)]
pub struct StubAgentConnection {
- sessions: Arc<Mutex<HashMap<acp::SessionId, WeakEntity<AcpThread>>>>,
+ sessions: Arc<Mutex<HashMap<acp::SessionId, Session>>>,
permission_requests: HashMap<acp::ToolCallId, Vec<acp::PermissionOption>>,
next_prompt_updates: Arc<Mutex<Vec<acp::SessionUpdate>>>,
}
+ struct Session {
+ thread: WeakEntity<AcpThread>,
+ response_tx: Option<oneshot::Sender<()>>,
+ }
+
impl StubAgentConnection {
pub fn new() -> Self {
Self {
@@ -226,15 +231,33 @@ mod test_support {
update: acp::SessionUpdate,
cx: &mut App,
) {
+ assert!(
+ self.next_prompt_updates.lock().is_empty(),
+ "Use either send_update or set_next_prompt_updates"
+ );
+
self.sessions
.lock()
.get(&session_id)
.unwrap()
+ .thread
.update(cx, |thread, cx| {
thread.handle_session_update(update.clone(), cx).unwrap();
})
.unwrap();
}
+
+ pub fn end_turn(&self, session_id: acp::SessionId) {
+ self.sessions
+ .lock()
+ .get_mut(&session_id)
+ .unwrap()
+ .response_tx
+ .take()
+ .expect("No pending turn")
+ .send(())
+ .unwrap();
+ }
}
impl AgentConnection for StubAgentConnection {
@@ -251,7 +274,13 @@ mod test_support {
let session_id = acp::SessionId(self.sessions.lock().len().to_string().into());
let thread =
cx.new(|cx| AcpThread::new("Test", self.clone(), project, session_id.clone(), cx));
- self.sessions.lock().insert(session_id, thread.downgrade());
+ self.sessions.lock().insert(
+ session_id,
+ Session {
+ thread: thread.downgrade(),
+ response_tx: None,
+ },
+ );
Task::ready(Ok(thread))
}
@@ -269,43 +298,59 @@ mod test_support {
params: acp::PromptRequest,
cx: &mut App,
) -> Task<gpui::Result<acp::PromptResponse>> {
- let sessions = self.sessions.lock();
- let thread = sessions.get(¶ms.session_id).unwrap();
+ let mut sessions = self.sessions.lock();
+ let Session {
+ thread,
+ response_tx,
+ } = sessions.get_mut(¶ms.session_id).unwrap();
let mut tasks = vec![];
- for update in self.next_prompt_updates.lock().drain(..) {
- let thread = thread.clone();
- let update = update.clone();
- let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) = &update
- && let Some(options) = self.permission_requests.get(&tool_call.id)
- {
- Some((tool_call.clone(), options.clone()))
- } else {
- None
- };
- let task = cx.spawn(async move |cx| {
- if let Some((tool_call, options)) = permission_request {
- let permission = thread.update(cx, |thread, cx| {
- thread.request_tool_call_authorization(
- tool_call.clone().into(),
- options.clone(),
- cx,
- )
+ if self.next_prompt_updates.lock().is_empty() {
+ let (tx, rx) = oneshot::channel();
+ response_tx.replace(tx);
+ cx.spawn(async move |_| {
+ rx.await?;
+ Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::EndTurn,
+ })
+ })
+ } else {
+ for update in self.next_prompt_updates.lock().drain(..) {
+ let thread = thread.clone();
+ let update = update.clone();
+ let permission_request = if let acp::SessionUpdate::ToolCall(tool_call) =
+ &update
+ && let Some(options) = self.permission_requests.get(&tool_call.id)
+ {
+ Some((tool_call.clone(), options.clone()))
+ } else {
+ None
+ };
+ let task = cx.spawn(async move |cx| {
+ if let Some((tool_call, options)) = permission_request {
+ let permission = thread.update(cx, |thread, cx| {
+ thread.request_tool_call_authorization(
+ tool_call.clone().into(),
+ options.clone(),
+ cx,
+ )
+ })?;
+ permission?.await?;
+ }
+ thread.update(cx, |thread, cx| {
+ thread.handle_session_update(update.clone(), cx).unwrap();
})?;
- permission?.await?;
- }
- thread.update(cx, |thread, cx| {
- thread.handle_session_update(update.clone(), cx).unwrap();
- })?;
- anyhow::Ok(())
- });
- tasks.push(task);
- }
- cx.spawn(async move |_| {
- try_join_all(tasks).await?;
- Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::EndTurn,
+ anyhow::Ok(())
+ });
+ tasks.push(task);
+ }
+
+ cx.spawn(async move |_| {
+ try_join_all(tasks).await?;
+ Ok(acp::PromptResponse {
+ stop_reason: acp::StopReason::EndTurn,
+ })
})
- })
+ }
}
fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
@@ -5,8 +5,8 @@ use agent::{TextThreadStore, ThreadStore};
use collections::HashMap;
use editor::{Editor, EditorMode, MinimapVisibility};
use gpui::{
- AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, TextStyleRefinement,
- WeakEntity, Window,
+ AnyEntity, App, AppContext as _, Entity, EntityId, EventEmitter, Focusable,
+ TextStyleRefinement, WeakEntity, Window,
};
use language::language_settings::SoftWrap;
use project::Project;
@@ -61,34 +61,44 @@ impl EntryViewState {
AgentThreadEntry::UserMessage(message) => {
let has_id = message.id.is_some();
let chunks = message.chunks.clone();
- let message_editor = cx.new(|cx| {
- let mut editor = MessageEditor::new(
- self.workspace.clone(),
- self.project.clone(),
- self.thread_store.clone(),
- self.text_thread_store.clone(),
- "Edit message - @ to include context",
- editor::EditorMode::AutoHeight {
- min_lines: 1,
- max_lines: None,
- },
- window,
- cx,
- );
- if !has_id {
- editor.set_read_only(true, cx);
+ if let Some(Entry::UserMessage(editor)) = self.entries.get_mut(index) {
+ if !editor.focus_handle(cx).is_focused(window) {
+ // Only update if we are not editing.
+ // If we are, cancelling the edit will set the message to the newest content.
+ editor.update(cx, |editor, cx| {
+ editor.set_message(chunks, window, cx);
+ });
}
- editor.set_message(chunks, window, cx);
- editor
- });
- cx.subscribe(&message_editor, move |_, editor, event, cx| {
- cx.emit(EntryViewEvent {
- entry_index: index,
- view_event: ViewEvent::MessageEditorEvent(editor, *event),
+ } else {
+ let message_editor = cx.new(|cx| {
+ let mut editor = MessageEditor::new(
+ self.workspace.clone(),
+ self.project.clone(),
+ self.thread_store.clone(),
+ self.text_thread_store.clone(),
+ "Edit message - @ to include context",
+ editor::EditorMode::AutoHeight {
+ min_lines: 1,
+ max_lines: None,
+ },
+ window,
+ cx,
+ );
+ if !has_id {
+ editor.set_read_only(true, cx);
+ }
+ editor.set_message(chunks, window, cx);
+ editor
+ });
+ cx.subscribe(&message_editor, move |_, editor, event, cx| {
+ cx.emit(EntryViewEvent {
+ entry_index: index,
+ view_event: ViewEvent::MessageEditorEvent(editor, *event),
+ })
})
- })
- .detach();
- self.set_entry(index, Entry::UserMessage(message_editor));
+ .detach();
+ self.set_entry(index, Entry::UserMessage(message_editor));
+ }
}
AgentThreadEntry::ToolCall(tool_call) => {
let terminals = tool_call.terminals().cloned().collect::<Vec<_>>();
@@ -3606,7 +3606,7 @@ pub(crate) mod tests {
async fn test_drop(cx: &mut TestAppContext) {
init_test(cx);
- let (thread_view, _cx) = setup_thread_view(StubAgentServer::default(), cx).await;
+ let (thread_view, _cx) = setup_thread_view(StubAgentServer::default_response(), cx).await;
let weak_view = thread_view.downgrade();
drop(thread_view);
assert!(!weak_view.is_upgradable());
@@ -3616,7 +3616,7 @@ pub(crate) mod tests {
async fn test_notification_for_stop_event(cx: &mut TestAppContext) {
init_test(cx);
- let (thread_view, cx) = setup_thread_view(StubAgentServer::default(), cx).await;
+ let (thread_view, cx) = setup_thread_view(StubAgentServer::default_response(), cx).await;
let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
message_editor.update_in(cx, |editor, window, cx| {
@@ -3800,8 +3800,12 @@ pub(crate) mod tests {
}
impl StubAgentServer<StubAgentConnection> {
- fn default() -> Self {
- Self::new(StubAgentConnection::default())
+ fn default_response() -> Self {
+ let conn = StubAgentConnection::new();
+ conn.set_next_prompt_updates(vec![acp::SessionUpdate::AgentMessageChunk {
+ content: "Default response".into(),
+ }]);
+ Self::new(conn)
}
}
@@ -4214,4 +4218,88 @@ pub(crate) mod tests {
assert_eq!(new_editor.read(cx).text(cx), "Edited message content");
})
}
+
+ #[gpui::test]
+ async fn test_message_editing_while_generating(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let connection = StubAgentConnection::new();
+
+ let (thread_view, cx) =
+ setup_thread_view(StubAgentServer::new(connection.clone()), cx).await;
+ add_to_workspace(thread_view.clone(), cx);
+
+ let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+ message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("Original message to edit", window, cx);
+ });
+ thread_view.update_in(cx, |thread_view, window, cx| {
+ thread_view.send(window, cx);
+ });
+
+ cx.run_until_parked();
+
+ let (user_message_editor, session_id) = thread_view.read_with(cx, |view, cx| {
+ let thread = view.thread().unwrap().read(cx);
+ assert_eq!(thread.entries().len(), 1);
+
+ let editor = view
+ .entry_view_state
+ .read(cx)
+ .entry(0)
+ .unwrap()
+ .message_editor()
+ .unwrap()
+ .clone();
+
+ (editor, thread.session_id().clone())
+ });
+
+ // Focus
+ cx.focus(&user_message_editor);
+
+ thread_view.read_with(cx, |view, _cx| {
+ assert_eq!(view.editing_message, Some(0));
+ });
+
+ // Edit
+ user_message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("Edited message content", window, cx);
+ });
+
+ thread_view.read_with(cx, |view, _cx| {
+ assert_eq!(view.editing_message, Some(0));
+ });
+
+ // Finish streaming response
+ cx.update(|_, cx| {
+ connection.send_update(
+ session_id.clone(),
+ acp::SessionUpdate::AgentMessageChunk {
+ content: acp::ContentBlock::Text(acp::TextContent {
+ text: "Response".into(),
+ annotations: None,
+ }),
+ },
+ cx,
+ );
+ connection.end_turn(session_id);
+ });
+
+ thread_view.read_with(cx, |view, _cx| {
+ assert_eq!(view.editing_message, Some(0));
+ });
+
+ cx.run_until_parked();
+
+ // Should still be editing
+ cx.update(|window, cx| {
+ assert!(user_message_editor.focus_handle(cx).is_focused(window));
+ assert_eq!(thread_view.read(cx).editing_message, Some(0));
+ assert_eq!(
+ user_message_editor.read(cx).text(cx),
+ "Edited message content"
+ );
+ });
+ }
}