@@ -1200,17 +1200,21 @@ impl AcpThread {
} else {
None
};
- self.push_entry(
- AgentThreadEntry::UserMessage(UserMessage {
- id: message_id.clone(),
- content: block,
- chunks: message,
- checkpoint: None,
- }),
- cx,
- );
self.run_turn(cx, async move |this, cx| {
+ this.update(cx, |this, cx| {
+ this.push_entry(
+ AgentThreadEntry::UserMessage(UserMessage {
+ id: message_id.clone(),
+ content: block,
+ chunks: message,
+ checkpoint: None,
+ }),
+ cx,
+ );
+ })
+ .ok();
+
let old_checkpoint = git_store
.update(cx, |git, cx| git.checkpoint(cx))?
.await
@@ -201,7 +201,7 @@ mod test_support {
struct Session {
thread: WeakEntity<AcpThread>,
- response_tx: Option<oneshot::Sender<()>>,
+ response_tx: Option<oneshot::Sender<acp::StopReason>>,
}
impl StubAgentConnection {
@@ -242,12 +242,12 @@ mod test_support {
.unwrap()
.thread
.update(cx, |thread, cx| {
- thread.handle_session_update(update.clone(), cx).unwrap();
+ thread.handle_session_update(update, cx).unwrap();
})
.unwrap();
}
- pub fn end_turn(&self, session_id: acp::SessionId) {
+ pub fn end_turn(&self, session_id: acp::SessionId, stop_reason: acp::StopReason) {
self.sessions
.lock()
.get_mut(&session_id)
@@ -255,7 +255,7 @@ mod test_support {
.response_tx
.take()
.expect("No pending turn")
- .send(())
+ .send(stop_reason)
.unwrap();
}
}
@@ -308,10 +308,8 @@ mod test_support {
let (tx, rx) = oneshot::channel();
response_tx.replace(tx);
cx.spawn(async move |_| {
- rx.await?;
- Ok(acp::PromptResponse {
- stop_reason: acp::StopReason::EndTurn,
- })
+ let stop_reason = rx.await?;
+ Ok(acp::PromptResponse { stop_reason })
})
} else {
for update in self.next_prompt_updates.lock().drain(..) {
@@ -353,8 +351,17 @@ mod test_support {
}
}
- fn cancel(&self, _session_id: &acp::SessionId, _cx: &mut App) {
- unimplemented!()
+ fn cancel(&self, session_id: &acp::SessionId, _cx: &mut App) {
+ if let Some(end_turn_tx) = self
+ .sessions
+ .lock()
+ .get_mut(session_id)
+ .unwrap()
+ .response_tx
+ .take()
+ {
+ end_turn_tx.send(acp::StopReason::Canceled).unwrap();
+ }
}
fn session_editor(
@@ -4283,7 +4283,7 @@ pub(crate) mod tests {
},
cx,
);
- connection.end_turn(session_id);
+ connection.end_turn(session_id, acp::StopReason::EndTurn);
});
thread_view.read_with(cx, |view, _cx| {
@@ -4302,4 +4302,137 @@ pub(crate) mod tests {
);
});
}
+
+ #[gpui::test]
+ async fn test_interrupt(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let connection = StubAgentConnection::new();
+
+ let (thread_view, cx) =
+ setup_thread_view(StubAgentServer::new(connection.clone()), cx).await;
+ add_to_workspace(thread_view.clone(), cx);
+
+ let message_editor = cx.read(|cx| thread_view.read(cx).message_editor.clone());
+ message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("Message 1", window, cx);
+ });
+ thread_view.update_in(cx, |thread_view, window, cx| {
+ thread_view.send(window, cx);
+ });
+
+ let (thread, session_id) = thread_view.read_with(cx, |view, cx| {
+ let thread = view.thread().unwrap();
+
+ (thread.clone(), thread.read(cx).session_id().clone())
+ });
+
+ cx.run_until_parked();
+
+ cx.update(|_, cx| {
+ connection.send_update(
+ session_id.clone(),
+ acp::SessionUpdate::AgentMessageChunk {
+ content: "Message 1 resp".into(),
+ },
+ cx,
+ );
+ });
+
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc::indoc! {"
+ ## User
+
+ Message 1
+
+ ## Assistant
+
+ Message 1 resp
+
+ "}
+ )
+ });
+
+ message_editor.update_in(cx, |editor, window, cx| {
+ editor.set_text("Message 2", window, cx);
+ });
+ thread_view.update_in(cx, |thread_view, window, cx| {
+ thread_view.send(window, cx);
+ });
+
+ cx.update(|_, cx| {
+ // Simulate a response sent after beginning to cancel
+ connection.send_update(
+ session_id.clone(),
+ acp::SessionUpdate::AgentMessageChunk {
+ content: "onse".into(),
+ },
+ cx,
+ );
+ });
+
+ cx.run_until_parked();
+
+ // Last Message 1 response should appear before Message 2
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc::indoc! {"
+ ## User
+
+ Message 1
+
+ ## Assistant
+
+ Message 1 response
+
+ ## User
+
+ Message 2
+
+ "}
+ )
+ });
+
+ cx.update(|_, cx| {
+ connection.send_update(
+ session_id.clone(),
+ acp::SessionUpdate::AgentMessageChunk {
+ content: "Message 2 response".into(),
+ },
+ cx,
+ );
+ connection.end_turn(session_id.clone(), acp::StopReason::EndTurn);
+ });
+
+ cx.run_until_parked();
+
+ thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc::indoc! {"
+ ## User
+
+ Message 1
+
+ ## Assistant
+
+ Message 1 response
+
+ ## User
+
+ Message 2
+
+ ## Assistant
+
+ Message 2 response
+
+ "}
+ )
+ });
+ }
}