Cargo.lock 🔗
@@ -683,6 +683,7 @@ dependencies = [
"language_model",
"language_models",
"log",
+ "lsp",
"markdown",
"open",
"paths",
Richard Feldman created
Release Notes:
- Agents now automatically format after edits if `format_on_save` is
enabled.
Cargo.lock | 1
crates/assistant_tools/Cargo.toml | 3
crates/assistant_tools/src/edit_file_tool.rs | 418 +++++++++++++++++++++
3 files changed, 419 insertions(+), 3 deletions(-)
@@ -683,6 +683,7 @@ dependencies = [
"language_model",
"language_models",
"log",
+ "lsp",
"markdown",
"open",
"paths",
@@ -59,11 +59,12 @@ ui.workspace = true
util.workspace = true
web_search.workspace = true
which.workspace = true
-workspace-hack.workspace = true
workspace.workspace = true
+workspace-hack.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
+lsp = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
@@ -8,6 +8,10 @@ use assistant_tool::{
ActionLog, AnyToolCard, Tool, ToolCard, ToolResult, ToolResultContent, ToolResultOutput,
ToolUseStatus,
};
+use language::language_settings::{self, FormatOnSave};
+use project::lsp_store::{FormatTrigger, LspFormatTarget};
+use std::collections::HashSet;
+
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use editor::{Editor, EditorMode, MultiBuffer, PathKey};
use futures::StreamExt;
@@ -249,6 +253,40 @@ impl Tool for EditFileTool {
}
let agent_output = output.await?;
+ // Format buffer if format_on_save is enabled, before saving.
+ // If any part of the formatting operation fails, log an error but
+ // don't block the completion of the edit tool's work.
+ let should_format = buffer
+ .read_with(cx, |buffer, cx| {
+ let settings = language_settings::language_settings(
+ buffer.language().map(|l| l.name()),
+ buffer.file(),
+ cx,
+ );
+ !matches!(settings.format_on_save, FormatOnSave::Off)
+ })
+ .log_err()
+ .unwrap_or(false);
+
+ if should_format {
+ let buffers = HashSet::from_iter([buffer.clone()]);
+
+ if let Some(format_task) = project
+ .update(cx, move |project, cx| {
+ project.format(
+ buffers,
+ LspFormatTarget::Buffers,
+ false, // Don't push to history since the tool did it.
+ FormatTrigger::Save,
+ cx,
+ )
+ })
+ .log_err()
+ {
+ format_task.await.log_err();
+ }
+ }
+
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
@@ -918,11 +956,15 @@ async fn build_buffer_diff(
mod tests {
use super::*;
use client::TelemetrySettings;
- use fs::FakeFs;
- use gpui::TestAppContext;
+ use fs::{FakeFs, Fs};
+ use gpui::{TestAppContext, UpdateGlobal};
+ use language::{FakeLspAdapter, Language, LanguageConfig, LanguageMatcher};
use language_model::fake_provider::FakeLanguageModel;
+ use language_settings::{AllLanguageSettings, Formatter, FormatterList, SelectedFormatter};
+ use lsp;
use serde_json::json;
use settings::SettingsStore;
+ use std::sync::Arc;
use util::path;
#[gpui::test]
@@ -1129,4 +1171,376 @@ mod tests {
Project::init_settings(cx);
});
}
+
+ #[gpui::test]
+ async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({"src": {}})).await;
+
+ // Create a simple file with trailing whitespace
+ fs.save(
+ path!("/root/src/main.rs").as_ref(),
+ &"initial content".into(),
+ LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ // First, test with remove_trailing_whitespace_on_save enabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+ settings.defaults.remove_trailing_whitespace_on_save = Some(true);
+ });
+ });
+ });
+
+ const CONTENT_WITH_TRAILING_WHITESPACE: &str =
+ "fn main() { \n println!(\"Hello!\"); \n}\n";
+
+ // Have the model stream content that contains trailing whitespace
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Create main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the content with trailing whitespace
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Read the file to verify trailing whitespace was removed automatically
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ fs.load(path!("/root/src/main.rs").as_ref())
+ .await
+ .unwrap()
+ .replace("\r\n", "\n"),
+ "fn main() {\n println!(\"Hello!\");\n}\n",
+ "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
+ );
+
+ // Next, test with remove_trailing_whitespace_on_save disabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+ settings.defaults.remove_trailing_whitespace_on_save = Some(false);
+ });
+ });
+ });
+
+ // Stream edits again with trailing whitespace
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Update main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the content with trailing whitespace
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Verify the file still has trailing whitespace
+ // Read the file again - it should still have trailing whitespace
+ let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ final_content.replace("\r\n", "\n"),
+ CONTENT_WITH_TRAILING_WHITESPACE,
+ "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_format_on_save(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({"src": {}})).await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+
+ // Set up a Rust language with LSP formatting support
+ let rust_language = Arc::new(Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ None,
+ ));
+
+ // Register the language and fake LSP
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_language);
+
+ let mut fake_language_servers = language_registry.register_fake_lsp(
+ "Rust",
+ FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ document_formatting_provider: Some(lsp::OneOf::Left(true)),
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ );
+
+ // Create the file
+ fs.save(
+ path!("/root/src/main.rs").as_ref(),
+ &"initial content".into(),
+ LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+
+ // Open the buffer to trigger LSP initialization
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/root/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ // Register the buffer with language servers
+ let _handle = project.update(cx, |project, cx| {
+ project.register_buffer_with_language_servers(&buffer, cx)
+ });
+
+ const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
+ const FORMATTED_CONTENT: &str =
+ "This file was formatted by the fake formatter in the test.\n";
+
+ // Get the fake language server and set up formatting handler
+ let fake_language_server = fake_language_servers.next().await.unwrap();
+ fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
+ |_, _| async move {
+ Ok(Some(vec![lsp::TextEdit {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
+ new_text: FORMATTED_CONTENT.to_string(),
+ }]))
+ }
+ });
+
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ // First, test with format_on_save enabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+ settings.defaults.format_on_save = Some(FormatOnSave::On);
+ settings.defaults.formatter = Some(SelectedFormatter::Auto);
+ });
+ });
+ });
+
+ // Have the model stream unformatted content
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Create main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the unformatted content
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Read the file to verify it was formatted automatically
+ let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ new_content.replace("\r\n", "\n"),
+ FORMATTED_CONTENT,
+ "Code should be formatted when format_on_save is enabled"
+ );
+
+ // Next, test with format_on_save disabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+ settings.defaults.format_on_save = Some(FormatOnSave::Off);
+ });
+ });
+ });
+
+ // Stream unformatted edits again
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Update main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the unformatted content
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Verify the file is still unformatted
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ fs.load(path!("/root/src/main.rs").as_ref())
+ .await
+ .unwrap()
+ .replace("\r\n", "\n"),
+ UNFORMATTED_CONTENT,
+ "Code should remain unformatted when format_on_save is disabled"
+ );
+
+ // Finally, test with format_on_save set to a list
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<AllLanguageSettings>(cx, |settings| {
+ settings.defaults.format_on_save = Some(FormatOnSave::List(FormatterList(
+ vec![Formatter::LanguageServer { name: None }].into(),
+ )));
+ });
+ });
+ });
+
+ // Stream unformatted edits again
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Update main function with list formatter".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the unformatted content
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Read the file to verify it was formatted with the specified formatter
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ fs.load(path!("/root/src/main.rs").as_ref())
+ .await
+ .unwrap()
+ .replace("\r\n", "\n"),
+ FORMATTED_CONTENT,
+ "Code should be formatted when format_on_save is set to a list"
+ );
+ }
}