Have `edit_file_tool` respect `format_on_save` (#31047)

Richard Feldman created

Release Notes:

- Agents now automatically format after edits if `format_on_save` is
enabled.

Change summary

Cargo.lock                                   |   1 
crates/assistant_tools/Cargo.toml            |   3 
crates/assistant_tools/src/edit_file_tool.rs | 418 +++++++++++++++++++++
3 files changed, 419 insertions(+), 3 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -683,6 +683,7 @@ dependencies = [
  "language_model",
  "language_models",
  "log",
+ "lsp",
  "markdown",
  "open",
  "paths",

crates/assistant_tools/Cargo.toml 🔗

@@ -59,11 +59,12 @@ ui.workspace = true
 util.workspace = true
 web_search.workspace = true
 which.workspace = true
-workspace-hack.workspace = true
 workspace.workspace = true
+workspace-hack.workspace = true
 zed_llm_client.workspace = true
 
 [dev-dependencies]
+lsp = { workspace = true, features = ["test-support"] }
 client = { workspace = true, features = ["test-support"] }
 clock = { workspace = true, features = ["test-support"] }
 collections = { workspace = true, features = ["test-support"] }

crates/assistant_tools/src/edit_file_tool.rs 🔗

@@ -8,6 +8,10 @@ use assistant_tool::{
     ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput,
     ToolUseStatus,
 };
+use language::language_settings::{self, FormatOnSave};
+use project::lsp_store::{FormatTrigger, LspFormatTarget};
+use std::collections::HashSet;
+
 use buffer_diff::{BufferDiff, BufferDiffSnapshot};
 use editor::{Editor, EditorMode, MultiBuffer, PathKey};
 use futures::StreamExt;
@@ -249,6 +253,40 @@ impl Tool for EditFileTool {
             }
             let agent_output = output.await?;
 
+            // Format buffer if format_on_save is enabled, before saving.
+            // If any part of the formatting operation fails, log an error but
+            // don't block the completion of the edit tool's work.
+            let should_format = buffer
+                .read_with(cx, |buffer, cx| {
+                    let settings = language_settings::language_settings(
+                        buffer.language().map(|l| l.name()),
+                        buffer.file(),
+                        cx,
+                    );
+                    !matches!(settings.format_on_save, FormatOnSave::Off)
+                })
+                .log_err()
+                .unwrap_or(false);
+
+            if should_format {
+                let buffers = HashSet::from_iter([buffer.clone()]);
+
+                if let Some(format_task) = project
+                    .update(cx, move |project, cx| {
+                        project.format(
+                            buffers,
+                            LspFormatTarget::Buffers,
+                            false, // Don't push to history since the tool did it.
+                            FormatTrigger::Save,
+                            cx,
+                        )
+                    })
+                    .log_err()
+                {
+                    format_task.await.log_err();
+                }
+            }
+
             project
                 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
                 .await?;
@@ -918,11 +956,15 @@ async fn build_buffer_diff(
 mod tests {
     use super::*;
     use client::TelemetrySettings;
-    use fs::FakeFs;
-    use gpui::TestAppContext;
+    use fs::{FakeFs, Fs};
+    use gpui::{TestAppContext, UpdateGlobal};
+    use language::{FakeLspAdapter, Language, LanguageConfig, LanguageMatcher};
     use language_model::fake_provider::FakeLanguageModel;
+    use language_settings::{AllLanguageSettings, Formatter, FormatterList, SelectedFormatter};
+    use lsp;
     use serde_json::json;
     use settings::SettingsStore;
+    use std::sync::Arc;
     use util::path;
 
     #[gpui::test]
@@ -1129,4 +1171,376 @@ mod tests {
             Project::init_settings(cx);
         });
     }
+
+    #[gpui::test]
+    async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree("/root", json!({"src": {}})).await;
+
+        // Create a simple file with trailing whitespace
+        fs.save(
+            path!("/root/src/main.rs").as_ref(),
+            &"initial content".into(),
+            LineEnding::Unix,
+        )
+        .await
+        .unwrap();
+
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let model = Arc::new(FakeLanguageModel::default());
+
+        // First, test with remove_trailing_whitespace_on_save enabled
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+                    settings.defaults.remove_trailing_whitespace_on_save = Some(true);
+                });
+            });
+        });
+
+        const CONTENT_WITH_TRAILING_WHITESPACE: &str =
+            "fn main() {  \n    println!(\"Hello!\");  \n}\n";
+
+        // Have the model stream content that contains trailing whitespace
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Create main function".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the content with trailing whitespace
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Read the file to verify trailing whitespace was removed automatically
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            fs.load(path!("/root/src/main.rs").as_ref())
+                .await
+                .unwrap()
+                .replace("\r\n", "\n"),
+            "fn main() {\n    println!(\"Hello!\");\n}\n",
+            "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
+        );
+
+        // Next, test with remove_trailing_whitespace_on_save disabled
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+                    settings.defaults.remove_trailing_whitespace_on_save = Some(false);
+                });
+            });
+        });
+
+        // Stream edits again with trailing whitespace
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Update main function".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the content with trailing whitespace
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Verify the file still has trailing whitespace
+        // Read the file again - it should still have trailing whitespace
+        let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            final_content.replace("\r\n", "\n"),
+            CONTENT_WITH_TRAILING_WHITESPACE,
+            "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
+        );
+    }
+
+    #[gpui::test]
+    async fn test_format_on_save(cx: &mut TestAppContext) {
+        init_test(cx);
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree("/root", json!({"src": {}})).await;
+
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+
+        // Set up a Rust language with LSP formatting support
+        let rust_language = Arc::new(Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                matcher: LanguageMatcher {
+                    path_suffixes: vec!["rs".to_string()],
+                    ..Default::default()
+                },
+                ..Default::default()
+            },
+            None,
+        ));
+
+        // Register the language and fake LSP
+        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+        language_registry.add(rust_language);
+
+        let mut fake_language_servers = language_registry.register_fake_lsp(
+            "Rust",
+            FakeLspAdapter {
+                capabilities: lsp::ServerCapabilities {
+                    document_formatting_provider: Some(lsp::OneOf::Left(true)),
+                    ..Default::default()
+                },
+                ..Default::default()
+            },
+        );
+
+        // Create the file
+        fs.save(
+            path!("/root/src/main.rs").as_ref(),
+            &"initial content".into(),
+            LineEnding::Unix,
+        )
+        .await
+        .unwrap();
+
+        // Open the buffer to trigger LSP initialization
+        let buffer = project
+            .update(cx, |project, cx| {
+                project.open_local_buffer(path!("/root/src/main.rs"), cx)
+            })
+            .await
+            .unwrap();
+
+        // Register the buffer with language servers
+        let _handle = project.update(cx, |project, cx| {
+            project.register_buffer_with_language_servers(&buffer, cx)
+        });
+
+        const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
+        const FORMATTED_CONTENT: &str =
+            "This file was formatted by the fake formatter in the test.\n";
+
+        // Get the fake language server and set up formatting handler
+        let fake_language_server = fake_language_servers.next().await.unwrap();
+        fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
+            |_, _| async move {
+                Ok(Some(vec![lsp::TextEdit {
+                    range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
+                    new_text: FORMATTED_CONTENT.to_string(),
+                }]))
+            }
+        });
+
+        let action_log = cx.new(|_| ActionLog::new(project.clone()));
+        let model = Arc::new(FakeLanguageModel::default());
+
+        // First, test with format_on_save enabled
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+                    settings.defaults.format_on_save = Some(FormatOnSave::On);
+                    settings.defaults.formatter = Some(SelectedFormatter::Auto);
+                });
+            });
+        });
+
+        // Have the model stream unformatted content
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Create main function".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the unformatted content
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Read the file to verify it was formatted automatically
+        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            new_content.replace("\r\n", "\n"),
+            FORMATTED_CONTENT,
+            "Code should be formatted when format_on_save is enabled"
+        );
+
+        // Next, test with format_on_save disabled
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+                    settings.defaults.format_on_save = Some(FormatOnSave::Off);
+                });
+            });
+        });
+
+        // Stream unformatted edits again
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Update main function".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the unformatted content
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Verify the file is still unformatted
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            fs.load(path!("/root/src/main.rs").as_ref())
+                .await
+                .unwrap()
+                .replace("\r\n", "\n"),
+            UNFORMATTED_CONTENT,
+            "Code should remain unformatted when format_on_save is disabled"
+        );
+
+        // Finally, test with format_on_save set to a list
+        cx.update(|cx| {
+            SettingsStore::update_global(cx, |store, cx| {
+                store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+                    settings.defaults.format_on_save = Some(FormatOnSave::List(FormatterList(
+                        vec![Formatter::LanguageServer { name: None }].into(),
+                    )));
+                });
+            });
+        });
+
+        // Stream unformatted edits again
+        let edit_result = {
+            let edit_task = cx.update(|cx| {
+                let input = serde_json::to_value(EditFileToolInput {
+                    display_description: "Update main function with list formatter".into(),
+                    path: "root/src/main.rs".into(),
+                    mode: EditFileMode::Overwrite,
+                })
+                .unwrap();
+                Arc::new(EditFileTool)
+                    .run(
+                        input,
+                        Arc::default(),
+                        project.clone(),
+                        action_log.clone(),
+                        model.clone(),
+                        None,
+                        cx,
+                    )
+                    .output
+            });
+
+            // Stream the unformatted content
+            cx.executor().run_until_parked();
+            model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+            model.end_last_completion_stream();
+
+            edit_task.await
+        };
+        assert!(edit_result.is_ok());
+
+        // Wait for any async operations (e.g. formatting) to complete
+        cx.executor().run_until_parked();
+
+        // Read the file to verify it was formatted with the specified formatter
+        assert_eq!(
+            // Ignore carriage returns on Windows
+            fs.load(path!("/root/src/main.rs").as_ref())
+                .await
+                .unwrap()
+                .replace("\r\n", "\n"),
+            FORMATTED_CONTENT,
+            "Code should be formatted when format_on_save is set to a list"
+        );
+    }
 }