@@ -2459,6 +2459,30 @@ impl ToolCallEventStreamReceiver {
}
}
+ pub async fn expect_update_fields(&mut self) -> acp::ToolCallUpdateFields {
+ let event = self.0.next().await;
+ if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateFields(
+ update,
+ )))) = event
+ {
+ update.fields
+ } else {
+ panic!("Expected update fields but got: {:?}", event);
+ }
+ }
+
+ pub async fn expect_diff(&mut self) -> Entity<acp_thread::Diff> {
+ let event = self.0.next().await;
+ if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateDiff(
+ update,
+ )))) = event
+ {
+ update.diff
+ } else {
+ panic!("Expected diff but got: {:?}", event);
+ }
+ }
+
pub async fn expect_terminal(&mut self) -> Entity<acp_thread::Terminal> {
let event = self.0.next().await;
if let Some(Ok(ThreadEvent::ToolCallUpdate(acp_thread::ToolCallUpdate::UpdateTerminal(
@@ -273,6 +273,13 @@ impl AgentTool for EditFileTool {
let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
event_stream.update_diff(diff.clone());
+ let _finalize_diff = util::defer({
+ let diff = diff.downgrade();
+ let mut cx = cx.clone();
+ move || {
+ diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
+ }
+ });
let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let old_text = cx
@@ -389,8 +396,6 @@ impl AgentTool for EditFileTool {
})
.await;
- diff.update(cx, |diff, cx| diff.finalize(cx)).ok();
-
let input_path = input.path.display();
if unified_diff.is_empty() {
anyhow::ensure!(
@@ -1545,6 +1550,100 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_diff_finalization(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = project::FakeFs::new(cx.executor());
+ fs.insert_tree("/", json!({"main.rs": ""})).await;
+
+ let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
+ let languages = project.read_with(cx, |project, _cx| project.languages().clone());
+ let context_server_registry =
+ cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
+ let model = Arc::new(FakeLanguageModel::default());
+ let thread = cx.new(|cx| {
+ Thread::new(
+ project.clone(),
+ cx.new(|_cx| ProjectContext::default()),
+ context_server_registry.clone(),
+ Templates::new(),
+ Some(model.clone()),
+ cx,
+ )
+ });
+
+ // Ensure the diff is finalized after the edit completes.
+ {
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let edit = cx.update(|cx| {
+ tool.run(
+ EditFileToolInput {
+ display_description: "Edit file".into(),
+ path: path!("/main.rs").into(),
+ mode: EditFileMode::Edit,
+ },
+ stream_tx,
+ cx,
+ )
+ });
+ stream_rx.expect_update_fields().await;
+ let diff = stream_rx.expect_diff().await;
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
+ cx.run_until_parked();
+ model.end_last_completion_stream();
+ edit.await.unwrap();
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+ }
+
+ // Ensure the diff is finalized if an error occurs while editing.
+ {
+ model.forbid_requests();
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let edit = cx.update(|cx| {
+ tool.run(
+ EditFileToolInput {
+ display_description: "Edit file".into(),
+ path: path!("/main.rs").into(),
+ mode: EditFileMode::Edit,
+ },
+ stream_tx,
+ cx,
+ )
+ });
+ stream_rx.expect_update_fields().await;
+ let diff = stream_rx.expect_diff().await;
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
+ edit.await.unwrap_err();
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+ model.allow_requests();
+ }
+
+ // Ensure the diff is finalized if the tool call gets dropped.
+ {
+ let tool = Arc::new(EditFileTool::new(thread.downgrade(), languages.clone()));
+ let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
+ let edit = cx.update(|cx| {
+ tool.run(
+ EditFileToolInput {
+ display_description: "Edit file".into(),
+ path: path!("/main.rs").into(),
+ mode: EditFileMode::Edit,
+ },
+ stream_tx,
+ cx,
+ )
+ });
+ stream_rx.expect_update_fields().await;
+ let diff = stream_rx.expect_diff().await;
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
+ drop(edit);
+ cx.run_until_parked();
+ diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
+ }
+ }
+
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
@@ -4,12 +4,16 @@ use crate::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, LanguageModelToolChoice,
};
+use anyhow::anyhow;
use futures::{FutureExt, channel::mpsc, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, Entity, Task, Window};
use http_client::Result;
use parking_lot::Mutex;
use smol::stream::StreamExt;
-use std::sync::Arc;
+use std::sync::{
+ Arc,
+ atomic::{AtomicBool, Ordering::SeqCst},
+};
#[derive(Clone)]
pub struct FakeLanguageModelProvider {
@@ -106,6 +110,7 @@ pub struct FakeLanguageModel {
>,
)>,
>,
+ forbid_requests: AtomicBool,
}
impl Default for FakeLanguageModel {
@@ -114,11 +119,20 @@ impl Default for FakeLanguageModel {
provider_id: LanguageModelProviderId::from("fake".to_string()),
provider_name: LanguageModelProviderName::from("Fake".to_string()),
current_completion_txs: Mutex::new(Vec::new()),
+ forbid_requests: AtomicBool::new(false),
}
}
}
impl FakeLanguageModel {
+ pub fn allow_requests(&self) {
+ self.forbid_requests.store(false, SeqCst);
+ }
+
+ pub fn forbid_requests(&self) {
+ self.forbid_requests.store(true, SeqCst);
+ }
+
pub fn pending_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_txs
.lock()
@@ -251,9 +265,18 @@ impl LanguageModel for FakeLanguageModel {
LanguageModelCompletionError,
>,
> {
- let (tx, rx) = mpsc::unbounded();
- self.current_completion_txs.lock().push((request, tx));
- async move { Ok(rx.boxed()) }.boxed()
+ if self.forbid_requests.load(SeqCst) {
+ async move {
+ Err(LanguageModelCompletionError::Other(anyhow!(
+ "requests are forbidden"
+ )))
+ }
+ .boxed()
+ } else {
+ let (tx, rx) = mpsc::unbounded();
+ self.current_completion_txs.lock().push((request, tx));
+ async move { Ok(rx.boxed()) }.boxed()
+ }
}
fn as_fake(&self) -> &Self {