diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index a63dabf1fb25258b6f4255a5c67682165371b255..56e33fda47f095eef1873f7a0724b021e88a0bdc 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1866,10 +1866,14 @@ impl AcpThread { .checkpoint .as_ref() .map(|c| c.git_checkpoint.clone()); + + // Cancel any in-progress generation before restoring + let cancel_task = self.cancel(cx); let rewind = self.rewind(id.clone(), cx); let git_store = self.project.read(cx).git_store().clone(); cx.spawn(async move |_, cx| { + cancel_task.await; rewind.await?; if let Some(checkpoint) = checkpoint { git_store @@ -1894,9 +1898,25 @@ impl AcpThread { cx.update(|cx| truncate.run(id.clone(), cx))?.await?; this.update(cx, |this, cx| { if let Some((ix, _)) = this.user_message_mut(&id) { + // Collect all terminals from entries that will be removed + let terminals_to_remove: Vec = this.entries[ix..] + .iter() + .flat_map(|entry| entry.terminals()) + .filter_map(|terminal| terminal.read(cx).id().clone().into()) + .collect(); + let range = ix..this.entries.len(); this.entries.truncate(ix); cx.emit(AcpThreadEvent::EntriesRemoved(range)); + + // Kill and remove the terminals + for terminal_id in terminals_to_remove { + if let Some(terminal) = this.terminals.remove(&terminal_id) { + terminal.update(cx, |terminal, cx| { + terminal.kill(cx); + }); + } + } } this.action_log().update(cx, |action_log, cx| { action_log.reject_all_edits(Some(telemetry), cx) @@ -3803,4 +3823,314 @@ mod tests { } }); } + + /// Tests that restoring a checkpoint properly cleans up terminals that were + /// created after that checkpoint, and cancels any in-progress generation. + /// + /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes + /// that were started after that checkpoint should be terminated, and any in-progress + /// AI generation should be canceled. + #[gpui::test] + async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + let project = Project::test(fs, [], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + let thread = cx + .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx)) + .await + .unwrap(); + + // Send first user message to create a checkpoint + cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.send(vec!["first message".into()], cx) + }) + }) + .await + .unwrap(); + + // Send second message (creates another checkpoint) - we'll restore to this one + cx.update(|cx| { + thread.update(cx, |thread, cx| { + thread.send(vec!["second message".into()], cx) + }) + }) + .await + .unwrap(); + + // Create 2 terminals BEFORE the checkpoint that have completed running + let terminal_id_1 = acp::TerminalId(uuid::Uuid::new_v4().to_string().into()); + let mock_terminal_1 = cx.new(|cx| { + let builder = ::terminal::TerminalBuilder::new_display_only( + ::terminal::terminal_settings::CursorShape::default(), + ::terminal::terminal_settings::AlternateScroll::On, + None, + 0, + ) + .unwrap(); + builder.subscribe(cx) + }); + + thread.update(cx, |thread, cx| { + thread.on_terminal_provider_event( + TerminalProviderEvent::Created { + terminal_id: terminal_id_1.clone(), + label: "echo 'first'".to_string(), + cwd: Some(PathBuf::from("/test")), + output_byte_limit: None, + terminal: mock_terminal_1.clone(), + }, + cx, + ); + }); + + thread.update(cx, |thread, cx| { + thread.on_terminal_provider_event( + TerminalProviderEvent::Output { + terminal_id: terminal_id_1.clone(), + data: b"first\n".to_vec(), + }, + cx, + ); + }); + + thread.update(cx, |thread, cx| { + thread.on_terminal_provider_event( + TerminalProviderEvent::Exit { + terminal_id: terminal_id_1.clone(), + status: acp::TerminalExitStatus { + exit_code: Some(0), + signal: None, + meta: None, + }, + }, + cx, + ); + }); + + let terminal_id_2 = acp::TerminalId(uuid::Uuid::new_v4().to_string().into()); + let mock_terminal_2 = cx.new(|cx| { + let builder = ::terminal::TerminalBuilder::new_display_only( + ::terminal::terminal_settings::CursorShape::default(), + ::terminal::terminal_settings::AlternateScroll::On, + None, + 0, + ) + .unwrap(); + builder.subscribe(cx) + }); + + thread.update(cx, |thread, cx| { + thread.on_terminal_provider_event( + TerminalProviderEvent::Created { + terminal_id: terminal_id_2.clone(), + label: "echo 'second'".to_string(), + cwd: Some(PathBuf::from("/test")), + output_byte_limit: None, + terminal: mock_terminal_2.clone(), + }, + cx, + ); + }); + + thread.update(cx, |thread, cx| { + thread.on_terminal_provider_event( + TerminalProviderEvent::Output { + terminal_id: terminal_id_2.clone(), + data: b"second\n".to_vec(), + }, + cx, + ); + }); + + thread.update(cx, |thread, cx| { + thread.on_terminal_provider_event( + TerminalProviderEvent::Exit { + terminal_id: terminal_id_2.clone(), + status: acp::TerminalExitStatus { + exit_code: Some(0), + signal: None, + meta: None, + }, + }, + cx, + ); + }); + + // Get the second message ID to restore to + let second_message_id = thread.read_with(cx, |thread, _| { + // At this point we have: + // - Index 0: First user message (with checkpoint) + // - Index 1: Second user message (with checkpoint) + // No assistant responses because FakeAgentConnection just returns EndTurn + let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else { + panic!("expected user message at index 1"); + }; + message.id.clone().unwrap() + }); + + // Create a terminal AFTER the checkpoint we'll restore to. + // This simulates the AI agent starting a long-running terminal command. + let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into()); + let mock_terminal = cx.new(|cx| { + let builder = ::terminal::TerminalBuilder::new_display_only( + ::terminal::terminal_settings::CursorShape::default(), + ::terminal::terminal_settings::AlternateScroll::On, + None, + 0, + ) + .unwrap(); + builder.subscribe(cx) + }); + + // Register the terminal as created + thread.update(cx, |thread, cx| { + thread.on_terminal_provider_event( + TerminalProviderEvent::Created { + terminal_id: terminal_id.clone(), + label: "sleep 1000".to_string(), + cwd: Some(PathBuf::from("/test")), + output_byte_limit: None, + terminal: mock_terminal.clone(), + }, + cx, + ); + }); + + // Simulate the terminal producing output (still running) + thread.update(cx, |thread, cx| { + thread.on_terminal_provider_event( + TerminalProviderEvent::Output { + terminal_id: terminal_id.clone(), + data: b"terminal is running...\n".to_vec(), + }, + cx, + ); + }); + + // Create a tool call entry that references this terminal + // This represents the agent requesting a terminal command + thread.update(cx, |thread, cx| { + thread + .handle_session_update( + acp::SessionUpdate::ToolCall(acp::ToolCall { + id: acp::ToolCallId("terminal-tool-1".into()), + title: "Running command".into(), + kind: acp::ToolKind::Execute, + status: acp::ToolCallStatus::InProgress, + content: vec![acp::ToolCallContent::Terminal { + terminal_id: terminal_id.clone(), + }], + locations: vec![], + raw_input: Some( + serde_json::json!({"command": "sleep 1000", "cd": "/test"}), + ), + raw_output: None, + meta: None, + }), + cx, + ) + .unwrap(); + }); + + // Verify terminal exists and is in the thread + let terminal_exists_before = + thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id)); + assert!( + terminal_exists_before, + "Terminal should exist before checkpoint restore" + ); + + // Verify the terminal's underlying task is still running (not completed) + let terminal_running_before = thread.read_with(cx, |thread, _cx| { + let terminal_entity = thread.terminals.get(&terminal_id).unwrap(); + terminal_entity.read_with(cx, |term, _cx| { + term.output().is_none() // output is None means it's still running + }) + }); + assert!( + terminal_running_before, + "Terminal should be running before checkpoint restore" + ); + + // Verify we have the expected entries before restore + let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len()); + assert!( + entry_count_before > 1, + "Should have multiple entries before restore" + ); + + // Restore the checkpoint to the second message. + // This should: + // 1. Cancel any in-progress generation (via the cancel() call) + // 2. Remove the terminal that was created after that point + thread + .update(cx, |thread, cx| { + thread.restore_checkpoint(second_message_id, cx) + }) + .await + .unwrap(); + + // Verify that no send_task is in progress after restore + // (cancel() clears the send_task) + let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some()); + assert!( + !has_send_task_after, + "Should not have a send_task after restore (cancel should have cleared it)" + ); + + // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0) + let entry_count = thread.read_with(cx, |thread, _| thread.entries.len()); + assert_eq!( + entry_count, 1, + "Should have 1 entry after restore (only the first user message)" + ); + + // Verify the 2 completed terminals from before the checkpoint still exist + let terminal_1_exists = thread.read_with(cx, |thread, _| { + thread.terminals.contains_key(&terminal_id_1) + }); + assert!( + terminal_1_exists, + "Terminal 1 (from before checkpoint) should still exist" + ); + + let terminal_2_exists = thread.read_with(cx, |thread, _| { + thread.terminals.contains_key(&terminal_id_2) + }); + assert!( + terminal_2_exists, + "Terminal 2 (from before checkpoint) should still exist" + ); + + // Verify they're still in completed state + let terminal_1_completed = thread.read_with(cx, |thread, _cx| { + let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap(); + terminal_entity.read_with(cx, |term, _cx| term.output().is_some()) + }); + assert!(terminal_1_completed, "Terminal 1 should still be completed"); + + let terminal_2_completed = thread.read_with(cx, |thread, _cx| { + let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap(); + terminal_entity.read_with(cx, |term, _cx| term.output().is_some()) + }); + assert!(terminal_2_completed, "Terminal 2 should still be completed"); + + // Verify the running terminal (created after checkpoint) was removed + let terminal_3_exists = + thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id)); + assert!( + !terminal_3_exists, + "Terminal 3 (created after checkpoint) should have been removed" + ); + + // Verify total count is 2 (the two from before the checkpoint) + let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len()); + assert_eq!( + terminal_count, 2, + "Should have exactly 2 terminals (the completed ones from before checkpoint)" + ); + } }