Format streamed edits on save (#31623)

Richard Feldman and Agus Zubiaga created

Re-enables format on save for agent changes (when the user has that
enabled in settings), except differently from before:
- Now we do the format-on-save in the separate buffer the edit tool
uses, *before* the diff
- This means it never triggers separate staleness
- It has the downside that edits are now blocked on the formatter
completing, but that's true of saving in general.

Release Notes:

- N/A

---------

Co-authored-by: Agus Zubiaga <hi@aguz.me>

Change summary

Cargo.lock                                   |   1 
crates/assistant_tools/Cargo.toml            |   2 
crates/assistant_tools/src/edit_file_tool.rs | 403 +++++++++++++++++++++
3 files changed, 393 insertions(+), 13 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -683,6 +683,7 @@ dependencies = [
  "language_model",
  "language_models",
  "log",
+ "lsp",
  "markdown",
  "open",
  "paths",

crates/assistant_tools/Cargo.toml 🔗

@@ -36,6 +36,7 @@ itertools.workspace = true
 language.workspace = true
 language_model.workspace = true
 log.workspace = true
+lsp.workspace = true
 markdown.workspace = true
 open.workspace = true
 paths.workspace = true
@@ -64,6 +65,7 @@ workspace.workspace = true
 zed_llm_client.workspace = true
 
 [dev-dependencies]
+lsp = { workspace = true, features = ["test-support"] }
 client = { workspace = true, features = ["test-support"] }
 clock = { workspace = true, features = ["test-support"] }
 collections = { workspace = true, features = ["test-support"] }

crates/assistant_tools/src/edit_file_tool.rs 🔗

@@ -18,16 +18,21 @@ use gpui::{
 use indoc::formatdoc;
 use language::{
     Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope,
-    TextBuffer, language_settings::SoftWrap,
+    TextBuffer,
+    language_settings::{self, FormatOnSave, SoftWrap},
 };
 use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
 use markdown::{Markdown, MarkdownElement, MarkdownStyle};
-use project::{Project, ProjectPath};
+use project::{
+    Project, ProjectPath,
+    lsp_store::{FormatTrigger, LspFormatTarget},
+};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::Settings;
 use std::{
     cmp::Reverse,
+    collections::HashSet,
     ops::Range,
     path::{Path, PathBuf},
     sync::Arc,
@@ -189,8 +194,10 @@ impl Tool for EditFileTool {
         });
 
         let card_clone = card.clone();
+        let action_log_clone = action_log.clone();
         let task = cx.spawn(async move |cx: &mut AsyncApp| {
-            let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
+            let edit_agent =
+                EditAgent::new(model, project.clone(), action_log_clone, Templates::new());
 
             let buffer = project
                 .update(cx, |project, cx| {
@@ -244,19 +251,53 @@ impl Tool for EditFileTool {
             }
             let agent_output = output.await?;
 
+            // If format_on_save is enabled, format the buffer
+            let format_on_save_enabled = buffer
+                .read_with(cx, |buffer, cx| {
+                    let settings = language_settings::language_settings(
+                        buffer.language().map(|l| l.name()),
+                        buffer.file(),
+                        cx,
+                    );
+                    !matches!(settings.format_on_save, FormatOnSave::Off)
+                })
+                .unwrap_or(false);
+
+            if format_on_save_enabled {
+                let format_task = project.update(cx, |project, cx| {
+                    project.format(
+                        HashSet::from_iter([buffer.clone()]),
+                        LspFormatTarget::Buffers,
+                        false, // Don't push to history since the tool did it.
+                        FormatTrigger::Save,
+                        cx,
+                    )
+                })?;
+                format_task.await.log_err();
+            }
+
             project
                 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
                 .await?;
 
+            // Notify the action log that we've edited the buffer (*after* formatting has completed).
+            action_log.update(cx, |log, cx| {
+                log.buffer_edited(buffer.clone(), cx);
+            })?;
+
             let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
-            let new_text = cx.background_spawn({
-                let new_snapshot = new_snapshot.clone();
-                async move { new_snapshot.text() }
-            });
-            let diff = cx.background_spawn(async move {
-                language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
-            });
-            let (new_text, diff) = futures::join!(new_text, diff);
+            let (new_text, diff) = cx
+                .background_spawn({
+                    let new_snapshot = new_snapshot.clone();
+                    let old_text = old_text.clone();
+                    async move {
+                        let new_text = new_snapshot.text();
+                        let diff = language::unified_diff(&old_text, &new_text);
+
+                        (new_text, diff)
+                    }
+                })
+                .await;
 
             let output = EditFileToolOutput {
                 original_path: project_path.path.to_path_buf(),
@@ -1099,8 +1140,8 @@ async fn build_buffer_diff(
 mod tests {
     use super::*;
     use client::TelemetrySettings;
-    use fs::FakeFs;
-    use gpui::TestAppContext;
+    use fs::{FakeFs, Fs};
+    use gpui::{TestAppContext, UpdateGlobal};
     use language_model::fake_provider::FakeLanguageModel;
     use serde_json::json;
     use settings::SettingsStore;
@@ -1310,4 +1351,340 @@ mod tests {
             Project::init_settings(cx);
         });
     }
+
+    #[gpui::test]
+    async fn test_format_on_save(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree("/root", json!({"src": {}})).await;
+
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+
+        // Set up a Rust language with LSP formatting support
+        let rust_language = Arc::new(language::Language::new(
+            language::LanguageConfig {
+                name: "Rust".into(),
+                matcher: language::LanguageMatcher {
+                    path_suffixes: vec!["rs".to_string()],
+                    ..Default::default()
+                },
+                ..Default::default()
+            },
+            None,
+        ));
+
+        // Register the language and fake LSP
+        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+        language_registry.add(rust_language);
+
+        let mut fake_language_servers = language_registry.register_fake_lsp(
+            "Rust",
+            language::FakeLspAdapter {
+                capabilities: lsp::ServerCapabilities {
+                    document_formatting_provider: Some(lsp::OneOf::Left(true)),
+                    ..Default::default()
+                },
+                ..Default::default()
+            },
+        );
+
+        // Create the file
+        fs.save(
+            path!("/root/src/main.rs").as_ref(),
+            &"initial content".into(),
+            language::LineEnding::Unix,
+        )
+        .await
+        .unwrap();
+
+        // Open the buffer to trigger LSP initialization
+        let buffer = project
+            .update(cx, |project, cx| {
+                project.open_local_buffer(path!("/root/src/main.rs"), cx)
+            })
+            .await
+            .unwrap();
+
+        // Register the buffer with language servers
+        let _handle = project.update(cx, |project, cx| {
+            project.register_buffer_with_language_servers(&buffer, cx)
+        });
+
+        const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
+        const FORMATTED_CONTENT: &str =
+            "This file was formatted by the fake formatter in the test.\n";
+
+        // Get the fake language server and set up formatting handler
+        let fake_language_server = fake_language_servers.next().await.unwrap();
+        fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
+            |_, _| async move {
+                Ok(Some(vec![lsp::TextEdit {
+                    range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
+                    new_text: FORMATTED_CONTENT.to_string(),
+                }]))
+            }
+        });
+
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let model = Arc::new(FakeLanguageModel::default());
+
+        // First, test with format_on_save enabled
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+                    cx,
+                    |settings| {
+                        settings.defaults.format_on_save = Some(FormatOnSave::On);
+                        settings.defaults.formatter =
+                            Some(language::language_settings::SelectedFormatter::Auto);
+                    },
+                );
+            });
+        });
+
+        // Have the model stream unformatted content
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Create main function".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the unformatted content
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Read the file to verify it was formatted automatically
+        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            new_content.replace("\r\n", "\n"),
+            FORMATTED_CONTENT,
+            "Code should be formatted when format_on_save is enabled"
+        );
+
+        let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count());
+
+        assert_eq!(
+            stale_buffer_count, 0,
+            "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
+             This causes the agent to think the file was modified externally when it was just formatted.",
+            stale_buffer_count
+        );
+
+        // Next, test with format_on_save disabled
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+                    cx,
+                    |settings| {
+                        settings.defaults.format_on_save = Some(FormatOnSave::Off);
+                    },
+                );
+            });
+        });
+
+        // Stream unformatted edits again
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Update main function".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the unformatted content
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Verify the file was not formatted
+        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            new_content.replace("\r\n", "\n"),
+            UNFORMATTED_CONTENT,
+            "Code should not be formatted when format_on_save is disabled"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree("/root", json!({"src": {}})).await;
+
+        // Create a simple file with trailing whitespace
+        fs.save(
+            path!("/root/src/main.rs").as_ref(),
+            &"initial content".into(),
+            language::LineEnding::Unix,
+        )
+        .await
+        .unwrap();
+
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let model = Arc::new(FakeLanguageModel::default());
+
+        // First, test with remove_trailing_whitespace_on_save enabled
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+                    cx,
+                    |settings| {
+                        settings.defaults.remove_trailing_whitespace_on_save = Some(true);
+                    },
+                );
+            });
+        });
+
+        const CONTENT_WITH_TRAILING_WHITESPACE: &str =
+            "fn main() {  \n    println!(\"Hello!\");  \n}\n";
+
+        // Have the model stream content that contains trailing whitespace
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Create main function".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the content with trailing whitespace
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Read the file to verify trailing whitespace was removed automatically
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            fs.load(path!("/root/src/main.rs").as_ref())
+                .await
+                .unwrap()
+                .replace("\r\n", "\n"),
+            "fn main() {\n    println!(\"Hello!\");\n}\n",
+            "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
+        );
+
+        // Next, test with remove_trailing_whitespace_on_save disabled
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+                    cx,
+                    |settings| {
+                        settings.defaults.remove_trailing_whitespace_on_save = Some(false);
+                    },
+                );
+            });
+        });
+
+        // Stream edits again with trailing whitespace
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Update main function".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the content with trailing whitespace
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Verify the file still has trailing whitespace
+        // Read the file again - it should still have trailing whitespace
+        let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            final_content.replace("\r\n", "\n"),
+            CONTENT_WITH_TRAILING_WHITESPACE,
+            "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
+        );
+    }
 }