agent: Cancel retries when the turn is cancelled (#50580)

Tom Houlé created

When a completion request fails with a retryable error (e.g. a 500 from
the upstream provider), the retry loop waits on a timer before trying
again. This timer did not race with the cancellation signal, so if the
user switched models and submitted a new message during the retry delay,
the old turn would continue retrying with the stale model for up to 15
seconds — making requests to the wrong provider and corrupting the
thread's message list with spurious Resume entries.

Now the retry delay races with the cancellation receiver, so the old
turn exits immediately when cancelled.

Release Notes:

- Fixed cancelled turns in a conversation that failed (e.g. 500 from the
LLM provider) bein retried even after cancellation

Change summary

crates/agent/src/tests/mod.rs | 78 +++++++++++++++++++++++++++++++++++++
crates/agent/src/thread.rs    | 10 ++++
2 files changed, 87 insertions(+), 1 deletion(-)

Detailed changes

crates/agent/src/tests/mod.rs 🔗

@@ -2631,6 +2631,84 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) {
     assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
 }
 
+#[gpui::test]
+async fn test_retry_cancelled_promptly_on_new_send(cx: &mut TestAppContext) {
+    // Regression test: when a completion fails with a retryable error (e.g. upstream 500),
+    // the retry loop waits on a timer. If the user switches models and sends a new message
+    // during that delay, the old turn should exit immediately instead of retrying with the
+    // stale model.
+    let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;
+    let model_a = model.as_fake();
+
+    // Start a turn with model_a.
+    let events_1 = thread
+        .update(cx, |thread, cx| {
+            thread.send(UserMessageId::new(), ["Hello"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+    assert_eq!(model_a.completion_count(), 1);
+
+    // Model returns a retryable upstream 500. The turn enters the retry delay.
+    model_a.send_last_completion_stream_error(
+        LanguageModelCompletionError::UpstreamProviderError {
+            message: "Internal server error".to_string(),
+            status: http_client::StatusCode::INTERNAL_SERVER_ERROR,
+            retry_after: None,
+        },
+    );
+    model_a.end_last_completion_stream();
+    cx.run_until_parked();
+
+    // The old completion was consumed; model_a has no pending requests yet because the
+    // retry timer hasn't fired.
+    assert_eq!(model_a.completion_count(), 0);
+
+    // Switch to model_b and send a new message. This cancels the old turn.
+    let model_b = Arc::new(FakeLanguageModel::with_id_and_thinking(
+        "fake", "model-b", "Model B", false,
+    ));
+    thread.update(cx, |thread, cx| {
+        thread.set_model(model_b.clone(), cx);
+    });
+    let events_2 = thread
+        .update(cx, |thread, cx| {
+            thread.send(UserMessageId::new(), ["Continue"], cx)
+        })
+        .unwrap();
+    cx.run_until_parked();
+
+    // model_b should have received its completion request.
+    assert_eq!(model_b.as_fake().completion_count(), 1);
+
+    // Advance the clock well past the retry delay (BASE_RETRY_DELAY = 5s).
+    cx.executor().advance_clock(Duration::from_secs(10));
+    cx.run_until_parked();
+
+    // model_a must NOT have received another completion request — the cancelled turn
+    // should have exited during the retry delay rather than retrying with the old model.
+    assert_eq!(
+        model_a.completion_count(),
+        0,
+        "old model should not receive a retry request after cancellation"
+    );
+
+    // Complete model_b's turn.
+    model_b
+        .as_fake()
+        .send_last_completion_stream_text_chunk("Done!");
+    model_b
+        .as_fake()
+        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
+    model_b.as_fake().end_last_completion_stream();
+
+    let events_1 = events_1.collect::<Vec<_>>().await;
+    assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]);
+
+    let events_2 = events_2.collect::<Vec<_>>().await;
+    assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]);
+}
+
 #[gpui::test]
 async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) {
     let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await;

crates/agent/src/thread.rs 🔗

@@ -1940,7 +1940,15 @@ impl Thread {
                 })??;
                 let timer = cx.background_executor().timer(retry.duration);
                 event_stream.send_retry(retry);
-                timer.await;
+                futures::select! {
+                    _ = timer.fuse() => {}
+                    _ = cancellation_rx.changed().fuse() => {
+                        if *cancellation_rx.borrow() {
+                            log::debug!("Turn cancelled during retry delay, exiting");
+                            return Ok(());
+                        }
+                    }
+                }
                 this.update(cx, |this, _cx| {
                     if let Some(Message::Agent(message)) = this.messages.last() {
                         if message.tool_results.is_empty() {