Stop thread on Restore Checkpoint (#42537)

Richard Feldman created

Closes #35142

In addition to cleaning up the terminals, also stops the conversation.

Release Notes:

- Restoring a checkpoint now stops the agent conversation.

Change summary

crates/acp_thread/src/acp_thread.rs | 330 +++++++++++++++++++++++++++++++
1 file changed, 330 insertions(+)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -1866,10 +1866,14 @@ impl AcpThread {
             .checkpoint
             .as_ref()
             .map(|c| c.git_checkpoint.clone());
+
+        // Cancel any in-progress generation before restoring
+        let cancel_task = self.cancel(cx);
         let rewind = self.rewind(id.clone(), cx);
         let git_store = self.project.read(cx).git_store().clone();
 
         cx.spawn(async move |_, cx| {
+            cancel_task.await;
             rewind.await?;
             if let Some(checkpoint) = checkpoint {
                 git_store
@@ -1894,9 +1898,25 @@ impl AcpThread {
             cx.update(|cx| truncate.run(id.clone(), cx))?.await?;
             this.update(cx, |this, cx| {
                 if let Some((ix, _)) = this.user_message_mut(&id) {
+                    // Collect all terminals from entries that will be removed
+                    let terminals_to_remove: Vec<acp::TerminalId> = this.entries[ix..]
+                        .iter()
+                        .flat_map(|entry| entry.terminals())
+                        .filter_map(|terminal| terminal.read(cx).id().clone().into())
+                        .collect();
+
                     let range = ix..this.entries.len();
                     this.entries.truncate(ix);
                     cx.emit(AcpThreadEvent::EntriesRemoved(range));
+
+                    // Kill and remove the terminals
+                    for terminal_id in terminals_to_remove {
+                        if let Some(terminal) = this.terminals.remove(&terminal_id) {
+                            terminal.update(cx, |terminal, cx| {
+                                terminal.kill(cx);
+                            });
+                        }
+                    }
                 }
                 this.action_log().update(cx, |action_log, cx| {
                     action_log.reject_all_edits(Some(telemetry), cx)
@@ -3803,4 +3823,314 @@ mod tests {
             }
         });
     }
+
+    /// Tests that restoring a checkpoint properly cleans up terminals that were
+    /// created after that checkpoint, and cancels any in-progress generation.
+    ///
+    /// Reproduces issue #35142: When a checkpoint is restored, any terminal processes
+    /// that were started after that checkpoint should be terminated, and any in-progress
+    /// AI generation should be canceled.
+    #[gpui::test]
+    async fn test_restore_checkpoint_kills_terminal(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        let project = Project::test(fs, [], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+        let thread = cx
+            .update(|cx| connection.new_thread(project, Path::new(path!("/test")), cx))
+            .await
+            .unwrap();
+
+        // Send first user message to create a checkpoint
+        cx.update(|cx| {
+            thread.update(cx, |thread, cx| {
+                thread.send(vec!["first message".into()], cx)
+            })
+        })
+        .await
+        .unwrap();
+
+        // Send second message (creates another checkpoint) - we'll restore to this one
+        cx.update(|cx| {
+            thread.update(cx, |thread, cx| {
+                thread.send(vec!["second message".into()], cx)
+            })
+        })
+        .await
+        .unwrap();
+
+        // Create 2 terminals BEFORE the checkpoint that have completed running
+        let terminal_id_1 = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
+        let mock_terminal_1 = cx.new(|cx| {
+            let builder = ::terminal::TerminalBuilder::new_display_only(
+                ::terminal::terminal_settings::CursorShape::default(),
+                ::terminal::terminal_settings::AlternateScroll::On,
+                None,
+                0,
+            )
+            .unwrap();
+            builder.subscribe(cx)
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.on_terminal_provider_event(
+                TerminalProviderEvent::Created {
+                    terminal_id: terminal_id_1.clone(),
+                    label: "echo 'first'".to_string(),
+                    cwd: Some(PathBuf::from("/test")),
+                    output_byte_limit: None,
+                    terminal: mock_terminal_1.clone(),
+                },
+                cx,
+            );
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.on_terminal_provider_event(
+                TerminalProviderEvent::Output {
+                    terminal_id: terminal_id_1.clone(),
+                    data: b"first\n".to_vec(),
+                },
+                cx,
+            );
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.on_terminal_provider_event(
+                TerminalProviderEvent::Exit {
+                    terminal_id: terminal_id_1.clone(),
+                    status: acp::TerminalExitStatus {
+                        exit_code: Some(0),
+                        signal: None,
+                        meta: None,
+                    },
+                },
+                cx,
+            );
+        });
+
+        let terminal_id_2 = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
+        let mock_terminal_2 = cx.new(|cx| {
+            let builder = ::terminal::TerminalBuilder::new_display_only(
+                ::terminal::terminal_settings::CursorShape::default(),
+                ::terminal::terminal_settings::AlternateScroll::On,
+                None,
+                0,
+            )
+            .unwrap();
+            builder.subscribe(cx)
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.on_terminal_provider_event(
+                TerminalProviderEvent::Created {
+                    terminal_id: terminal_id_2.clone(),
+                    label: "echo 'second'".to_string(),
+                    cwd: Some(PathBuf::from("/test")),
+                    output_byte_limit: None,
+                    terminal: mock_terminal_2.clone(),
+                },
+                cx,
+            );
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.on_terminal_provider_event(
+                TerminalProviderEvent::Output {
+                    terminal_id: terminal_id_2.clone(),
+                    data: b"second\n".to_vec(),
+                },
+                cx,
+            );
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.on_terminal_provider_event(
+                TerminalProviderEvent::Exit {
+                    terminal_id: terminal_id_2.clone(),
+                    status: acp::TerminalExitStatus {
+                        exit_code: Some(0),
+                        signal: None,
+                        meta: None,
+                    },
+                },
+                cx,
+            );
+        });
+
+        // Get the second message ID to restore to
+        let second_message_id = thread.read_with(cx, |thread, _| {
+            // At this point we have:
+            // - Index 0: First user message (with checkpoint)
+            // - Index 1: Second user message (with checkpoint)
+            // No assistant responses because FakeAgentConnection just returns EndTurn
+            let AgentThreadEntry::UserMessage(message) = &thread.entries[1] else {
+                panic!("expected user message at index 1");
+            };
+            message.id.clone().unwrap()
+        });
+
+        // Create a terminal AFTER the checkpoint we'll restore to.
+        // This simulates the AI agent starting a long-running terminal command.
+        let terminal_id = acp::TerminalId(uuid::Uuid::new_v4().to_string().into());
+        let mock_terminal = cx.new(|cx| {
+            let builder = ::terminal::TerminalBuilder::new_display_only(
+                ::terminal::terminal_settings::CursorShape::default(),
+                ::terminal::terminal_settings::AlternateScroll::On,
+                None,
+                0,
+            )
+            .unwrap();
+            builder.subscribe(cx)
+        });
+
+        // Register the terminal as created
+        thread.update(cx, |thread, cx| {
+            thread.on_terminal_provider_event(
+                TerminalProviderEvent::Created {
+                    terminal_id: terminal_id.clone(),
+                    label: "sleep 1000".to_string(),
+                    cwd: Some(PathBuf::from("/test")),
+                    output_byte_limit: None,
+                    terminal: mock_terminal.clone(),
+                },
+                cx,
+            );
+        });
+
+        // Simulate the terminal producing output (still running)
+        thread.update(cx, |thread, cx| {
+            thread.on_terminal_provider_event(
+                TerminalProviderEvent::Output {
+                    terminal_id: terminal_id.clone(),
+                    data: b"terminal is running...\n".to_vec(),
+                },
+                cx,
+            );
+        });
+
+        // Create a tool call entry that references this terminal
+        // This represents the agent requesting a terminal command
+        thread.update(cx, |thread, cx| {
+            thread
+                .handle_session_update(
+                    acp::SessionUpdate::ToolCall(acp::ToolCall {
+                        id: acp::ToolCallId("terminal-tool-1".into()),
+                        title: "Running command".into(),
+                        kind: acp::ToolKind::Execute,
+                        status: acp::ToolCallStatus::InProgress,
+                        content: vec![acp::ToolCallContent::Terminal {
+                            terminal_id: terminal_id.clone(),
+                        }],
+                        locations: vec![],
+                        raw_input: Some(
+                            serde_json::json!({"command": "sleep 1000", "cd": "/test"}),
+                        ),
+                        raw_output: None,
+                        meta: None,
+                    }),
+                    cx,
+                )
+                .unwrap();
+        });
+
+        // Verify terminal exists and is in the thread
+        let terminal_exists_before =
+            thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
+        assert!(
+            terminal_exists_before,
+            "Terminal should exist before checkpoint restore"
+        );
+
+        // Verify the terminal's underlying task is still running (not completed)
+        let terminal_running_before = thread.read_with(cx, |thread, _cx| {
+            let terminal_entity = thread.terminals.get(&terminal_id).unwrap();
+            terminal_entity.read_with(cx, |term, _cx| {
+                term.output().is_none() // output is None means it's still running
+            })
+        });
+        assert!(
+            terminal_running_before,
+            "Terminal should be running before checkpoint restore"
+        );
+
+        // Verify we have the expected entries before restore
+        let entry_count_before = thread.read_with(cx, |thread, _| thread.entries.len());
+        assert!(
+            entry_count_before > 1,
+            "Should have multiple entries before restore"
+        );
+
+        // Restore the checkpoint to the second message.
+        // This should:
+        // 1. Cancel any in-progress generation (via the cancel() call)
+        // 2. Remove the terminal that was created after that point
+        thread
+            .update(cx, |thread, cx| {
+                thread.restore_checkpoint(second_message_id, cx)
+            })
+            .await
+            .unwrap();
+
+        // Verify that no send_task is in progress after restore
+        // (cancel() clears the send_task)
+        let has_send_task_after = thread.read_with(cx, |thread, _| thread.send_task.is_some());
+        assert!(
+            !has_send_task_after,
+            "Should not have a send_task after restore (cancel should have cleared it)"
+        );
+
+        // Verify the entries were truncated (restoring to index 1 truncates at 1, keeping only index 0)
+        let entry_count = thread.read_with(cx, |thread, _| thread.entries.len());
+        assert_eq!(
+            entry_count, 1,
+            "Should have 1 entry after restore (only the first user message)"
+        );
+
+        // Verify the 2 completed terminals from before the checkpoint still exist
+        let terminal_1_exists = thread.read_with(cx, |thread, _| {
+            thread.terminals.contains_key(&terminal_id_1)
+        });
+        assert!(
+            terminal_1_exists,
+            "Terminal 1 (from before checkpoint) should still exist"
+        );
+
+        let terminal_2_exists = thread.read_with(cx, |thread, _| {
+            thread.terminals.contains_key(&terminal_id_2)
+        });
+        assert!(
+            terminal_2_exists,
+            "Terminal 2 (from before checkpoint) should still exist"
+        );
+
+        // Verify they're still in completed state
+        let terminal_1_completed = thread.read_with(cx, |thread, _cx| {
+            let terminal_entity = thread.terminals.get(&terminal_id_1).unwrap();
+            terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
+        });
+        assert!(terminal_1_completed, "Terminal 1 should still be completed");
+
+        let terminal_2_completed = thread.read_with(cx, |thread, _cx| {
+            let terminal_entity = thread.terminals.get(&terminal_id_2).unwrap();
+            terminal_entity.read_with(cx, |term, _cx| term.output().is_some())
+        });
+        assert!(terminal_2_completed, "Terminal 2 should still be completed");
+
+        // Verify the running terminal (created after checkpoint) was removed
+        let terminal_3_exists =
+            thread.read_with(cx, |thread, _| thread.terminals.contains_key(&terminal_id));
+        assert!(
+            !terminal_3_exists,
+            "Terminal 3 (created after checkpoint) should have been removed"
+        );
+
+        // Verify total count is 2 (the two from before the checkpoint)
+        let terminal_count = thread.read_with(cx, |thread, _| thread.terminals.len());
+        assert_eq!(
+            terminal_count, 2,
+            "Should have exactly 2 terminals (the completed ones from before checkpoint)"
+        );
+    }
 }