Cargo.lock 🔗
@@ -683,6 +683,7 @@ dependencies = [
"language_model",
"language_models",
"log",
+ "lsp",
"markdown",
"open",
"paths",
Richard Feldman and Agus Zubiaga created
Re-enables format on save for agent changes (when the user has that
enabled in settings), except differently from before:
- Now we do the format-on-save in the separate buffer the edit tool
uses, *before* the diff
- This means it never triggers separate staleness
- It has the downside that edits are now blocked on the formatter
completing, but that's true of saving in general.
Release Notes:
- N/A
---------
Co-authored-by: Agus Zubiaga <hi@aguz.me>
Cargo.lock | 1
crates/assistant_tools/Cargo.toml | 2
crates/assistant_tools/src/edit_file_tool.rs | 403 +++++++++++++++++++++
3 files changed, 393 insertions(+), 13 deletions(-)
@@ -683,6 +683,7 @@ dependencies = [
"language_model",
"language_models",
"log",
+ "lsp",
"markdown",
"open",
"paths",
@@ -36,6 +36,7 @@ itertools.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
+lsp.workspace = true
markdown.workspace = true
open.workspace = true
paths.workspace = true
@@ -64,6 +65,7 @@ workspace.workspace = true
zed_llm_client.workspace = true
[dev-dependencies]
+lsp = { workspace = true, features = ["test-support"] }
client = { workspace = true, features = ["test-support"] }
clock = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
@@ -18,16 +18,21 @@ use gpui::{
use indoc::formatdoc;
use language::{
Anchor, Buffer, Capability, LanguageRegistry, LineEnding, OffsetRangeExt, Point, Rope,
- TextBuffer, language_settings::SoftWrap,
+ TextBuffer,
+ language_settings::{self, FormatOnSave, SoftWrap},
};
use language_model::{LanguageModel, LanguageModelRequest, LanguageModelToolSchemaFormat};
use markdown::{Markdown, MarkdownElement, MarkdownStyle};
-use project::{Project, ProjectPath};
+use project::{
+ Project, ProjectPath,
+ lsp_store::{FormatTrigger, LspFormatTarget},
+};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
use std::{
cmp::Reverse,
+ collections::HashSet,
ops::Range,
path::{Path, PathBuf},
sync::Arc,
@@ -189,8 +194,10 @@ impl Tool for EditFileTool {
});
let card_clone = card.clone();
+ let action_log_clone = action_log.clone();
let task = cx.spawn(async move |cx: &mut AsyncApp| {
- let edit_agent = EditAgent::new(model, project.clone(), action_log, Templates::new());
+ let edit_agent =
+ EditAgent::new(model, project.clone(), action_log_clone, Templates::new());
let buffer = project
.update(cx, |project, cx| {
@@ -244,19 +251,53 @@ impl Tool for EditFileTool {
}
let agent_output = output.await?;
+ // If format_on_save is enabled, format the buffer
+ let format_on_save_enabled = buffer
+ .read_with(cx, |buffer, cx| {
+ let settings = language_settings::language_settings(
+ buffer.language().map(|l| l.name()),
+ buffer.file(),
+ cx,
+ );
+ !matches!(settings.format_on_save, FormatOnSave::Off)
+ })
+ .unwrap_or(false);
+
+ if format_on_save_enabled {
+ let format_task = project.update(cx, |project, cx| {
+ project.format(
+ HashSet::from_iter([buffer.clone()]),
+ LspFormatTarget::Buffers,
+ false, // Don't push to history since the tool did it.
+ FormatTrigger::Save,
+ cx,
+ )
+ })?;
+ format_task.await.log_err();
+ }
+
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
+ // Notify the action log that we've edited the buffer (*after* formatting has completed).
+ action_log.update(cx, |log, cx| {
+ log.buffer_edited(buffer.clone(), cx);
+ })?;
+
let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
- let new_text = cx.background_spawn({
- let new_snapshot = new_snapshot.clone();
- async move { new_snapshot.text() }
- });
- let diff = cx.background_spawn(async move {
- language::unified_diff(&old_snapshot.text(), &new_snapshot.text())
- });
- let (new_text, diff) = futures::join!(new_text, diff);
+ let (new_text, diff) = cx
+ .background_spawn({
+ let new_snapshot = new_snapshot.clone();
+ let old_text = old_text.clone();
+ async move {
+ let new_text = new_snapshot.text();
+ let diff = language::unified_diff(&old_text, &new_text);
+
+ (new_text, diff)
+ }
+ })
+ .await;
let output = EditFileToolOutput {
original_path: project_path.path.to_path_buf(),
@@ -1099,8 +1140,8 @@ async fn build_buffer_diff(
mod tests {
use super::*;
use client::TelemetrySettings;
- use fs::FakeFs;
- use gpui::TestAppContext;
+ use fs::{FakeFs, Fs};
+ use gpui::{TestAppContext, UpdateGlobal};
use language_model::fake_provider::FakeLanguageModel;
use serde_json::json;
use settings::SettingsStore;
@@ -1310,4 +1351,340 @@ mod tests {
Project::init_settings(cx);
});
}
+
+ #[gpui::test]
+ async fn test_format_on_save(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({"src": {}})).await;
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+
+ // Set up a Rust language with LSP formatting support
+ let rust_language = Arc::new(language::Language::new(
+ language::LanguageConfig {
+ name: "Rust".into(),
+ matcher: language::LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ None,
+ ));
+
+ // Register the language and fake LSP
+ let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+ language_registry.add(rust_language);
+
+ let mut fake_language_servers = language_registry.register_fake_lsp(
+ "Rust",
+ language::FakeLspAdapter {
+ capabilities: lsp::ServerCapabilities {
+ document_formatting_provider: Some(lsp::OneOf::Left(true)),
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ );
+
+ // Create the file
+ fs.save(
+ path!("/root/src/main.rs").as_ref(),
+ &"initial content".into(),
+ language::LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+
+ // Open the buffer to trigger LSP initialization
+ let buffer = project
+ .update(cx, |project, cx| {
+ project.open_local_buffer(path!("/root/src/main.rs"), cx)
+ })
+ .await
+ .unwrap();
+
+ // Register the buffer with language servers
+ let _handle = project.update(cx, |project, cx| {
+ project.register_buffer_with_language_servers(&buffer, cx)
+ });
+
+ const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
+ const FORMATTED_CONTENT: &str =
+ "This file was formatted by the fake formatter in the test.\n";
+
+ // Get the fake language server and set up formatting handler
+ let fake_language_server = fake_language_servers.next().await.unwrap();
+ fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
+ |_, _| async move {
+ Ok(Some(vec![lsp::TextEdit {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
+ new_text: FORMATTED_CONTENT.to_string(),
+ }]))
+ }
+ });
+
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ // First, test with format_on_save enabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+ cx,
+ |settings| {
+ settings.defaults.format_on_save = Some(FormatOnSave::On);
+ settings.defaults.formatter =
+ Some(language::language_settings::SelectedFormatter::Auto);
+ },
+ );
+ });
+ });
+
+ // Have the model stream unformatted content
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Create main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the unformatted content
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Read the file to verify it was formatted automatically
+ let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ new_content.replace("\r\n", "\n"),
+ FORMATTED_CONTENT,
+ "Code should be formatted when format_on_save is enabled"
+ );
+
+ let stale_buffer_count = action_log.read_with(cx, |log, cx| log.stale_buffers(cx).count());
+
+ assert_eq!(
+ stale_buffer_count, 0,
+ "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
+ This causes the agent to think the file was modified externally when it was just formatted.",
+ stale_buffer_count
+ );
+
+ // Next, test with format_on_save disabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+ cx,
+ |settings| {
+ settings.defaults.format_on_save = Some(FormatOnSave::Off);
+ },
+ );
+ });
+ });
+
+ // Stream unformatted edits again
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Update main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the unformatted content
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(UNFORMATTED_CONTENT.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Verify the file was not formatted
+ let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ new_content.replace("\r\n", "\n"),
+ UNFORMATTED_CONTENT,
+ "Code should not be formatted when format_on_save is disabled"
+ );
+ }
+
+ #[gpui::test]
+ async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree("/root", json!({"src": {}})).await;
+
+ // Create a simple file with trailing whitespace
+ fs.save(
+ path!("/root/src/main.rs").as_ref(),
+ &"initial content".into(),
+ language::LineEnding::Unix,
+ )
+ .await
+ .unwrap();
+
+ let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+ let action_log = cx.new(|_| ActionLog::new(project.clone()));
+ let model = Arc::new(FakeLanguageModel::default());
+
+ // First, test with remove_trailing_whitespace_on_save enabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+ cx,
+ |settings| {
+ settings.defaults.remove_trailing_whitespace_on_save = Some(true);
+ },
+ );
+ });
+ });
+
+ const CONTENT_WITH_TRAILING_WHITESPACE: &str =
+ "fn main() { \n println!(\"Hello!\"); \n}\n";
+
+ // Have the model stream content that contains trailing whitespace
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Create main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the content with trailing whitespace
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Read the file to verify trailing whitespace was removed automatically
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ fs.load(path!("/root/src/main.rs").as_ref())
+ .await
+ .unwrap()
+ .replace("\r\n", "\n"),
+ "fn main() {\n println!(\"Hello!\");\n}\n",
+ "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
+ );
+
+ // Next, test with remove_trailing_whitespace_on_save disabled
+ cx.update(|cx| {
+ SettingsStore::update_global(cx, |store, cx| {
+ store.update_user_settings::<language::language_settings::AllLanguageSettings>(
+ cx,
+ |settings| {
+ settings.defaults.remove_trailing_whitespace_on_save = Some(false);
+ },
+ );
+ });
+ });
+
+ // Stream edits again with trailing whitespace
+ let edit_result = {
+ let edit_task = cx.update(|cx| {
+ let input = serde_json::to_value(EditFileToolInput {
+ display_description: "Update main function".into(),
+ path: "root/src/main.rs".into(),
+ mode: EditFileMode::Overwrite,
+ })
+ .unwrap();
+ Arc::new(EditFileTool)
+ .run(
+ input,
+ Arc::default(),
+ project.clone(),
+ action_log.clone(),
+ model.clone(),
+ None,
+ cx,
+ )
+ .output
+ });
+
+ // Stream the content with trailing whitespace
+ cx.executor().run_until_parked();
+ model.stream_last_completion_response(CONTENT_WITH_TRAILING_WHITESPACE.to_string());
+ model.end_last_completion_stream();
+
+ edit_task.await
+ };
+ assert!(edit_result.is_ok());
+
+ // Wait for any async operations (e.g. formatting) to complete
+ cx.executor().run_until_parked();
+
+ // Verify the file still has trailing whitespace
+ // Read the file again - it should still have trailing whitespace
+ let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
+ assert_eq!(
+ // Ignore carriage returns on Windows
+ final_content.replace("\r\n", "\n"),
+ CONTENT_WITH_TRAILING_WHITESPACE,
+ "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
+ );
+ }
}