more

Ben Brandt created

Change summary

crates/acp_thread/src/acp_thread.rs | 617 +++++++++++++++++++++++++++++-
crates/action_log/src/action_log.rs | 304 ++++++++++++++-
2 files changed, 866 insertions(+), 55 deletions(-)

Detailed changes

crates/acp_thread/src/acp_thread.rs 🔗

@@ -1063,8 +1063,10 @@ pub struct AcpThread {
     pending_terminal_exit: HashMap<acp::TerminalId, acp::TerminalExitStatus>,
     inferred_edit_candidates:
         HashMap<acp::ToolCallId, HashMap<PathBuf, InferredEditCandidateState>>,
+    inferred_edit_tool_call_turns: HashMap<acp::ToolCallId, u32>,
     finalizing_inferred_edit_tool_calls: HashSet<acp::ToolCallId>,
     next_inferred_edit_candidate_nonce: u64,
+
     had_error: bool,
     /// The user's unsent prompt text, persisted so it can be restored when reloading the thread.
     draft_prompt: Option<Vec<acp::ContentBlock>>,
@@ -1253,8 +1255,10 @@ impl AcpThread {
             pending_terminal_output: HashMap::default(),
             pending_terminal_exit: HashMap::default(),
             inferred_edit_candidates: HashMap::default(),
+            inferred_edit_tool_call_turns: HashMap::default(),
             finalizing_inferred_edit_tool_calls: HashSet::default(),
             next_inferred_edit_candidate_nonce: 0,
+
             had_error: false,
             draft_prompt: None,
             ui_scroll_position: None,
@@ -1757,6 +1761,17 @@ impl AcpThread {
             && !tool_call.locations.is_empty()
     }
 
+    fn should_track_inferred_external_edits(tool_call: &ToolCall) -> bool {
+        Self::should_infer_external_edits(tool_call)
+            && matches!(
+                tool_call.status,
+                ToolCallStatus::Pending
+                    | ToolCallStatus::InProgress
+                    | ToolCallStatus::Completed
+                    | ToolCallStatus::Failed
+            )
+    }
+
     fn is_inferred_edit_terminal_status(status: &ToolCallStatus) -> bool {
         matches!(status, ToolCallStatus::Completed | ToolCallStatus::Failed)
     }
@@ -1780,6 +1795,37 @@ impl AcpThread {
         }
     }
 
+    fn inferred_edit_tracking_turn_id(&self) -> u32 {
+        self.running_turn
+            .as_ref()
+            .map_or(self.turn_id, |turn| turn.id)
+    }
+
+    fn inferred_edit_tool_call_belongs_to_turn(
+        &self,
+        tool_call_id: &acp::ToolCallId,
+        turn_id: u32,
+    ) -> bool {
+        self.inferred_edit_tool_call_turns.get(tool_call_id) == Some(&turn_id)
+    }
+
+    fn record_inferred_edit_tool_call_turn_if_needed(
+        &mut self,
+        tool_call_id: &acp::ToolCallId,
+        cx: &mut Context<Self>,
+    ) {
+        let turn_id = self.inferred_edit_tracking_turn_id();
+        if self.inferred_edit_tool_call_belongs_to_turn(tool_call_id, turn_id) {
+            return;
+        }
+
+        let buffers_to_end = self.clear_inferred_edit_tool_call_tracking(tool_call_id);
+        self.end_expected_external_edits(buffers_to_end, cx);
+
+        self.inferred_edit_tool_call_turns
+            .insert(tool_call_id.clone(), turn_id);
+    }
+
     fn remove_inferred_edit_tool_call_if_empty(&mut self, tool_call_id: &acp::ToolCallId) {
         let remove_tool_call = self
             .inferred_edit_candidates
@@ -1788,6 +1834,7 @@ impl AcpThread {
 
         if remove_tool_call {
             self.inferred_edit_candidates.remove(tool_call_id);
+            self.inferred_edit_tool_call_turns.remove(tool_call_id);
             self.finalizing_inferred_edit_tool_calls
                 .remove(tool_call_id);
         }
@@ -1805,6 +1852,7 @@ impl AcpThread {
             .filter_map(|candidate_state| candidate_state.into_buffer_to_end())
             .collect::<Vec<_>>();
 
+        self.inferred_edit_tool_call_turns.remove(tool_call_id);
         self.finalizing_inferred_edit_tool_calls
             .remove(tool_call_id);
 
@@ -1853,13 +1901,20 @@ impl AcpThread {
         self.end_expected_external_edits(buffers_to_end, cx);
     }
 
-    fn finalize_all_inferred_edit_tool_calls(&mut self, cx: &mut Context<Self>) {
+    fn finalize_all_inferred_edit_tool_calls_for_turn(
+        &mut self,
+        turn_id: u32,
+        cx: &mut Context<Self>,
+    ) {
         let tool_call_ids = self
             .inferred_edit_candidates
             .keys()
+            .filter(|tool_call_id| {
+                self.inferred_edit_tool_call_belongs_to_turn(tool_call_id, turn_id)
+            })
             .filter(|tool_call_id| {
                 self.tool_call(tool_call_id).is_some_and(|(_, tool_call)| {
-                    Self::should_infer_external_edits(tool_call)
+                    Self::should_track_inferred_external_edits(tool_call)
                         && Self::is_inferred_edit_terminal_status(&tool_call.status)
                 })
             })
@@ -1870,6 +1925,31 @@ impl AcpThread {
         }
     }
 
+    fn finish_inferred_edit_tracking_for_stopped_turn(
+        &mut self,
+        turn_id: u32,
+        cx: &mut Context<Self>,
+    ) {
+        self.finalize_all_inferred_edit_tool_calls_for_turn(turn_id, cx);
+
+        let tool_call_ids_to_clear = self
+            .inferred_edit_candidates
+            .keys()
+            .filter(|tool_call_id| {
+                self.inferred_edit_tool_call_belongs_to_turn(tool_call_id, turn_id)
+            })
+            .filter(|tool_call_id| {
+                !self.tool_call(tool_call_id).is_some_and(|(_, tool_call)| {
+                    Self::should_track_inferred_external_edits(tool_call)
+                        && Self::is_inferred_edit_terminal_status(&tool_call.status)
+                })
+            })
+            .cloned()
+            .collect::<Vec<_>>();
+
+        self.clear_inferred_edit_candidates_for_tool_calls(tool_call_ids_to_clear, cx);
+    }
+
     fn sync_inferred_edit_candidate_paths(
         &mut self,
         tool_call_id: &acp::ToolCallId,
@@ -2162,7 +2242,7 @@ impl AcpThread {
             return;
         };
 
-        let should_track = Self::should_infer_external_edits(tool_call);
+        let should_track = Self::should_track_inferred_external_edits(tool_call);
         let should_finalize = Self::is_inferred_edit_terminal_status(&tool_call.status);
         let locations = tool_call.locations.clone();
 
@@ -2171,6 +2251,7 @@ impl AcpThread {
             return;
         }
 
+        self.record_inferred_edit_tool_call_turn_if_needed(&tool_call_id, cx);
         self.register_inferred_edit_locations(tool_call_id.clone(), &locations, cx);
 
         if should_finalize {
@@ -2489,6 +2570,18 @@ impl AcpThread {
             return;
         };
 
+        let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = &call.status else {
+            return;
+        };
+
+        if respond_tx.is_canceled() {
+            log::warn!(
+                "dropping stale tool authorization for call `{}` because it is no longer waiting for confirmation",
+                id
+            );
+            return;
+        }
+
         let new_status = match option_kind {
             acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => {
                 ToolCallStatus::Rejected
@@ -2499,15 +2592,24 @@ impl AcpThread {
             _ => ToolCallStatus::InProgress,
         };
 
-        let curr_status = mem::replace(&mut call.status, new_status);
+        let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } =
+            mem::replace(&mut call.status, ToolCallStatus::Canceled)
+        else {
+            return;
+        };
 
-        if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status {
-            respond_tx.send(outcome).log_err();
-        } else if cfg!(debug_assertions) {
-            panic!("tried to authorize an already authorized tool call");
+        if respond_tx.send(outcome).is_err() {
+            log::warn!(
+                "dropping stale tool authorization for call `{}` because it is no longer waiting for confirmation",
+                id
+            );
+            return;
         }
 
+        call.status = new_status;
+
         cx.emit(AcpThreadEvent::EntryUpdated(ix));
+        self.refresh_inferred_edit_tool_call(id, cx);
     }
 
     pub fn plan(&self) -> &Plan {
@@ -2681,7 +2783,7 @@ impl AcpThread {
                         Self::flush_streaming_text(&mut this.streaming_text_buffer, cx);
 
                         if r.stop_reason == acp::StopReason::MaxTokens {
-                            this.finalize_all_inferred_edit_tool_calls(cx);
+                            this.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx);
                             this.had_error = true;
                             cx.emit(AcpThreadEvent::Error);
                             log::error!("Max tokens reached. Usage: {:?}", this.token_usage);
@@ -2690,12 +2792,8 @@ impl AcpThread {
 
                         let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled);
                         if canceled {
-                            let canceled_tool_call_ids = this.mark_pending_tools_as_canceled();
-                            this.clear_inferred_edit_candidates_for_tool_calls(
-                                canceled_tool_call_ids,
-                                cx,
-                            );
-                            this.finalize_all_inferred_edit_tool_calls(cx);
+                            this.mark_pending_tools_as_canceled();
+                            this.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx);
                         }
 
                         // Handle refusal - distinguish between user prompt and tool call refusals
@@ -2749,14 +2847,14 @@ impl AcpThread {
                             }
                         }
 
-                        this.finalize_all_inferred_edit_tool_calls(cx);
+                        this.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx);
                         cx.emit(AcpThreadEvent::Stopped(r.stop_reason));
                         Ok(Some(r))
                     }
                     Err(e) => {
                         Self::flush_streaming_text(&mut this.streaming_text_buffer, cx);
 
-                        this.finalize_all_inferred_edit_tool_calls(cx);
+                        this.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx);
                         this.had_error = true;
                         cx.emit(AcpThreadEvent::Error);
                         log::error!("Error in run turn: {:?}", e);
@@ -2772,12 +2870,12 @@ impl AcpThread {
         let Some(turn) = self.running_turn.take() else {
             return Task::ready(());
         };
+        let turn_id = turn.id;
         self.connection.cancel(&self.session_id, cx);
 
         Self::flush_streaming_text(&mut self.streaming_text_buffer, cx);
-        let canceled_tool_call_ids = self.mark_pending_tools_as_canceled();
-        self.clear_inferred_edit_candidates_for_tool_calls(canceled_tool_call_ids, cx);
-        self.finalize_all_inferred_edit_tool_calls(cx);
+        self.mark_pending_tools_as_canceled();
+        self.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx);
 
         cx.background_spawn(turn.send_task)
     }
@@ -4326,6 +4424,434 @@ mod tests {
         assert!(cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls()));
     }
 
+    #[gpui::test]
+    async fn test_waiting_for_confirmation_does_not_start_inferred_edit_tracking(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        let tool_call_id = acp::ToolCallId::new("test");
+        let allow_option_id = acp::PermissionOptionId::new("allow");
+        let deny_option_id = acp::PermissionOptionId::new("deny");
+        let _authorization_task = thread
+            .update(cx, |thread, cx| {
+                thread.request_tool_call_authorization(
+                    acp::ToolCall::new(tool_call_id.clone(), "Label")
+                        .kind(acp::ToolKind::Edit)
+                        .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))])
+                        .into(),
+                    PermissionOptions::Flat(vec![
+                        acp::PermissionOption::new(
+                            allow_option_id.clone(),
+                            "Allow",
+                            acp::PermissionOptionKind::AllowOnce,
+                        ),
+                        acp::PermissionOption::new(
+                            deny_option_id,
+                            "Deny",
+                            acp::PermissionOptionKind::RejectOnce,
+                        ),
+                    ]),
+                    cx,
+                )
+            })
+            .unwrap();
+
+        cx.run_until_parked();
+
+        assert_eq!(inferred_edit_candidate_count(&thread, cx), 0);
+        assert!(!cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls()));
+        assert!(cx.read(|cx| thread.read(cx).is_waiting_for_confirmation()));
+    }
+
+    #[gpui::test]
+    async fn test_authorizing_waiting_tool_call_starts_inferred_edit_tracking(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        let tool_call_id = acp::ToolCallId::new("test");
+        let allow_option_id = acp::PermissionOptionId::new("allow");
+        let deny_option_id = acp::PermissionOptionId::new("deny");
+        let _authorization_task = thread
+            .update(cx, |thread, cx| {
+                thread.request_tool_call_authorization(
+                    acp::ToolCall::new(tool_call_id.clone(), "Label")
+                        .kind(acp::ToolKind::Edit)
+                        .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))])
+                        .into(),
+                    PermissionOptions::Flat(vec![
+                        acp::PermissionOption::new(
+                            allow_option_id.clone(),
+                            "Allow",
+                            acp::PermissionOptionKind::AllowOnce,
+                        ),
+                        acp::PermissionOption::new(
+                            deny_option_id,
+                            "Deny",
+                            acp::PermissionOptionKind::RejectOnce,
+                        ),
+                    ]),
+                    cx,
+                )
+            })
+            .unwrap();
+
+        thread.update(cx, |thread, cx| {
+            thread.authorize_tool_call(
+                tool_call_id.clone(),
+                allow_option_id.into(),
+                acp::PermissionOptionKind::AllowOnce,
+                cx,
+            );
+        });
+        cx.run_until_parked();
+
+        assert_eq!(inferred_edit_candidate_count(&thread, cx), 1);
+        assert!(cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls()));
+    }
+
+    #[gpui::test]
+    async fn test_stale_authorization_does_not_rewrite_status_or_start_inferred_edit_tracking(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        let tool_call_id = acp::ToolCallId::new("test");
+        let allow_option_id = acp::PermissionOptionId::new("allow");
+        let deny_option_id = acp::PermissionOptionId::new("deny");
+        let _authorization_task = thread
+            .update(cx, |thread, cx| {
+                thread.request_tool_call_authorization(
+                    acp::ToolCall::new(tool_call_id.clone(), "Label")
+                        .kind(acp::ToolKind::Edit)
+                        .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))])
+                        .into(),
+                    PermissionOptions::Flat(vec![
+                        acp::PermissionOption::new(
+                            allow_option_id.clone(),
+                            "Allow",
+                            acp::PermissionOptionKind::AllowOnce,
+                        ),
+                        acp::PermissionOption::new(
+                            deny_option_id,
+                            "Deny",
+                            acp::PermissionOptionKind::RejectOnce,
+                        ),
+                    ]),
+                    cx,
+                )
+            })
+            .unwrap();
+
+        thread.update(cx, |thread, _cx| {
+            let (_, tool_call) = thread.tool_call_mut(&tool_call_id).unwrap();
+            tool_call.status = ToolCallStatus::Rejected;
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.authorize_tool_call(
+                tool_call_id.clone(),
+                allow_option_id.into(),
+                acp::PermissionOptionKind::AllowOnce,
+                cx,
+            );
+        });
+        cx.run_until_parked();
+
+        assert_eq!(inferred_edit_candidate_count(&thread, cx), 0);
+        thread.read_with(cx, |thread, _| {
+            let (_, tool_call) = thread.tool_call(&tool_call_id).unwrap();
+            assert!(matches!(tool_call.status, ToolCallStatus::Rejected));
+        });
+        assert!(!cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls()));
+    }
+
+    #[gpui::test]
+    async fn test_duplicate_authorization_does_not_rewrite_status_or_tracking(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        let tool_call_id = acp::ToolCallId::new("test");
+        let allow_option_id = acp::PermissionOptionId::new("allow");
+        let deny_option_id = acp::PermissionOptionId::new("deny");
+        let _authorization_task = thread
+            .update(cx, |thread, cx| {
+                thread.request_tool_call_authorization(
+                    acp::ToolCall::new(tool_call_id.clone(), "Label")
+                        .kind(acp::ToolKind::Edit)
+                        .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))])
+                        .into(),
+                    PermissionOptions::Flat(vec![
+                        acp::PermissionOption::new(
+                            allow_option_id.clone(),
+                            "Allow",
+                            acp::PermissionOptionKind::AllowOnce,
+                        ),
+                        acp::PermissionOption::new(
+                            deny_option_id.clone(),
+                            "Deny",
+                            acp::PermissionOptionKind::RejectOnce,
+                        ),
+                    ]),
+                    cx,
+                )
+            })
+            .unwrap();
+
+        thread.update(cx, |thread, cx| {
+            thread.authorize_tool_call(
+                tool_call_id.clone(),
+                allow_option_id.into(),
+                acp::PermissionOptionKind::AllowOnce,
+                cx,
+            );
+        });
+        cx.run_until_parked();
+
+        assert_eq!(inferred_edit_candidate_count(&thread, cx), 1);
+        thread.read_with(cx, |thread, _| {
+            let (_, tool_call) = thread.tool_call(&tool_call_id).unwrap();
+            assert!(matches!(tool_call.status, ToolCallStatus::InProgress));
+        });
+
+        thread.update(cx, |thread, cx| {
+            thread.authorize_tool_call(
+                tool_call_id.clone(),
+                deny_option_id.into(),
+                acp::PermissionOptionKind::RejectOnce,
+                cx,
+            );
+        });
+        cx.run_until_parked();
+
+        assert_eq!(inferred_edit_candidate_count(&thread, cx), 1);
+        thread.read_with(cx, |thread, _| {
+            let (_, tool_call) = thread.tool_call(&tool_call_id).unwrap();
+            assert!(matches!(tool_call.status, ToolCallStatus::InProgress));
+        });
+    }
+
+    #[gpui::test]
+    async fn test_authorization_send_failure_cancels_call_without_tracking(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        let tool_call_id = acp::ToolCallId::new("test");
+        let allow_option_id = acp::PermissionOptionId::new("allow");
+        let deny_option_id = acp::PermissionOptionId::new("deny");
+        let authorization_task = thread
+            .update(cx, |thread, cx| {
+                thread.request_tool_call_authorization(
+                    acp::ToolCall::new(tool_call_id.clone(), "Label")
+                        .kind(acp::ToolKind::Edit)
+                        .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))])
+                        .into(),
+                    PermissionOptions::Flat(vec![
+                        acp::PermissionOption::new(
+                            allow_option_id.clone(),
+                            "Allow",
+                            acp::PermissionOptionKind::AllowOnce,
+                        ),
+                        acp::PermissionOption::new(
+                            deny_option_id,
+                            "Deny",
+                            acp::PermissionOptionKind::RejectOnce,
+                        ),
+                    ]),
+                    cx,
+                )
+            })
+            .unwrap();
+        drop(authorization_task);
+        cx.run_until_parked();
+
+        thread.update(cx, |thread, cx| {
+            thread.authorize_tool_call(
+                tool_call_id.clone(),
+                allow_option_id.into(),
+                acp::PermissionOptionKind::AllowOnce,
+                cx,
+            );
+        });
+        cx.run_until_parked();
+
+        assert_eq!(inferred_edit_candidate_count(&thread, cx), 0);
+        thread.read_with(cx, |thread, _| {
+            let (_, tool_call) = thread.tool_call(&tool_call_id).unwrap();
+            assert!(matches!(tool_call.status, ToolCallStatus::Canceled));
+        });
+        assert!(!cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls()));
+    }
+
+    #[gpui::test]
+    async fn test_stopped_turn_clears_unfinished_inferred_edit_tracking(cx: &mut TestAppContext) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        let tool_call_id = acp::ToolCallId::new("test");
+        start_external_edit_tool_call(
+            &thread,
+            &tool_call_id,
+            vec![acp::ToolCallLocation::new(path!("/test/file.txt"))],
+            cx,
+        );
+        cx.run_until_parked();
+
+        assert_eq!(inferred_edit_candidate_count(&thread, cx), 1);
+
+        thread.update(cx, |thread, cx| {
+            let turn_id = thread.inferred_edit_tracking_turn_id();
+            thread.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx);
+        });
+
+        cx.run_until_parked();
+
+        assert_eq!(inferred_edit_candidate_count(&thread, cx), 0);
+    }
+
+    #[gpui::test]
+    async fn test_stopped_turn_only_clears_inferred_edit_tracking_for_its_own_turn(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+        let fs = FakeFs::new(cx.background_executor.clone());
+        fs.insert_tree(
+            path!("/test"),
+            json!({
+                "old.txt": "one\n",
+                "new.txt": "two\n",
+            }),
+        )
+        .await;
+        let project = Project::test(fs, [path!("/test").as_ref()], cx).await;
+        let connection = Rc::new(FakeAgentConnection::new());
+
+        let thread = cx
+            .update(|cx| {
+                connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx)
+            })
+            .await
+            .unwrap();
+
+        let old_tool_call_id = acp::ToolCallId::new("old");
+        let new_tool_call_id = acp::ToolCallId::new("new");
+
+        set_running_turn_for_test(&thread, 1, cx);
+        start_external_edit_tool_call(
+            &thread,
+            &old_tool_call_id,
+            vec![acp::ToolCallLocation::new(path!("/test/old.txt"))],
+            cx,
+        );
+
+        set_running_turn_for_test(&thread, 2, cx);
+        start_external_edit_tool_call(
+            &thread,
+            &new_tool_call_id,
+            vec![acp::ToolCallLocation::new(path!("/test/new.txt"))],
+            cx,
+        );
+        cx.run_until_parked();
+
+        assert_eq!(
+            inferred_edit_candidate_count_for_tool_call(&thread, &old_tool_call_id, cx),
+            1
+        );
+        assert_eq!(
+            inferred_edit_candidate_count_for_tool_call(&thread, &new_tool_call_id, cx),
+            1
+        );
+
+        thread.update(cx, |thread, cx| {
+            thread.finish_inferred_edit_tracking_for_stopped_turn(1, cx);
+        });
+        cx.run_until_parked();
+
+        assert_eq!(
+            inferred_edit_candidate_count_for_tool_call(&thread, &old_tool_call_id, cx),
+            0
+        );
+        assert_eq!(
+            inferred_edit_candidate_count_for_tool_call(&thread, &new_tool_call_id, cx),
+            1
+        );
+        assert!(cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls()));
+    }
+
     #[gpui::test]
     async fn test_infer_external_modified_file_edits_from_tool_call_locations(
         cx: &mut TestAppContext,
@@ -4576,6 +5102,20 @@ mod tests {
         action_log.read_with(cx, |action_log, cx| action_log.changed_buffers(cx).len())
     }
 
+    fn set_running_turn_for_test(
+        thread: &Entity<AcpThread>,
+        turn_id: u32,
+        cx: &mut TestAppContext,
+    ) {
+        thread.update(cx, |thread, _cx| {
+            thread.turn_id = turn_id;
+            thread.running_turn = Some(RunningTurn {
+                id: turn_id,
+                send_task: Task::ready(()),
+            });
+        });
+    }
+
     fn inferred_edit_candidate_count(thread: &Entity<AcpThread>, cx: &TestAppContext) -> usize {
         thread.read_with(cx, |thread, _| {
             thread
@@ -4586,6 +5126,19 @@ mod tests {
         })
     }
 
+    fn inferred_edit_candidate_count_for_tool_call(
+        thread: &Entity<AcpThread>,
+        tool_call_id: &acp::ToolCallId,
+        cx: &TestAppContext,
+    ) -> usize {
+        thread.read_with(cx, |thread, _| {
+            thread
+                .inferred_edit_candidates
+                .get(tool_call_id)
+                .map_or(0, HashMap::len)
+        })
+    }
+
     fn inferred_edit_tool_call_is_finalizing(
         thread: &Entity<AcpThread>,
         tool_call_id: &acp::ToolCallId,
@@ -4650,6 +5203,7 @@ mod tests {
             .unwrap();
 
         let tool_call_id = acp::ToolCallId::new("test");
+        set_running_turn_for_test(&thread, 1, cx);
         start_external_edit_tool_call(
             &thread,
             &tool_call_id,
@@ -4660,13 +5214,8 @@ mod tests {
 
         assert_eq!(inferred_edit_candidate_count(&thread, cx), 1);
 
-        let cancel = thread.update(cx, |thread, cx| {
-            thread.running_turn = Some(RunningTurn {
-                id: 1,
-                send_task: Task::ready(()),
-            });
-            thread.cancel(cx)
-        });
+        let cancel = thread.update(cx, |thread, cx| thread.cancel(cx));
+
         cancel.await;
         cx.run_until_parked();
 
@@ -4696,6 +5245,7 @@ mod tests {
             .unwrap();
 
         let tool_call_id = acp::ToolCallId::new("test");
+        set_running_turn_for_test(&thread, 1, cx);
         start_external_edit_tool_call(
             &thread,
             &tool_call_id,
@@ -4704,13 +5254,8 @@ mod tests {
         );
         cx.run_until_parked();
 
-        let cancel = thread.update(cx, |thread, cx| {
-            thread.running_turn = Some(RunningTurn {
-                id: 1,
-                send_task: Task::ready(()),
-            });
-            thread.cancel(cx)
-        });
+        let cancel = thread.update(cx, |thread, cx| thread.cancel(cx));
+
         cancel.await;
         cx.run_until_parked();
 
@@ -4911,6 +5456,7 @@ mod tests {
         let buffer = open_test_buffer(&project, Path::new(path!("/test/file.txt")), cx).await;
 
         let tool_call_id = acp::ToolCallId::new("test");
+        set_running_turn_for_test(&thread, 1, cx);
         start_external_edit_tool_call(&thread, &tool_call_id, Vec::new(), cx);
 
         let nonce = thread.update(cx, |thread, _cx| {
@@ -4920,6 +5466,9 @@ mod tests {
                 tool_call.locations = vec![acp::ToolCallLocation::new(abs_path.clone())];
                 tool_call.resolved_locations = vec![None];
             }
+            thread
+                .inferred_edit_tool_call_turns
+                .insert(tool_call_id.clone(), 1);
             thread
                 .inferred_edit_candidates
                 .entry(tool_call_id.clone())

crates/action_log/src/action_log.rs 🔗

@@ -881,7 +881,12 @@ impl ActionLog {
     ) {
         let version = buffer.read(cx).version();
         let diff_base = match &status {
-            TrackedBufferStatus::Created { .. } => Rope::default(),
+            TrackedBufferStatus::Created {
+                existing_file_content: Some(existing_file_content),
+            } => existing_file_content.clone(),
+            TrackedBufferStatus::Created {
+                existing_file_content: None,
+            } => Rope::default(),
             TrackedBufferStatus::Modified | TrackedBufferStatus::Deleted => {
                 baseline_snapshot.as_rope().clone()
             }
@@ -922,13 +927,9 @@ impl ActionLog {
             .file()
             .is_some_and(|file| file.disk_state().exists());
         let had_tracked_buffer = self.tracked_buffers.contains_key(&buffer);
-        let has_attributed_change = self
-            .tracked_buffers
-            .get(&buffer)
-            .is_some_and(|tracked_buffer| tracked_buffer.has_edits(cx));
 
-        let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx);
         if !had_tracked_buffer {
+            let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx);
             tracked_buffer.mode = TrackedBufferMode::ExpectationOnly;
             tracked_buffer.status = TrackedBufferStatus::Modified;
             tracked_buffer.diff_base = buffer.read(cx).as_rope().clone();
@@ -936,6 +937,12 @@ impl ActionLog {
             tracked_buffer.unreviewed_edits.clear();
         }
 
+        // Reusing an existing tracked buffer must preserve its prior version so stale-buffer
+        // detection continues to reflect any user edits that predate the expectation.
+        let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else {
+            return;
+        };
+
         let expected_external_edit =
             tracked_buffer
                 .expected_external_edit
@@ -945,7 +952,7 @@ impl ActionLog {
                     initial_exists_on_disk,
                     observed_external_file_change: false,
                     armed_explicit_reload: false,
-                    has_attributed_change,
+                    has_attributed_change: false,
                     pending_delete: false,
                     is_disqualified: false,
                 });
@@ -1094,17 +1101,17 @@ impl ActionLog {
     ) {
         if let Some(linked_action_log) = &self.linked_action_log {
             let linked_baseline_snapshot = baseline_snapshot.clone();
-            if !linked_action_log.read(cx).has_changed_buffer(buffer, cx) {
-                linked_action_log.update(cx, |log, cx| {
-                    log.infer_buffer_from_snapshot_impl(
-                        buffer.clone(),
-                        linked_baseline_snapshot,
-                        kind,
-                        false,
-                        cx,
-                    );
-                });
-            }
+            // Later inferred snapshots must keep refreshing linked logs for the same buffer so
+            // parent and child review state do not diverge after the first forwarded hunk.
+            linked_action_log.update(cx, |log, cx| {
+                log.infer_buffer_from_snapshot_impl(
+                    buffer.clone(),
+                    linked_baseline_snapshot,
+                    kind,
+                    false,
+                    cx,
+                );
+            });
         }
     }
 
@@ -1129,10 +1136,11 @@ impl ActionLog {
             }
         }
 
+        let tracked_buffer_status = kind.tracked_buffer_status(&baseline_snapshot);
         self.prime_tracked_buffer_from_snapshot(
             buffer.clone(),
             baseline_snapshot,
-            kind.tracked_buffer_status(),
+            tracked_buffer_status,
             cx,
         );
 
@@ -1818,10 +1826,17 @@ enum InferredSnapshotKind {
 }
 
 impl InferredSnapshotKind {
-    fn tracked_buffer_status(self) -> TrackedBufferStatus {
+    fn tracked_buffer_status(
+        self,
+        baseline_snapshot: &text::BufferSnapshot,
+    ) -> TrackedBufferStatus {
         match self {
             Self::Created => TrackedBufferStatus::Created {
-                existing_file_content: None,
+                existing_file_content: if baseline_snapshot.text().is_empty() {
+                    None
+                } else {
+                    Some(baseline_snapshot.as_rope().clone())
+                },
             },
             Self::Edited => TrackedBufferStatus::Modified,
             Self::Deleted => TrackedBufferStatus::Deleted,
@@ -4077,6 +4092,88 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_linked_action_log_forwards_sequential_inferred_snapshots(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/dir"), json!({"file": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let parent_log = cx.new(|_| ActionLog::new(project.clone()));
+        let child_log =
+            cx.new(|_| ActionLog::new(project.clone()).with_linked_action_log(parent_log.clone()));
+
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .expect("test file should exist");
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .expect("test buffer should open");
+
+        let first_baseline_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
+        buffer.update(cx, |buffer, cx| buffer.set_text("one\ntwo\nthree\n", cx));
+        let second_baseline_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
+        buffer.update(cx, |buffer, cx| {
+            buffer.set_text("one\ntwo\nthree\nfour\n", cx)
+        });
+        project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .await
+            .expect("final inferred buffer contents should save");
+
+        cx.update(|cx| {
+            child_log.update(cx, |log, cx| {
+                log.infer_buffer_edited_from_snapshot(
+                    buffer.clone(),
+                    first_baseline_snapshot.clone(),
+                    cx,
+                );
+            });
+        });
+        cx.run_until_parked();
+
+        let first_child_hunks = unreviewed_hunks(&child_log, cx);
+        assert!(
+            !first_child_hunks.is_empty(),
+            "the first inferred snapshot should produce review hunks"
+        );
+        assert_eq!(
+            unreviewed_hunks(&parent_log, cx),
+            first_child_hunks,
+            "parent should match the first forwarded inferred snapshot"
+        );
+
+        cx.update(|cx| {
+            child_log.update(cx, |log, cx| {
+                log.infer_buffer_edited_from_snapshot(
+                    buffer.clone(),
+                    second_baseline_snapshot.clone(),
+                    cx,
+                );
+            });
+        });
+        cx.run_until_parked();
+
+        let second_child_hunks = unreviewed_hunks(&child_log, cx);
+        assert!(
+            !second_child_hunks.is_empty(),
+            "the second inferred snapshot should still produce review hunks"
+        );
+        assert_ne!(
+            second_child_hunks, first_child_hunks,
+            "the second inferred snapshot should refresh the tracked diff"
+        );
+        assert_eq!(
+            unreviewed_hunks(&parent_log, cx),
+            second_child_hunks,
+            "parent should stay in sync after sequential inferred snapshots on one buffer"
+        );
+    }
+
     #[gpui::test]
     async fn test_linked_action_log_infer_buffer_created(cx: &mut TestAppContext) {
         init_test(cx);
@@ -4499,6 +4596,171 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_expected_external_edit_starts_unattributed_even_with_existing_hunks(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/dir"), json!({"file": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+            buffer.update(cx, |buffer, cx| buffer.set_text("one\ntwo\nthree\n", cx));
+            action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx));
+        });
+        cx.run_until_parked();
+
+        assert!(
+            !unreviewed_hunks(&action_log, cx).is_empty(),
+            "buffer should already have tracked hunks before the expectation starts"
+        );
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| {
+                log.begin_expected_external_edit(buffer.clone(), cx);
+            });
+        });
+
+        assert!(
+            action_log.read_with(cx, |log, _| {
+                log.tracked_buffers
+                    .get(&buffer)
+                    .and_then(|tracked_buffer| tracked_buffer.expected_external_edit.as_ref())
+                    .is_some_and(|expected_external_edit| {
+                        !expected_external_edit.has_attributed_change
+                    })
+            }),
+            "a new expected external edit should start as unattributed even when the buffer already has hunks"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_expected_external_edit_preserves_stale_tracking_for_existing_tracked_buffer(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/dir"), json!({"file": "one\ntwo\n"}))
+            .await;
+        let project = Project::test(fs, [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| project.find_project_path("dir/file", cx))
+            .expect("test file should exist");
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .expect("test buffer should open");
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx));
+        });
+
+        cx.update(|cx| {
+            buffer.update(cx, |buffer, cx| {
+                assert!(buffer.edit([(0..0, "zero\n")], None, cx).is_some());
+            });
+        });
+        cx.run_until_parked();
+
+        assert_eq!(
+            action_log.read_with(cx, |log, cx| {
+                log.stale_buffers(cx).cloned().collect::<Vec<_>>()
+            }),
+            vec![buffer.clone()],
+            "user edits after a read should mark the tracked buffer as stale"
+        );
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| {
+                log.begin_expected_external_edit(buffer.clone(), cx);
+            });
+        });
+
+        assert_eq!(
+            action_log.read_with(cx, |log, cx| {
+                log.stale_buffers(cx).cloned().collect::<Vec<_>>()
+            }),
+            vec![buffer.clone()],
+            "starting an expected external edit should not clear existing stale tracking"
+        );
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| {
+                log.end_expected_external_edit(buffer.clone(), cx);
+            });
+        });
+
+        assert_eq!(
+            action_log.read_with(cx, |log, cx| {
+                log.stale_buffers(cx).cloned().collect::<Vec<_>>()
+            }),
+            vec![buffer],
+            "ending an unattributed expected external edit should preserve existing stale tracking"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_infer_buffer_created_preserves_non_empty_baseline_on_reject(
+        cx: &mut TestAppContext,
+    ) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(path!("/dir"), json!({})).await;
+        let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let file_path = project
+            .read_with(cx, |project, cx| {
+                project.find_project_path("dir/new_file", cx)
+            })
+            .unwrap();
+        let buffer = project
+            .update(cx, |project, cx| project.open_buffer(file_path, cx))
+            .await
+            .unwrap();
+
+        buffer.update(cx, |buffer, cx| buffer.set_text("draft\n", cx));
+        let baseline_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot());
+
+        buffer.update(cx, |buffer, cx| buffer.set_text("draft\nagent\n", cx));
+        project
+            .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
+            .await
+            .unwrap();
+
+        cx.update(|cx| {
+            action_log.update(cx, |log, cx| {
+                log.infer_buffer_created(buffer.clone(), baseline_snapshot.clone(), cx);
+            });
+        });
+        cx.run_until_parked();
+
+        action_log
+            .update(cx, |log, cx| log.reject_all_edits(None, cx))
+            .await;
+        cx.run_until_parked();
+
+        assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "draft\n");
+        assert_eq!(
+            String::from_utf8(fs.read_file_sync(path!("/dir/new_file")).unwrap()).unwrap(),
+            "draft\n"
+        );
+    }
+
     #[gpui::test]
     async fn test_infer_buffer_edited_from_snapshot(cx: &mut TestAppContext) {
         init_test(cx);