diff --git a/crates/acp_thread/src/acp_thread.rs b/crates/acp_thread/src/acp_thread.rs index 9842a9d1203f94a57208383d9eced56bd96582dd..5c8088c1af87d6869c65eafffc939fc272c1c71d 100644 --- a/crates/acp_thread/src/acp_thread.rs +++ b/crates/acp_thread/src/acp_thread.rs @@ -1063,8 +1063,10 @@ pub struct AcpThread { pending_terminal_exit: HashMap, inferred_edit_candidates: HashMap>, + inferred_edit_tool_call_turns: HashMap, finalizing_inferred_edit_tool_calls: HashSet, next_inferred_edit_candidate_nonce: u64, + had_error: bool, /// The user's unsent prompt text, persisted so it can be restored when reloading the thread. draft_prompt: Option>, @@ -1253,8 +1255,10 @@ impl AcpThread { pending_terminal_output: HashMap::default(), pending_terminal_exit: HashMap::default(), inferred_edit_candidates: HashMap::default(), + inferred_edit_tool_call_turns: HashMap::default(), finalizing_inferred_edit_tool_calls: HashSet::default(), next_inferred_edit_candidate_nonce: 0, + had_error: false, draft_prompt: None, ui_scroll_position: None, @@ -1757,6 +1761,17 @@ impl AcpThread { && !tool_call.locations.is_empty() } + fn should_track_inferred_external_edits(tool_call: &ToolCall) -> bool { + Self::should_infer_external_edits(tool_call) + && matches!( + tool_call.status, + ToolCallStatus::Pending + | ToolCallStatus::InProgress + | ToolCallStatus::Completed + | ToolCallStatus::Failed + ) + } + fn is_inferred_edit_terminal_status(status: &ToolCallStatus) -> bool { matches!(status, ToolCallStatus::Completed | ToolCallStatus::Failed) } @@ -1780,6 +1795,37 @@ impl AcpThread { } } + fn inferred_edit_tracking_turn_id(&self) -> u32 { + self.running_turn + .as_ref() + .map_or(self.turn_id, |turn| turn.id) + } + + fn inferred_edit_tool_call_belongs_to_turn( + &self, + tool_call_id: &acp::ToolCallId, + turn_id: u32, + ) -> bool { + self.inferred_edit_tool_call_turns.get(tool_call_id) == Some(&turn_id) + } + + fn record_inferred_edit_tool_call_turn_if_needed( + &mut self, + tool_call_id: &acp::ToolCallId, + cx: &mut Context, + ) { + let turn_id = self.inferred_edit_tracking_turn_id(); + if self.inferred_edit_tool_call_belongs_to_turn(tool_call_id, turn_id) { + return; + } + + let buffers_to_end = self.clear_inferred_edit_tool_call_tracking(tool_call_id); + self.end_expected_external_edits(buffers_to_end, cx); + + self.inferred_edit_tool_call_turns + .insert(tool_call_id.clone(), turn_id); + } + fn remove_inferred_edit_tool_call_if_empty(&mut self, tool_call_id: &acp::ToolCallId) { let remove_tool_call = self .inferred_edit_candidates @@ -1788,6 +1834,7 @@ impl AcpThread { if remove_tool_call { self.inferred_edit_candidates.remove(tool_call_id); + self.inferred_edit_tool_call_turns.remove(tool_call_id); self.finalizing_inferred_edit_tool_calls .remove(tool_call_id); } @@ -1805,6 +1852,7 @@ impl AcpThread { .filter_map(|candidate_state| candidate_state.into_buffer_to_end()) .collect::>(); + self.inferred_edit_tool_call_turns.remove(tool_call_id); self.finalizing_inferred_edit_tool_calls .remove(tool_call_id); @@ -1853,13 +1901,20 @@ impl AcpThread { self.end_expected_external_edits(buffers_to_end, cx); } - fn finalize_all_inferred_edit_tool_calls(&mut self, cx: &mut Context) { + fn finalize_all_inferred_edit_tool_calls_for_turn( + &mut self, + turn_id: u32, + cx: &mut Context, + ) { let tool_call_ids = self .inferred_edit_candidates .keys() + .filter(|tool_call_id| { + self.inferred_edit_tool_call_belongs_to_turn(tool_call_id, turn_id) + }) .filter(|tool_call_id| { self.tool_call(tool_call_id).is_some_and(|(_, tool_call)| { - Self::should_infer_external_edits(tool_call) + Self::should_track_inferred_external_edits(tool_call) && Self::is_inferred_edit_terminal_status(&tool_call.status) }) }) @@ -1870,6 +1925,31 @@ impl AcpThread { } } + fn finish_inferred_edit_tracking_for_stopped_turn( + &mut self, + turn_id: u32, + cx: &mut Context, + ) { + self.finalize_all_inferred_edit_tool_calls_for_turn(turn_id, cx); + + let tool_call_ids_to_clear = self + .inferred_edit_candidates + .keys() + .filter(|tool_call_id| { + self.inferred_edit_tool_call_belongs_to_turn(tool_call_id, turn_id) + }) + .filter(|tool_call_id| { + !self.tool_call(tool_call_id).is_some_and(|(_, tool_call)| { + Self::should_track_inferred_external_edits(tool_call) + && Self::is_inferred_edit_terminal_status(&tool_call.status) + }) + }) + .cloned() + .collect::>(); + + self.clear_inferred_edit_candidates_for_tool_calls(tool_call_ids_to_clear, cx); + } + fn sync_inferred_edit_candidate_paths( &mut self, tool_call_id: &acp::ToolCallId, @@ -2162,7 +2242,7 @@ impl AcpThread { return; }; - let should_track = Self::should_infer_external_edits(tool_call); + let should_track = Self::should_track_inferred_external_edits(tool_call); let should_finalize = Self::is_inferred_edit_terminal_status(&tool_call.status); let locations = tool_call.locations.clone(); @@ -2171,6 +2251,7 @@ impl AcpThread { return; } + self.record_inferred_edit_tool_call_turn_if_needed(&tool_call_id, cx); self.register_inferred_edit_locations(tool_call_id.clone(), &locations, cx); if should_finalize { @@ -2489,6 +2570,18 @@ impl AcpThread { return; }; + let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = &call.status else { + return; + }; + + if respond_tx.is_canceled() { + log::warn!( + "dropping stale tool authorization for call `{}` because it is no longer waiting for confirmation", + id + ); + return; + } + let new_status = match option_kind { acp::PermissionOptionKind::RejectOnce | acp::PermissionOptionKind::RejectAlways => { ToolCallStatus::Rejected @@ -2499,15 +2592,24 @@ impl AcpThread { _ => ToolCallStatus::InProgress, }; - let curr_status = mem::replace(&mut call.status, new_status); + let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = + mem::replace(&mut call.status, ToolCallStatus::Canceled) + else { + return; + }; - if let ToolCallStatus::WaitingForConfirmation { respond_tx, .. } = curr_status { - respond_tx.send(outcome).log_err(); - } else if cfg!(debug_assertions) { - panic!("tried to authorize an already authorized tool call"); + if respond_tx.send(outcome).is_err() { + log::warn!( + "dropping stale tool authorization for call `{}` because it is no longer waiting for confirmation", + id + ); + return; } + call.status = new_status; + cx.emit(AcpThreadEvent::EntryUpdated(ix)); + self.refresh_inferred_edit_tool_call(id, cx); } pub fn plan(&self) -> &Plan { @@ -2681,7 +2783,7 @@ impl AcpThread { Self::flush_streaming_text(&mut this.streaming_text_buffer, cx); if r.stop_reason == acp::StopReason::MaxTokens { - this.finalize_all_inferred_edit_tool_calls(cx); + this.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx); this.had_error = true; cx.emit(AcpThreadEvent::Error); log::error!("Max tokens reached. Usage: {:?}", this.token_usage); @@ -2690,12 +2792,8 @@ impl AcpThread { let canceled = matches!(r.stop_reason, acp::StopReason::Cancelled); if canceled { - let canceled_tool_call_ids = this.mark_pending_tools_as_canceled(); - this.clear_inferred_edit_candidates_for_tool_calls( - canceled_tool_call_ids, - cx, - ); - this.finalize_all_inferred_edit_tool_calls(cx); + this.mark_pending_tools_as_canceled(); + this.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx); } // Handle refusal - distinguish between user prompt and tool call refusals @@ -2749,14 +2847,14 @@ impl AcpThread { } } - this.finalize_all_inferred_edit_tool_calls(cx); + this.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx); cx.emit(AcpThreadEvent::Stopped(r.stop_reason)); Ok(Some(r)) } Err(e) => { Self::flush_streaming_text(&mut this.streaming_text_buffer, cx); - this.finalize_all_inferred_edit_tool_calls(cx); + this.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx); this.had_error = true; cx.emit(AcpThreadEvent::Error); log::error!("Error in run turn: {:?}", e); @@ -2772,12 +2870,12 @@ impl AcpThread { let Some(turn) = self.running_turn.take() else { return Task::ready(()); }; + let turn_id = turn.id; self.connection.cancel(&self.session_id, cx); Self::flush_streaming_text(&mut self.streaming_text_buffer, cx); - let canceled_tool_call_ids = self.mark_pending_tools_as_canceled(); - self.clear_inferred_edit_candidates_for_tool_calls(canceled_tool_call_ids, cx); - self.finalize_all_inferred_edit_tool_calls(cx); + self.mark_pending_tools_as_canceled(); + self.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx); cx.background_spawn(turn.send_task) } @@ -4326,6 +4424,434 @@ mod tests { assert!(cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls())); } + #[gpui::test] + async fn test_waiting_for_confirmation_does_not_start_inferred_edit_tracking( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"})) + .await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + + let thread = cx + .update(|cx| { + connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx) + }) + .await + .unwrap(); + + let tool_call_id = acp::ToolCallId::new("test"); + let allow_option_id = acp::PermissionOptionId::new("allow"); + let deny_option_id = acp::PermissionOptionId::new("deny"); + let _authorization_task = thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization( + acp::ToolCall::new(tool_call_id.clone(), "Label") + .kind(acp::ToolKind::Edit) + .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))]) + .into(), + PermissionOptions::Flat(vec![ + acp::PermissionOption::new( + allow_option_id.clone(), + "Allow", + acp::PermissionOptionKind::AllowOnce, + ), + acp::PermissionOption::new( + deny_option_id, + "Deny", + acp::PermissionOptionKind::RejectOnce, + ), + ]), + cx, + ) + }) + .unwrap(); + + cx.run_until_parked(); + + assert_eq!(inferred_edit_candidate_count(&thread, cx), 0); + assert!(!cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls())); + assert!(cx.read(|cx| thread.read(cx).is_waiting_for_confirmation())); + } + + #[gpui::test] + async fn test_authorizing_waiting_tool_call_starts_inferred_edit_tracking( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"})) + .await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + + let thread = cx + .update(|cx| { + connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx) + }) + .await + .unwrap(); + + let tool_call_id = acp::ToolCallId::new("test"); + let allow_option_id = acp::PermissionOptionId::new("allow"); + let deny_option_id = acp::PermissionOptionId::new("deny"); + let _authorization_task = thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization( + acp::ToolCall::new(tool_call_id.clone(), "Label") + .kind(acp::ToolKind::Edit) + .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))]) + .into(), + PermissionOptions::Flat(vec![ + acp::PermissionOption::new( + allow_option_id.clone(), + "Allow", + acp::PermissionOptionKind::AllowOnce, + ), + acp::PermissionOption::new( + deny_option_id, + "Deny", + acp::PermissionOptionKind::RejectOnce, + ), + ]), + cx, + ) + }) + .unwrap(); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call( + tool_call_id.clone(), + allow_option_id.into(), + acp::PermissionOptionKind::AllowOnce, + cx, + ); + }); + cx.run_until_parked(); + + assert_eq!(inferred_edit_candidate_count(&thread, cx), 1); + assert!(cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls())); + } + + #[gpui::test] + async fn test_stale_authorization_does_not_rewrite_status_or_start_inferred_edit_tracking( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"})) + .await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + + let thread = cx + .update(|cx| { + connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx) + }) + .await + .unwrap(); + + let tool_call_id = acp::ToolCallId::new("test"); + let allow_option_id = acp::PermissionOptionId::new("allow"); + let deny_option_id = acp::PermissionOptionId::new("deny"); + let _authorization_task = thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization( + acp::ToolCall::new(tool_call_id.clone(), "Label") + .kind(acp::ToolKind::Edit) + .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))]) + .into(), + PermissionOptions::Flat(vec![ + acp::PermissionOption::new( + allow_option_id.clone(), + "Allow", + acp::PermissionOptionKind::AllowOnce, + ), + acp::PermissionOption::new( + deny_option_id, + "Deny", + acp::PermissionOptionKind::RejectOnce, + ), + ]), + cx, + ) + }) + .unwrap(); + + thread.update(cx, |thread, _cx| { + let (_, tool_call) = thread.tool_call_mut(&tool_call_id).unwrap(); + tool_call.status = ToolCallStatus::Rejected; + }); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call( + tool_call_id.clone(), + allow_option_id.into(), + acp::PermissionOptionKind::AllowOnce, + cx, + ); + }); + cx.run_until_parked(); + + assert_eq!(inferred_edit_candidate_count(&thread, cx), 0); + thread.read_with(cx, |thread, _| { + let (_, tool_call) = thread.tool_call(&tool_call_id).unwrap(); + assert!(matches!(tool_call.status, ToolCallStatus::Rejected)); + }); + assert!(!cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls())); + } + + #[gpui::test] + async fn test_duplicate_authorization_does_not_rewrite_status_or_tracking( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"})) + .await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + + let thread = cx + .update(|cx| { + connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx) + }) + .await + .unwrap(); + + let tool_call_id = acp::ToolCallId::new("test"); + let allow_option_id = acp::PermissionOptionId::new("allow"); + let deny_option_id = acp::PermissionOptionId::new("deny"); + let _authorization_task = thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization( + acp::ToolCall::new(tool_call_id.clone(), "Label") + .kind(acp::ToolKind::Edit) + .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))]) + .into(), + PermissionOptions::Flat(vec![ + acp::PermissionOption::new( + allow_option_id.clone(), + "Allow", + acp::PermissionOptionKind::AllowOnce, + ), + acp::PermissionOption::new( + deny_option_id.clone(), + "Deny", + acp::PermissionOptionKind::RejectOnce, + ), + ]), + cx, + ) + }) + .unwrap(); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call( + tool_call_id.clone(), + allow_option_id.into(), + acp::PermissionOptionKind::AllowOnce, + cx, + ); + }); + cx.run_until_parked(); + + assert_eq!(inferred_edit_candidate_count(&thread, cx), 1); + thread.read_with(cx, |thread, _| { + let (_, tool_call) = thread.tool_call(&tool_call_id).unwrap(); + assert!(matches!(tool_call.status, ToolCallStatus::InProgress)); + }); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call( + tool_call_id.clone(), + deny_option_id.into(), + acp::PermissionOptionKind::RejectOnce, + cx, + ); + }); + cx.run_until_parked(); + + assert_eq!(inferred_edit_candidate_count(&thread, cx), 1); + thread.read_with(cx, |thread, _| { + let (_, tool_call) = thread.tool_call(&tool_call_id).unwrap(); + assert!(matches!(tool_call.status, ToolCallStatus::InProgress)); + }); + } + + #[gpui::test] + async fn test_authorization_send_failure_cancels_call_without_tracking( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"})) + .await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + + let thread = cx + .update(|cx| { + connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx) + }) + .await + .unwrap(); + + let tool_call_id = acp::ToolCallId::new("test"); + let allow_option_id = acp::PermissionOptionId::new("allow"); + let deny_option_id = acp::PermissionOptionId::new("deny"); + let authorization_task = thread + .update(cx, |thread, cx| { + thread.request_tool_call_authorization( + acp::ToolCall::new(tool_call_id.clone(), "Label") + .kind(acp::ToolKind::Edit) + .locations(vec![acp::ToolCallLocation::new(path!("/test/file.txt"))]) + .into(), + PermissionOptions::Flat(vec![ + acp::PermissionOption::new( + allow_option_id.clone(), + "Allow", + acp::PermissionOptionKind::AllowOnce, + ), + acp::PermissionOption::new( + deny_option_id, + "Deny", + acp::PermissionOptionKind::RejectOnce, + ), + ]), + cx, + ) + }) + .unwrap(); + drop(authorization_task); + cx.run_until_parked(); + + thread.update(cx, |thread, cx| { + thread.authorize_tool_call( + tool_call_id.clone(), + allow_option_id.into(), + acp::PermissionOptionKind::AllowOnce, + cx, + ); + }); + cx.run_until_parked(); + + assert_eq!(inferred_edit_candidate_count(&thread, cx), 0); + thread.read_with(cx, |thread, _| { + let (_, tool_call) = thread.tool_call(&tool_call_id).unwrap(); + assert!(matches!(tool_call.status, ToolCallStatus::Canceled)); + }); + assert!(!cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls())); + } + + #[gpui::test] + async fn test_stopped_turn_clears_unfinished_inferred_edit_tracking(cx: &mut TestAppContext) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree(path!("/test"), json!({"file.txt": "one\ntwo\n"})) + .await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + + let thread = cx + .update(|cx| { + connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx) + }) + .await + .unwrap(); + + let tool_call_id = acp::ToolCallId::new("test"); + start_external_edit_tool_call( + &thread, + &tool_call_id, + vec![acp::ToolCallLocation::new(path!("/test/file.txt"))], + cx, + ); + cx.run_until_parked(); + + assert_eq!(inferred_edit_candidate_count(&thread, cx), 1); + + thread.update(cx, |thread, cx| { + let turn_id = thread.inferred_edit_tracking_turn_id(); + thread.finish_inferred_edit_tracking_for_stopped_turn(turn_id, cx); + }); + + cx.run_until_parked(); + + assert_eq!(inferred_edit_candidate_count(&thread, cx), 0); + } + + #[gpui::test] + async fn test_stopped_turn_only_clears_inferred_edit_tracking_for_its_own_turn( + cx: &mut TestAppContext, + ) { + init_test(cx); + let fs = FakeFs::new(cx.background_executor.clone()); + fs.insert_tree( + path!("/test"), + json!({ + "old.txt": "one\n", + "new.txt": "two\n", + }), + ) + .await; + let project = Project::test(fs, [path!("/test").as_ref()], cx).await; + let connection = Rc::new(FakeAgentConnection::new()); + + let thread = cx + .update(|cx| { + connection.new_session(project, PathList::new(&[Path::new(path!("/test"))]), cx) + }) + .await + .unwrap(); + + let old_tool_call_id = acp::ToolCallId::new("old"); + let new_tool_call_id = acp::ToolCallId::new("new"); + + set_running_turn_for_test(&thread, 1, cx); + start_external_edit_tool_call( + &thread, + &old_tool_call_id, + vec![acp::ToolCallLocation::new(path!("/test/old.txt"))], + cx, + ); + + set_running_turn_for_test(&thread, 2, cx); + start_external_edit_tool_call( + &thread, + &new_tool_call_id, + vec![acp::ToolCallLocation::new(path!("/test/new.txt"))], + cx, + ); + cx.run_until_parked(); + + assert_eq!( + inferred_edit_candidate_count_for_tool_call(&thread, &old_tool_call_id, cx), + 1 + ); + assert_eq!( + inferred_edit_candidate_count_for_tool_call(&thread, &new_tool_call_id, cx), + 1 + ); + + thread.update(cx, |thread, cx| { + thread.finish_inferred_edit_tracking_for_stopped_turn(1, cx); + }); + cx.run_until_parked(); + + assert_eq!( + inferred_edit_candidate_count_for_tool_call(&thread, &old_tool_call_id, cx), + 0 + ); + assert_eq!( + inferred_edit_candidate_count_for_tool_call(&thread, &new_tool_call_id, cx), + 1 + ); + assert!(cx.read(|cx| thread.read(cx).has_pending_edit_tool_calls())); + } + #[gpui::test] async fn test_infer_external_modified_file_edits_from_tool_call_locations( cx: &mut TestAppContext, @@ -4576,6 +5102,20 @@ mod tests { action_log.read_with(cx, |action_log, cx| action_log.changed_buffers(cx).len()) } + fn set_running_turn_for_test( + thread: &Entity, + turn_id: u32, + cx: &mut TestAppContext, + ) { + thread.update(cx, |thread, _cx| { + thread.turn_id = turn_id; + thread.running_turn = Some(RunningTurn { + id: turn_id, + send_task: Task::ready(()), + }); + }); + } + fn inferred_edit_candidate_count(thread: &Entity, cx: &TestAppContext) -> usize { thread.read_with(cx, |thread, _| { thread @@ -4586,6 +5126,19 @@ mod tests { }) } + fn inferred_edit_candidate_count_for_tool_call( + thread: &Entity, + tool_call_id: &acp::ToolCallId, + cx: &TestAppContext, + ) -> usize { + thread.read_with(cx, |thread, _| { + thread + .inferred_edit_candidates + .get(tool_call_id) + .map_or(0, HashMap::len) + }) + } + fn inferred_edit_tool_call_is_finalizing( thread: &Entity, tool_call_id: &acp::ToolCallId, @@ -4650,6 +5203,7 @@ mod tests { .unwrap(); let tool_call_id = acp::ToolCallId::new("test"); + set_running_turn_for_test(&thread, 1, cx); start_external_edit_tool_call( &thread, &tool_call_id, @@ -4660,13 +5214,8 @@ mod tests { assert_eq!(inferred_edit_candidate_count(&thread, cx), 1); - let cancel = thread.update(cx, |thread, cx| { - thread.running_turn = Some(RunningTurn { - id: 1, - send_task: Task::ready(()), - }); - thread.cancel(cx) - }); + let cancel = thread.update(cx, |thread, cx| thread.cancel(cx)); + cancel.await; cx.run_until_parked(); @@ -4696,6 +5245,7 @@ mod tests { .unwrap(); let tool_call_id = acp::ToolCallId::new("test"); + set_running_turn_for_test(&thread, 1, cx); start_external_edit_tool_call( &thread, &tool_call_id, @@ -4704,13 +5254,8 @@ mod tests { ); cx.run_until_parked(); - let cancel = thread.update(cx, |thread, cx| { - thread.running_turn = Some(RunningTurn { - id: 1, - send_task: Task::ready(()), - }); - thread.cancel(cx) - }); + let cancel = thread.update(cx, |thread, cx| thread.cancel(cx)); + cancel.await; cx.run_until_parked(); @@ -4911,6 +5456,7 @@ mod tests { let buffer = open_test_buffer(&project, Path::new(path!("/test/file.txt")), cx).await; let tool_call_id = acp::ToolCallId::new("test"); + set_running_turn_for_test(&thread, 1, cx); start_external_edit_tool_call(&thread, &tool_call_id, Vec::new(), cx); let nonce = thread.update(cx, |thread, _cx| { @@ -4920,6 +5466,9 @@ mod tests { tool_call.locations = vec![acp::ToolCallLocation::new(abs_path.clone())]; tool_call.resolved_locations = vec![None]; } + thread + .inferred_edit_tool_call_turns + .insert(tool_call_id.clone(), 1); thread .inferred_edit_candidates .entry(tool_call_id.clone()) diff --git a/crates/action_log/src/action_log.rs b/crates/action_log/src/action_log.rs index 8fee58ea62725b71bb7f4d2bc3b12ea7abcc790b..cbabb881bdc7356f7da4ab9ee65dbdf65259e4fc 100644 --- a/crates/action_log/src/action_log.rs +++ b/crates/action_log/src/action_log.rs @@ -881,7 +881,12 @@ impl ActionLog { ) { let version = buffer.read(cx).version(); let diff_base = match &status { - TrackedBufferStatus::Created { .. } => Rope::default(), + TrackedBufferStatus::Created { + existing_file_content: Some(existing_file_content), + } => existing_file_content.clone(), + TrackedBufferStatus::Created { + existing_file_content: None, + } => Rope::default(), TrackedBufferStatus::Modified | TrackedBufferStatus::Deleted => { baseline_snapshot.as_rope().clone() } @@ -922,13 +927,9 @@ impl ActionLog { .file() .is_some_and(|file| file.disk_state().exists()); let had_tracked_buffer = self.tracked_buffers.contains_key(&buffer); - let has_attributed_change = self - .tracked_buffers - .get(&buffer) - .is_some_and(|tracked_buffer| tracked_buffer.has_edits(cx)); - let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx); if !had_tracked_buffer { + let tracked_buffer = self.track_buffer_internal(buffer.clone(), false, cx); tracked_buffer.mode = TrackedBufferMode::ExpectationOnly; tracked_buffer.status = TrackedBufferStatus::Modified; tracked_buffer.diff_base = buffer.read(cx).as_rope().clone(); @@ -936,6 +937,12 @@ impl ActionLog { tracked_buffer.unreviewed_edits.clear(); } + // Reusing an existing tracked buffer must preserve its prior version so stale-buffer + // detection continues to reflect any user edits that predate the expectation. + let Some(tracked_buffer) = self.tracked_buffers.get_mut(&buffer) else { + return; + }; + let expected_external_edit = tracked_buffer .expected_external_edit @@ -945,7 +952,7 @@ impl ActionLog { initial_exists_on_disk, observed_external_file_change: false, armed_explicit_reload: false, - has_attributed_change, + has_attributed_change: false, pending_delete: false, is_disqualified: false, }); @@ -1094,17 +1101,17 @@ impl ActionLog { ) { if let Some(linked_action_log) = &self.linked_action_log { let linked_baseline_snapshot = baseline_snapshot.clone(); - if !linked_action_log.read(cx).has_changed_buffer(buffer, cx) { - linked_action_log.update(cx, |log, cx| { - log.infer_buffer_from_snapshot_impl( - buffer.clone(), - linked_baseline_snapshot, - kind, - false, - cx, - ); - }); - } + // Later inferred snapshots must keep refreshing linked logs for the same buffer so + // parent and child review state do not diverge after the first forwarded hunk. + linked_action_log.update(cx, |log, cx| { + log.infer_buffer_from_snapshot_impl( + buffer.clone(), + linked_baseline_snapshot, + kind, + false, + cx, + ); + }); } } @@ -1129,10 +1136,11 @@ impl ActionLog { } } + let tracked_buffer_status = kind.tracked_buffer_status(&baseline_snapshot); self.prime_tracked_buffer_from_snapshot( buffer.clone(), baseline_snapshot, - kind.tracked_buffer_status(), + tracked_buffer_status, cx, ); @@ -1818,10 +1826,17 @@ enum InferredSnapshotKind { } impl InferredSnapshotKind { - fn tracked_buffer_status(self) -> TrackedBufferStatus { + fn tracked_buffer_status( + self, + baseline_snapshot: &text::BufferSnapshot, + ) -> TrackedBufferStatus { match self { Self::Created => TrackedBufferStatus::Created { - existing_file_content: None, + existing_file_content: if baseline_snapshot.text().is_empty() { + None + } else { + Some(baseline_snapshot.as_rope().clone()) + }, }, Self::Edited => TrackedBufferStatus::Modified, Self::Deleted => TrackedBufferStatus::Deleted, @@ -4077,6 +4092,88 @@ mod tests { ); } + #[gpui::test] + async fn test_linked_action_log_forwards_sequential_inferred_snapshots( + cx: &mut TestAppContext, + ) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "one\ntwo\n"})) + .await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let parent_log = cx.new(|_| ActionLog::new(project.clone())); + let child_log = + cx.new(|_| ActionLog::new(project.clone()).with_linked_action_log(parent_log.clone())); + + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .expect("test file should exist"); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .expect("test buffer should open"); + + let first_baseline_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot()); + buffer.update(cx, |buffer, cx| buffer.set_text("one\ntwo\nthree\n", cx)); + let second_baseline_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot()); + buffer.update(cx, |buffer, cx| { + buffer.set_text("one\ntwo\nthree\nfour\n", cx) + }); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .expect("final inferred buffer contents should save"); + + cx.update(|cx| { + child_log.update(cx, |log, cx| { + log.infer_buffer_edited_from_snapshot( + buffer.clone(), + first_baseline_snapshot.clone(), + cx, + ); + }); + }); + cx.run_until_parked(); + + let first_child_hunks = unreviewed_hunks(&child_log, cx); + assert!( + !first_child_hunks.is_empty(), + "the first inferred snapshot should produce review hunks" + ); + assert_eq!( + unreviewed_hunks(&parent_log, cx), + first_child_hunks, + "parent should match the first forwarded inferred snapshot" + ); + + cx.update(|cx| { + child_log.update(cx, |log, cx| { + log.infer_buffer_edited_from_snapshot( + buffer.clone(), + second_baseline_snapshot.clone(), + cx, + ); + }); + }); + cx.run_until_parked(); + + let second_child_hunks = unreviewed_hunks(&child_log, cx); + assert!( + !second_child_hunks.is_empty(), + "the second inferred snapshot should still produce review hunks" + ); + assert_ne!( + second_child_hunks, first_child_hunks, + "the second inferred snapshot should refresh the tracked diff" + ); + assert_eq!( + unreviewed_hunks(&parent_log, cx), + second_child_hunks, + "parent should stay in sync after sequential inferred snapshots on one buffer" + ); + } + #[gpui::test] async fn test_linked_action_log_infer_buffer_created(cx: &mut TestAppContext) { init_test(cx); @@ -4499,6 +4596,171 @@ mod tests { ); } + #[gpui::test] + async fn test_expected_external_edit_starts_unattributed_even_with_existing_hunks( + cx: &mut TestAppContext, + ) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "one\ntwo\n"})) + .await; + let project = Project::test(fs, [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + buffer.update(cx, |buffer, cx| buffer.set_text("one\ntwo\nthree\n", cx)); + action_log.update(cx, |log, cx| log.buffer_edited(buffer.clone(), cx)); + }); + cx.run_until_parked(); + + assert!( + !unreviewed_hunks(&action_log, cx).is_empty(), + "buffer should already have tracked hunks before the expectation starts" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| { + log.begin_expected_external_edit(buffer.clone(), cx); + }); + }); + + assert!( + action_log.read_with(cx, |log, _| { + log.tracked_buffers + .get(&buffer) + .and_then(|tracked_buffer| tracked_buffer.expected_external_edit.as_ref()) + .is_some_and(|expected_external_edit| { + !expected_external_edit.has_attributed_change + }) + }), + "a new expected external edit should start as unattributed even when the buffer already has hunks" + ); + } + + #[gpui::test] + async fn test_expected_external_edit_preserves_stale_tracking_for_existing_tracked_buffer( + cx: &mut TestAppContext, + ) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({"file": "one\ntwo\n"})) + .await; + let project = Project::test(fs, [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| project.find_project_path("dir/file", cx)) + .expect("test file should exist"); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .expect("test buffer should open"); + + cx.update(|cx| { + action_log.update(cx, |log, cx| log.buffer_read(buffer.clone(), cx)); + }); + + cx.update(|cx| { + buffer.update(cx, |buffer, cx| { + assert!(buffer.edit([(0..0, "zero\n")], None, cx).is_some()); + }); + }); + cx.run_until_parked(); + + assert_eq!( + action_log.read_with(cx, |log, cx| { + log.stale_buffers(cx).cloned().collect::>() + }), + vec![buffer.clone()], + "user edits after a read should mark the tracked buffer as stale" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| { + log.begin_expected_external_edit(buffer.clone(), cx); + }); + }); + + assert_eq!( + action_log.read_with(cx, |log, cx| { + log.stale_buffers(cx).cloned().collect::>() + }), + vec![buffer.clone()], + "starting an expected external edit should not clear existing stale tracking" + ); + + cx.update(|cx| { + action_log.update(cx, |log, cx| { + log.end_expected_external_edit(buffer.clone(), cx); + }); + }); + + assert_eq!( + action_log.read_with(cx, |log, cx| { + log.stale_buffers(cx).cloned().collect::>() + }), + vec![buffer], + "ending an unattributed expected external edit should preserve existing stale tracking" + ); + } + + #[gpui::test] + async fn test_infer_buffer_created_preserves_non_empty_baseline_on_reject( + cx: &mut TestAppContext, + ) { + init_test(cx); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree(path!("/dir"), json!({})).await; + let project = Project::test(fs.clone(), [path!("/dir").as_ref()], cx).await; + let action_log = cx.new(|_| ActionLog::new(project.clone())); + let file_path = project + .read_with(cx, |project, cx| { + project.find_project_path("dir/new_file", cx) + }) + .unwrap(); + let buffer = project + .update(cx, |project, cx| project.open_buffer(file_path, cx)) + .await + .unwrap(); + + buffer.update(cx, |buffer, cx| buffer.set_text("draft\n", cx)); + let baseline_snapshot = buffer.read_with(cx, |buffer, _| buffer.text_snapshot()); + + buffer.update(cx, |buffer, cx| buffer.set_text("draft\nagent\n", cx)); + project + .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx)) + .await + .unwrap(); + + cx.update(|cx| { + action_log.update(cx, |log, cx| { + log.infer_buffer_created(buffer.clone(), baseline_snapshot.clone(), cx); + }); + }); + cx.run_until_parked(); + + action_log + .update(cx, |log, cx| log.reject_all_edits(None, cx)) + .await; + cx.run_until_parked(); + + assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "draft\n"); + assert_eq!( + String::from_utf8(fs.read_file_sync(path!("/dir/new_file")).unwrap()).unwrap(), + "draft\n" + ); + } + #[gpui::test] async fn test_infer_buffer_edited_from_snapshot(cx: &mut TestAppContext) { init_test(cx);