diff --git a/crates/agent/src/tests/mod.rs b/crates/agent/src/tests/mod.rs index 8d75aae7e2948ef9c0934a72da112b926f633941..23ebe41d3c42654cb8fcdc0266009416686858aa 100644 --- a/crates/agent/src/tests/mod.rs +++ b/crates/agent/src/tests/mod.rs @@ -2631,6 +2631,84 @@ async fn test_in_progress_send_canceled_by_next_send(cx: &mut TestAppContext) { assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]); } +#[gpui::test] +async fn test_retry_cancelled_promptly_on_new_send(cx: &mut TestAppContext) { + // Regression test: when a completion fails with a retryable error (e.g. upstream 500), + // the retry loop waits on a timer. If the user switches models and sends a new message + // during that delay, the old turn should exit immediately instead of retrying with the + // stale model. + let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; + let model_a = model.as_fake(); + + // Start a turn with model_a. + let events_1 = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Hello"], cx) + }) + .unwrap(); + cx.run_until_parked(); + assert_eq!(model_a.completion_count(), 1); + + // Model returns a retryable upstream 500. The turn enters the retry delay. + model_a.send_last_completion_stream_error( + LanguageModelCompletionError::UpstreamProviderError { + message: "Internal server error".to_string(), + status: http_client::StatusCode::INTERNAL_SERVER_ERROR, + retry_after: None, + }, + ); + model_a.end_last_completion_stream(); + cx.run_until_parked(); + + // The old completion was consumed; model_a has no pending requests yet because the + // retry timer hasn't fired. + assert_eq!(model_a.completion_count(), 0); + + // Switch to model_b and send a new message. This cancels the old turn. + let model_b = Arc::new(FakeLanguageModel::with_id_and_thinking( + "fake", "model-b", "Model B", false, + )); + thread.update(cx, |thread, cx| { + thread.set_model(model_b.clone(), cx); + }); + let events_2 = thread + .update(cx, |thread, cx| { + thread.send(UserMessageId::new(), ["Continue"], cx) + }) + .unwrap(); + cx.run_until_parked(); + + // model_b should have received its completion request. + assert_eq!(model_b.as_fake().completion_count(), 1); + + // Advance the clock well past the retry delay (BASE_RETRY_DELAY = 5s). + cx.executor().advance_clock(Duration::from_secs(10)); + cx.run_until_parked(); + + // model_a must NOT have received another completion request — the cancelled turn + // should have exited during the retry delay rather than retrying with the old model. + assert_eq!( + model_a.completion_count(), + 0, + "old model should not receive a retry request after cancellation" + ); + + // Complete model_b's turn. + model_b + .as_fake() + .send_last_completion_stream_text_chunk("Done!"); + model_b + .as_fake() + .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)); + model_b.as_fake().end_last_completion_stream(); + + let events_1 = events_1.collect::>().await; + assert_eq!(stop_events(events_1), vec![acp::StopReason::Cancelled]); + + let events_2 = events_2.collect::>().await; + assert_eq!(stop_events(events_2), vec![acp::StopReason::EndTurn]); +} + #[gpui::test] async fn test_subsequent_successful_sends_dont_cancel(cx: &mut TestAppContext) { let ThreadTest { model, thread, .. } = setup(cx, TestModel::Fake).await; diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index c5ca1118ace28b66d555d67aa40c718da292f644..2e693a85cd1f86d232e392860d8bd83509ce131a 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -1940,7 +1940,15 @@ impl Thread { })??; let timer = cx.background_executor().timer(retry.duration); event_stream.send_retry(retry); - timer.await; + futures::select! { + _ = timer.fuse() => {} + _ = cancellation_rx.changed().fuse() => { + if *cancellation_rx.borrow() { + log::debug!("Turn cancelled during retry delay, exiting"); + return Ok(()); + } + } + } this.update(cx, |this, _cx| { if let Some(Message::Agent(message)) = this.messages.last() { if message.tool_results.is_empty() {