edit_file_tool.rs

   1use crate::{AgentTool, Thread, ToolCallEventStream};
   2use acp_thread::Diff;
   3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
   4use anyhow::{Context as _, Result, anyhow};
   5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
   9use indoc::formatdoc;
  10use language::language_settings::{self, FormatOnSave};
  11use language::{LanguageRegistry, ToPoint};
  12use language_model::LanguageModelToolResultContent;
  13use paths;
  14use project::lsp_store::{FormatTrigger, LspFormatTarget};
  15use project::{Project, ProjectPath};
  16use schemars::JsonSchema;
  17use serde::{Deserialize, Serialize};
  18use settings::Settings;
  19use smol::stream::StreamExt as _;
  20use std::ffi::OsStr;
  21use std::path::{Path, PathBuf};
  22use std::sync::Arc;
  23use ui::SharedString;
  24use util::ResultExt;
  25use util::rel_path::RelPath;
  26
  27const DEFAULT_UI_TEXT: &str = "Editing file";
  28
  29/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
  30///
  31/// Before using this tool:
  32///
  33/// 1. Use the `read_file` tool to understand the file's contents and context
  34///
  35/// 2. Verify the directory path is correct (only applicable when creating new files):
  36///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
  37#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  38pub struct EditFileToolInput {
  39    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
  40    ///
  41    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
  42    ///
  43    /// NEVER mention the file path in this description.
  44    ///
  45    /// <example>Fix API endpoint URLs</example>
  46    /// <example>Update copyright year in `page_footer`</example>
  47    ///
  48    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
  49    pub display_description: String,
  50
  51    /// The full path of the file to create or modify in the project.
  52    ///
  53    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
  54    ///
  55    /// The following examples assume we have two root directories in the project:
  56    /// - /a/b/backend
  57    /// - /c/d/frontend
  58    ///
  59    /// <example>
  60    /// `backend/src/main.rs`
  61    ///
  62    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
  63    /// </example>
  64    ///
  65    /// <example>
  66    /// `frontend/db.js`
  67    /// </example>
  68    pub path: PathBuf,
  69    /// The mode of operation on the file. Possible values:
  70    /// - 'edit': Make granular edits to an existing file.
  71    /// - 'create': Create a new file if it doesn't exist.
  72    /// - 'overwrite': Replace the entire contents of an existing file.
  73    ///
  74    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
  75    pub mode: EditFileMode,
  76}
  77
  78#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  79struct EditFileToolPartialInput {
  80    #[serde(default)]
  81    path: String,
  82    #[serde(default)]
  83    display_description: String,
  84}
  85
  86#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  87#[serde(rename_all = "lowercase")]
  88#[schemars(inline)]
  89pub enum EditFileMode {
  90    Edit,
  91    Create,
  92    Overwrite,
  93}
  94
  95#[derive(Debug, Serialize, Deserialize)]
  96pub struct EditFileToolOutput {
  97    #[serde(alias = "original_path")]
  98    input_path: PathBuf,
  99    new_text: String,
 100    old_text: Arc<String>,
 101    #[serde(default)]
 102    diff: String,
 103    #[serde(alias = "raw_output")]
 104    edit_agent_output: EditAgentOutput,
 105}
 106
 107impl From<EditFileToolOutput> for LanguageModelToolResultContent {
 108    fn from(output: EditFileToolOutput) -> Self {
 109        if output.diff.is_empty() {
 110            "No edits were made.".into()
 111        } else {
 112            format!(
 113                "Edited {}:\n\n```diff\n{}\n```",
 114                output.input_path.display(),
 115                output.diff
 116            )
 117            .into()
 118        }
 119    }
 120}
 121
 122pub struct EditFileTool {
 123    thread: WeakEntity<Thread>,
 124    language_registry: Arc<LanguageRegistry>,
 125    project: Entity<Project>,
 126}
 127
 128impl EditFileTool {
 129    pub fn new(
 130        project: Entity<Project>,
 131        thread: WeakEntity<Thread>,
 132        language_registry: Arc<LanguageRegistry>,
 133    ) -> Self {
 134        Self {
 135            project,
 136            thread,
 137            language_registry,
 138        }
 139    }
 140
 141    fn authorize(
 142        &self,
 143        input: &EditFileToolInput,
 144        event_stream: &ToolCallEventStream,
 145        cx: &mut App,
 146    ) -> Task<Result<()>> {
 147        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
 148            return Task::ready(Ok(()));
 149        }
 150
 151        // If any path component matches the local settings folder, then this could affect
 152        // the editor in ways beyond the project source, so prompt.
 153        let local_settings_folder = paths::local_settings_folder_name();
 154        let path = Path::new(&input.path);
 155        if path.components().any(|component| {
 156            component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
 157        }) {
 158            return event_stream.authorize(
 159                format!("{} (local settings)", input.display_description),
 160                cx,
 161            );
 162        }
 163
 164        // It's also possible that the global config dir is configured to be inside the project,
 165        // so check for that edge case too.
 166        // TODO this is broken when remoting
 167        if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
 168            && canonical_path.starts_with(paths::config_dir())
 169        {
 170            return event_stream.authorize(
 171                format!("{} (global settings)", input.display_description),
 172                cx,
 173            );
 174        }
 175
 176        // Check if path is inside the global config directory
 177        // First check if it's already inside project - if not, try to canonicalize
 178        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
 179            thread.project().read(cx).find_project_path(&input.path, cx)
 180        }) else {
 181            return Task::ready(Err(anyhow!("thread was dropped")));
 182        };
 183
 184        // If the path is inside the project, and it's not one of the above edge cases,
 185        // then no confirmation is necessary. Otherwise, confirmation is necessary.
 186        if project_path.is_some() {
 187            Task::ready(Ok(()))
 188        } else {
 189            event_stream.authorize(&input.display_description, cx)
 190        }
 191    }
 192}
 193
 194impl AgentTool for EditFileTool {
 195    type Input = EditFileToolInput;
 196    type Output = EditFileToolOutput;
 197
 198    fn name() -> &'static str {
 199        "edit_file"
 200    }
 201
 202    fn kind() -> acp::ToolKind {
 203        acp::ToolKind::Edit
 204    }
 205
 206    fn initial_title(
 207        &self,
 208        input: Result<Self::Input, serde_json::Value>,
 209        cx: &mut App,
 210    ) -> SharedString {
 211        match input {
 212            Ok(input) => self
 213                .project
 214                .read(cx)
 215                .find_project_path(&input.path, cx)
 216                .and_then(|project_path| {
 217                    self.project
 218                        .read(cx)
 219                        .short_full_path_for_project_path(&project_path, cx)
 220                })
 221                .unwrap_or(input.path.to_string_lossy().into_owned())
 222                .into(),
 223            Err(raw_input) => {
 224                if let Some(input) =
 225                    serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
 226                {
 227                    let path = input.path.trim();
 228                    if !path.is_empty() {
 229                        return self
 230                            .project
 231                            .read(cx)
 232                            .find_project_path(&input.path, cx)
 233                            .and_then(|project_path| {
 234                                self.project
 235                                    .read(cx)
 236                                    .short_full_path_for_project_path(&project_path, cx)
 237                            })
 238                            .unwrap_or(input.path)
 239                            .into();
 240                    }
 241
 242                    let description = input.display_description.trim();
 243                    if !description.is_empty() {
 244                        return description.to_string().into();
 245                    }
 246                }
 247
 248                DEFAULT_UI_TEXT.into()
 249            }
 250        }
 251    }
 252
 253    fn run(
 254        self: Arc<Self>,
 255        input: Self::Input,
 256        event_stream: ToolCallEventStream,
 257        cx: &mut App,
 258    ) -> Task<Result<Self::Output>> {
 259        let Ok(project) = self
 260            .thread
 261            .read_with(cx, |thread, _cx| thread.project().clone())
 262        else {
 263            return Task::ready(Err(anyhow!("thread was dropped")));
 264        };
 265        let project_path = match resolve_path(&input, project.clone(), cx) {
 266            Ok(path) => path,
 267            Err(err) => return Task::ready(Err(anyhow!(err))),
 268        };
 269        let abs_path = project.read(cx).absolute_path(&project_path, cx);
 270        if let Some(abs_path) = abs_path.clone() {
 271            event_stream.update_fields(ToolCallUpdateFields {
 272                locations: Some(vec![acp::ToolCallLocation {
 273                    path: abs_path,
 274                    line: None,
 275                    meta: None,
 276                }]),
 277                ..Default::default()
 278            });
 279        }
 280
 281        let authorize = self.authorize(&input, &event_stream, cx);
 282        cx.spawn(async move |cx: &mut AsyncApp| {
 283            authorize.await?;
 284
 285            let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
 286                let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
 287                (request, thread.model().cloned(), thread.action_log().clone())
 288            })?;
 289            let request = request?;
 290            let model = model.context("No language model configured")?;
 291
 292            let edit_format = EditFormat::from_model(model.clone())?;
 293            let edit_agent = EditAgent::new(
 294                model,
 295                project.clone(),
 296                action_log.clone(),
 297                // TODO: move edit agent to this crate so we can use our templates
 298                assistant_tools::templates::Templates::new(),
 299                edit_format,
 300            );
 301
 302            let buffer = project
 303                .update(cx, |project, cx| {
 304                    project.open_buffer(project_path.clone(), cx)
 305                })?
 306                .await?;
 307
 308            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
 309            event_stream.update_diff(diff.clone());
 310            let _finalize_diff = util::defer({
 311               let diff = diff.downgrade();
 312               let mut cx = cx.clone();
 313               move || {
 314                   diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
 315               }
 316            });
 317
 318            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 319            let old_text = cx
 320                .background_spawn({
 321                    let old_snapshot = old_snapshot.clone();
 322                    async move { Arc::new(old_snapshot.text()) }
 323                })
 324                .await;
 325
 326
 327            let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
 328                edit_agent.edit(
 329                    buffer.clone(),
 330                    input.display_description.clone(),
 331                    &request,
 332                    cx,
 333                )
 334            } else {
 335                edit_agent.overwrite(
 336                    buffer.clone(),
 337                    input.display_description.clone(),
 338                    &request,
 339                    cx,
 340                )
 341            };
 342
 343            let mut hallucinated_old_text = false;
 344            let mut ambiguous_ranges = Vec::new();
 345            let mut emitted_location = false;
 346            while let Some(event) = events.next().await {
 347                match event {
 348                    EditAgentOutputEvent::Edited(range) => {
 349                        if !emitted_location {
 350                            let line = buffer.update(cx, |buffer, _cx| {
 351                                range.start.to_point(&buffer.snapshot()).row
 352                            }).ok();
 353                            if let Some(abs_path) = abs_path.clone() {
 354                                event_stream.update_fields(ToolCallUpdateFields {
 355                                    locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
 356                                    ..Default::default()
 357                                });
 358                            }
 359                            emitted_location = true;
 360                        }
 361                    },
 362                    EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
 363                    EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
 364                    EditAgentOutputEvent::ResolvingEditRange(range) => {
 365                        diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
 366                        // if !emitted_location {
 367                        //     let line = buffer.update(cx, |buffer, _cx| {
 368                        //         range.start.to_point(&buffer.snapshot()).row
 369                        //     }).ok();
 370                        //     if let Some(abs_path) = abs_path.clone() {
 371                        //         event_stream.update_fields(ToolCallUpdateFields {
 372                        //             locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
 373                        //             ..Default::default()
 374                        //         });
 375                        //     }
 376                        // }
 377                    }
 378                }
 379            }
 380
 381            // If format_on_save is enabled, format the buffer
 382            let format_on_save_enabled = buffer
 383                .read_with(cx, |buffer, cx| {
 384                    let settings = language_settings::language_settings(
 385                        buffer.language().map(|l| l.name()),
 386                        buffer.file(),
 387                        cx,
 388                    );
 389                    settings.format_on_save != FormatOnSave::Off
 390                })
 391                .unwrap_or(false);
 392
 393            let edit_agent_output = output.await?;
 394
 395            if format_on_save_enabled {
 396                action_log.update(cx, |log, cx| {
 397                    log.buffer_edited(buffer.clone(), cx);
 398                })?;
 399
 400                let format_task = project.update(cx, |project, cx| {
 401                    project.format(
 402                        HashSet::from_iter([buffer.clone()]),
 403                        LspFormatTarget::Buffers,
 404                        false, // Don't push to history since the tool did it.
 405                        FormatTrigger::Save,
 406                        cx,
 407                    )
 408                })?;
 409                format_task.await.log_err();
 410            }
 411
 412            project
 413                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
 414                .await?;
 415
 416            action_log.update(cx, |log, cx| {
 417                log.buffer_edited(buffer.clone(), cx);
 418            })?;
 419
 420            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 421            let (new_text, unified_diff) = cx
 422                .background_spawn({
 423                    let new_snapshot = new_snapshot.clone();
 424                    let old_text = old_text.clone();
 425                    async move {
 426                        let new_text = new_snapshot.text();
 427                        let diff = language::unified_diff(&old_text, &new_text);
 428                        (new_text, diff)
 429                    }
 430                })
 431                .await;
 432
 433            let input_path = input.path.display();
 434            if unified_diff.is_empty() {
 435                anyhow::ensure!(
 436                    !hallucinated_old_text,
 437                    formatdoc! {"
 438                        Some edits were produced but none of them could be applied.
 439                        Read the relevant sections of {input_path} again so that
 440                        I can perform the requested edits.
 441                    "}
 442                );
 443                anyhow::ensure!(
 444                    ambiguous_ranges.is_empty(),
 445                    {
 446                        let line_numbers = ambiguous_ranges
 447                            .iter()
 448                            .map(|range| range.start.to_string())
 449                            .collect::<Vec<_>>()
 450                            .join(", ");
 451                        formatdoc! {"
 452                            <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
 453                            relevant sections of {input_path} again and extend <old_text> so
 454                            that I can perform the requested edits.
 455                        "}
 456                    }
 457                );
 458            }
 459
 460            Ok(EditFileToolOutput {
 461                input_path: input.path,
 462                new_text,
 463                old_text,
 464                diff: unified_diff,
 465                edit_agent_output,
 466            })
 467        })
 468    }
 469
 470    fn replay(
 471        &self,
 472        _input: Self::Input,
 473        output: Self::Output,
 474        event_stream: ToolCallEventStream,
 475        cx: &mut App,
 476    ) -> Result<()> {
 477        event_stream.update_diff(cx.new(|cx| {
 478            Diff::finalized(
 479                output.input_path.to_string_lossy().into_owned(),
 480                Some(output.old_text.to_string()),
 481                output.new_text,
 482                self.language_registry.clone(),
 483                cx,
 484            )
 485        }));
 486        Ok(())
 487    }
 488}
 489
 490/// Validate that the file path is valid, meaning:
 491///
 492/// - For `edit` and `overwrite`, the path must point to an existing file.
 493/// - For `create`, the file must not already exist, but it's parent dir must exist.
 494fn resolve_path(
 495    input: &EditFileToolInput,
 496    project: Entity<Project>,
 497    cx: &mut App,
 498) -> Result<ProjectPath> {
 499    let project = project.read(cx);
 500
 501    match input.mode {
 502        EditFileMode::Edit | EditFileMode::Overwrite => {
 503            let path = project
 504                .find_project_path(&input.path, cx)
 505                .context("Can't edit file: path not found")?;
 506
 507            let entry = project
 508                .entry_for_path(&path, cx)
 509                .context("Can't edit file: path not found")?;
 510
 511            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
 512            Ok(path)
 513        }
 514
 515        EditFileMode::Create => {
 516            if let Some(path) = project.find_project_path(&input.path, cx) {
 517                anyhow::ensure!(
 518                    project.entry_for_path(&path, cx).is_none(),
 519                    "Can't create file: file already exists"
 520                );
 521            }
 522
 523            let parent_path = input
 524                .path
 525                .parent()
 526                .context("Can't create file: incorrect path")?;
 527
 528            let parent_project_path = project.find_project_path(&parent_path, cx);
 529
 530            let parent_entry = parent_project_path
 531                .as_ref()
 532                .and_then(|path| project.entry_for_path(path, cx))
 533                .context("Can't create file: parent directory doesn't exist")?;
 534
 535            anyhow::ensure!(
 536                parent_entry.is_dir(),
 537                "Can't create file: parent is not a directory"
 538            );
 539
 540            let file_name = input
 541                .path
 542                .file_name()
 543                .and_then(|file_name| file_name.to_str())
 544                .and_then(|file_name| RelPath::unix(file_name).ok())
 545                .context("Can't create file: invalid filename")?;
 546
 547            let new_file_path = parent_project_path.map(|parent| ProjectPath {
 548                path: parent.path.join(file_name),
 549                ..parent
 550            });
 551
 552            new_file_path.context("Can't create file")
 553        }
 554    }
 555}
 556
 557#[cfg(test)]
 558mod tests {
 559    use super::*;
 560    use crate::{ContextServerRegistry, Templates};
 561    use client::TelemetrySettings;
 562    use fs::Fs;
 563    use gpui::{TestAppContext, UpdateGlobal};
 564    use language_model::fake_provider::FakeLanguageModel;
 565    use prompt_store::ProjectContext;
 566    use serde_json::json;
 567    use settings::SettingsStore;
 568    use util::{path, rel_path::rel_path};
 569
 570    #[gpui::test]
 571    async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
 572        init_test(cx);
 573
 574        let fs = project::FakeFs::new(cx.executor());
 575        fs.insert_tree("/root", json!({})).await;
 576        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 577        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 578        let context_server_registry =
 579            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 580        let model = Arc::new(FakeLanguageModel::default());
 581        let thread = cx.new(|cx| {
 582            Thread::new(
 583                project.clone(),
 584                cx.new(|_cx| ProjectContext::default()),
 585                context_server_registry,
 586                Templates::new(),
 587                Some(model),
 588                cx,
 589            )
 590        });
 591        let result = cx
 592            .update(|cx| {
 593                let input = EditFileToolInput {
 594                    display_description: "Some edit".into(),
 595                    path: "root/nonexistent_file.txt".into(),
 596                    mode: EditFileMode::Edit,
 597                };
 598                Arc::new(EditFileTool::new(
 599                    project,
 600                    thread.downgrade(),
 601                    language_registry,
 602                ))
 603                .run(input, ToolCallEventStream::test().0, cx)
 604            })
 605            .await;
 606        assert_eq!(
 607            result.unwrap_err().to_string(),
 608            "Can't edit file: path not found"
 609        );
 610    }
 611
 612    #[gpui::test]
 613    async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
 614        let mode = &EditFileMode::Create;
 615
 616        let result = test_resolve_path(mode, "root/new.txt", cx);
 617        assert_resolved_path_eq(result.await, rel_path("new.txt"));
 618
 619        let result = test_resolve_path(mode, "new.txt", cx);
 620        assert_resolved_path_eq(result.await, rel_path("new.txt"));
 621
 622        let result = test_resolve_path(mode, "dir/new.txt", cx);
 623        assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
 624
 625        let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
 626        assert_eq!(
 627            result.await.unwrap_err().to_string(),
 628            "Can't create file: file already exists"
 629        );
 630
 631        let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
 632        assert_eq!(
 633            result.await.unwrap_err().to_string(),
 634            "Can't create file: parent directory doesn't exist"
 635        );
 636    }
 637
 638    #[gpui::test]
 639    async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
 640        let mode = &EditFileMode::Edit;
 641
 642        let path_with_root = "root/dir/subdir/existing.txt";
 643        let path_without_root = "dir/subdir/existing.txt";
 644        let result = test_resolve_path(mode, path_with_root, cx);
 645        assert_resolved_path_eq(result.await, rel_path(path_without_root));
 646
 647        let result = test_resolve_path(mode, path_without_root, cx);
 648        assert_resolved_path_eq(result.await, rel_path(path_without_root));
 649
 650        let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
 651        assert_eq!(
 652            result.await.unwrap_err().to_string(),
 653            "Can't edit file: path not found"
 654        );
 655
 656        let result = test_resolve_path(mode, "root/dir", cx);
 657        assert_eq!(
 658            result.await.unwrap_err().to_string(),
 659            "Can't edit file: path is a directory"
 660        );
 661    }
 662
 663    async fn test_resolve_path(
 664        mode: &EditFileMode,
 665        path: &str,
 666        cx: &mut TestAppContext,
 667    ) -> anyhow::Result<ProjectPath> {
 668        init_test(cx);
 669
 670        let fs = project::FakeFs::new(cx.executor());
 671        fs.insert_tree(
 672            "/root",
 673            json!({
 674                "dir": {
 675                    "subdir": {
 676                        "existing.txt": "hello"
 677                    }
 678                }
 679            }),
 680        )
 681        .await;
 682        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 683
 684        let input = EditFileToolInput {
 685            display_description: "Some edit".into(),
 686            path: path.into(),
 687            mode: mode.clone(),
 688        };
 689
 690        cx.update(|cx| resolve_path(&input, project, cx))
 691    }
 692
 693    #[track_caller]
 694    fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
 695        let actual = path.expect("Should return valid path").path;
 696        assert_eq!(actual.as_ref(), expected);
 697    }
 698
 699    #[gpui::test]
 700    async fn test_format_on_save(cx: &mut TestAppContext) {
 701        init_test(cx);
 702
 703        let fs = project::FakeFs::new(cx.executor());
 704        fs.insert_tree("/root", json!({"src": {}})).await;
 705
 706        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 707
 708        // Set up a Rust language with LSP formatting support
 709        let rust_language = Arc::new(language::Language::new(
 710            language::LanguageConfig {
 711                name: "Rust".into(),
 712                matcher: language::LanguageMatcher {
 713                    path_suffixes: vec!["rs".to_string()],
 714                    ..Default::default()
 715                },
 716                ..Default::default()
 717            },
 718            None,
 719        ));
 720
 721        // Register the language and fake LSP
 722        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
 723        language_registry.add(rust_language);
 724
 725        let mut fake_language_servers = language_registry.register_fake_lsp(
 726            "Rust",
 727            language::FakeLspAdapter {
 728                capabilities: lsp::ServerCapabilities {
 729                    document_formatting_provider: Some(lsp::OneOf::Left(true)),
 730                    ..Default::default()
 731                },
 732                ..Default::default()
 733            },
 734        );
 735
 736        // Create the file
 737        fs.save(
 738            path!("/root/src/main.rs").as_ref(),
 739            &"initial content".into(),
 740            language::LineEnding::Unix,
 741        )
 742        .await
 743        .unwrap();
 744
 745        // Open the buffer to trigger LSP initialization
 746        let buffer = project
 747            .update(cx, |project, cx| {
 748                project.open_local_buffer(path!("/root/src/main.rs"), cx)
 749            })
 750            .await
 751            .unwrap();
 752
 753        // Register the buffer with language servers
 754        let _handle = project.update(cx, |project, cx| {
 755            project.register_buffer_with_language_servers(&buffer, cx)
 756        });
 757
 758        const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
 759        const FORMATTED_CONTENT: &str =
 760            "This file was formatted by the fake formatter in the test.\n";
 761
 762        // Get the fake language server and set up formatting handler
 763        let fake_language_server = fake_language_servers.next().await.unwrap();
 764        fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
 765            |_, _| async move {
 766                Ok(Some(vec![lsp::TextEdit {
 767                    range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
 768                    new_text: FORMATTED_CONTENT.to_string(),
 769                }]))
 770            }
 771        });
 772
 773        let context_server_registry =
 774            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 775        let model = Arc::new(FakeLanguageModel::default());
 776        let thread = cx.new(|cx| {
 777            Thread::new(
 778                project.clone(),
 779                cx.new(|_cx| ProjectContext::default()),
 780                context_server_registry,
 781                Templates::new(),
 782                Some(model.clone()),
 783                cx,
 784            )
 785        });
 786
 787        // First, test with format_on_save enabled
 788        cx.update(|cx| {
 789            SettingsStore::update_global(cx, |store, cx| {
 790                store.update_user_settings(cx, |settings| {
 791                    settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
 792                    settings.project.all_languages.defaults.formatter =
 793                        Some(language::language_settings::SelectedFormatter::Auto);
 794                });
 795            });
 796        });
 797
 798        // Have the model stream unformatted content
 799        let edit_result = {
 800            let edit_task = cx.update(|cx| {
 801                let input = EditFileToolInput {
 802                    display_description: "Create main function".into(),
 803                    path: "root/src/main.rs".into(),
 804                    mode: EditFileMode::Overwrite,
 805                };
 806                Arc::new(EditFileTool::new(
 807                    project.clone(),
 808                    thread.downgrade(),
 809                    language_registry.clone(),
 810                ))
 811                .run(input, ToolCallEventStream::test().0, cx)
 812            });
 813
 814            // Stream the unformatted content
 815            cx.executor().run_until_parked();
 816            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 817            model.end_last_completion_stream();
 818
 819            edit_task.await
 820        };
 821        assert!(edit_result.is_ok());
 822
 823        // Wait for any async operations (e.g. formatting) to complete
 824        cx.executor().run_until_parked();
 825
 826        // Read the file to verify it was formatted automatically
 827        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 828        assert_eq!(
 829            // Ignore carriage returns on Windows
 830            new_content.replace("\r\n", "\n"),
 831            FORMATTED_CONTENT,
 832            "Code should be formatted when format_on_save is enabled"
 833        );
 834
 835        let stale_buffer_count = thread
 836            .read_with(cx, |thread, _cx| thread.action_log.clone())
 837            .read_with(cx, |log, cx| log.stale_buffers(cx).count());
 838
 839        assert_eq!(
 840            stale_buffer_count, 0,
 841            "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
 842             This causes the agent to think the file was modified externally when it was just formatted.",
 843            stale_buffer_count
 844        );
 845
 846        // Next, test with format_on_save disabled
 847        cx.update(|cx| {
 848            SettingsStore::update_global(cx, |store, cx| {
 849                store.update_user_settings(cx, |settings| {
 850                    settings.project.all_languages.defaults.format_on_save =
 851                        Some(FormatOnSave::Off);
 852                });
 853            });
 854        });
 855
 856        // Stream unformatted edits again
 857        let edit_result = {
 858            let edit_task = cx.update(|cx| {
 859                let input = EditFileToolInput {
 860                    display_description: "Update main function".into(),
 861                    path: "root/src/main.rs".into(),
 862                    mode: EditFileMode::Overwrite,
 863                };
 864                Arc::new(EditFileTool::new(
 865                    project.clone(),
 866                    thread.downgrade(),
 867                    language_registry,
 868                ))
 869                .run(input, ToolCallEventStream::test().0, cx)
 870            });
 871
 872            // Stream the unformatted content
 873            cx.executor().run_until_parked();
 874            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 875            model.end_last_completion_stream();
 876
 877            edit_task.await
 878        };
 879        assert!(edit_result.is_ok());
 880
 881        // Wait for any async operations (e.g. formatting) to complete
 882        cx.executor().run_until_parked();
 883
 884        // Verify the file was not formatted
 885        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 886        assert_eq!(
 887            // Ignore carriage returns on Windows
 888            new_content.replace("\r\n", "\n"),
 889            UNFORMATTED_CONTENT,
 890            "Code should not be formatted when format_on_save is disabled"
 891        );
 892    }
 893
 894    #[gpui::test]
 895    async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
 896        init_test(cx);
 897
 898        let fs = project::FakeFs::new(cx.executor());
 899        fs.insert_tree("/root", json!({"src": {}})).await;
 900
 901        // Create a simple file with trailing whitespace
 902        fs.save(
 903            path!("/root/src/main.rs").as_ref(),
 904            &"initial content".into(),
 905            language::LineEnding::Unix,
 906        )
 907        .await
 908        .unwrap();
 909
 910        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 911        let context_server_registry =
 912            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 913        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 914        let model = Arc::new(FakeLanguageModel::default());
 915        let thread = cx.new(|cx| {
 916            Thread::new(
 917                project.clone(),
 918                cx.new(|_cx| ProjectContext::default()),
 919                context_server_registry,
 920                Templates::new(),
 921                Some(model.clone()),
 922                cx,
 923            )
 924        });
 925
 926        // First, test with remove_trailing_whitespace_on_save enabled
 927        cx.update(|cx| {
 928            SettingsStore::update_global(cx, |store, cx| {
 929                store.update_user_settings(cx, |settings| {
 930                    settings
 931                        .project
 932                        .all_languages
 933                        .defaults
 934                        .remove_trailing_whitespace_on_save = Some(true);
 935                });
 936            });
 937        });
 938
 939        const CONTENT_WITH_TRAILING_WHITESPACE: &str =
 940            "fn main() {  \n    println!(\"Hello!\");  \n}\n";
 941
 942        // Have the model stream content that contains trailing whitespace
 943        let edit_result = {
 944            let edit_task = cx.update(|cx| {
 945                let input = EditFileToolInput {
 946                    display_description: "Create main function".into(),
 947                    path: "root/src/main.rs".into(),
 948                    mode: EditFileMode::Overwrite,
 949                };
 950                Arc::new(EditFileTool::new(
 951                    project.clone(),
 952                    thread.downgrade(),
 953                    language_registry.clone(),
 954                ))
 955                .run(input, ToolCallEventStream::test().0, cx)
 956            });
 957
 958            // Stream the content with trailing whitespace
 959            cx.executor().run_until_parked();
 960            model.send_last_completion_stream_text_chunk(
 961                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
 962            );
 963            model.end_last_completion_stream();
 964
 965            edit_task.await
 966        };
 967        assert!(edit_result.is_ok());
 968
 969        // Wait for any async operations (e.g. formatting) to complete
 970        cx.executor().run_until_parked();
 971
 972        // Read the file to verify trailing whitespace was removed automatically
 973        assert_eq!(
 974            // Ignore carriage returns on Windows
 975            fs.load(path!("/root/src/main.rs").as_ref())
 976                .await
 977                .unwrap()
 978                .replace("\r\n", "\n"),
 979            "fn main() {\n    println!(\"Hello!\");\n}\n",
 980            "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
 981        );
 982
 983        // Next, test with remove_trailing_whitespace_on_save disabled
 984        cx.update(|cx| {
 985            SettingsStore::update_global(cx, |store, cx| {
 986                store.update_user_settings(cx, |settings| {
 987                    settings
 988                        .project
 989                        .all_languages
 990                        .defaults
 991                        .remove_trailing_whitespace_on_save = Some(false);
 992                });
 993            });
 994        });
 995
 996        // Stream edits again with trailing whitespace
 997        let edit_result = {
 998            let edit_task = cx.update(|cx| {
 999                let input = EditFileToolInput {
1000                    display_description: "Update main function".into(),
1001                    path: "root/src/main.rs".into(),
1002                    mode: EditFileMode::Overwrite,
1003                };
1004                Arc::new(EditFileTool::new(
1005                    project.clone(),
1006                    thread.downgrade(),
1007                    language_registry,
1008                ))
1009                .run(input, ToolCallEventStream::test().0, cx)
1010            });
1011
1012            // Stream the content with trailing whitespace
1013            cx.executor().run_until_parked();
1014            model.send_last_completion_stream_text_chunk(
1015                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1016            );
1017            model.end_last_completion_stream();
1018
1019            edit_task.await
1020        };
1021        assert!(edit_result.is_ok());
1022
1023        // Wait for any async operations (e.g. formatting) to complete
1024        cx.executor().run_until_parked();
1025
1026        // Verify the file still has trailing whitespace
1027        // Read the file again - it should still have trailing whitespace
1028        let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1029        assert_eq!(
1030            // Ignore carriage returns on Windows
1031            final_content.replace("\r\n", "\n"),
1032            CONTENT_WITH_TRAILING_WHITESPACE,
1033            "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1034        );
1035    }
1036
1037    #[gpui::test]
1038    async fn test_authorize(cx: &mut TestAppContext) {
1039        init_test(cx);
1040        let fs = project::FakeFs::new(cx.executor());
1041        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1042        let context_server_registry =
1043            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1044        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1045        let model = Arc::new(FakeLanguageModel::default());
1046        let thread = cx.new(|cx| {
1047            Thread::new(
1048                project.clone(),
1049                cx.new(|_cx| ProjectContext::default()),
1050                context_server_registry,
1051                Templates::new(),
1052                Some(model.clone()),
1053                cx,
1054            )
1055        });
1056        let tool = Arc::new(EditFileTool::new(
1057            project.clone(),
1058            thread.downgrade(),
1059            language_registry,
1060        ));
1061        fs.insert_tree("/root", json!({})).await;
1062
1063        // Test 1: Path with .zed component should require confirmation
1064        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1065        let _auth = cx.update(|cx| {
1066            tool.authorize(
1067                &EditFileToolInput {
1068                    display_description: "test 1".into(),
1069                    path: ".zed/settings.json".into(),
1070                    mode: EditFileMode::Edit,
1071                },
1072                &stream_tx,
1073                cx,
1074            )
1075        });
1076
1077        let event = stream_rx.expect_authorization().await;
1078        assert_eq!(
1079            event.tool_call.fields.title,
1080            Some("test 1 (local settings)".into())
1081        );
1082
1083        // Test 2: Path outside project should require confirmation
1084        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1085        let _auth = cx.update(|cx| {
1086            tool.authorize(
1087                &EditFileToolInput {
1088                    display_description: "test 2".into(),
1089                    path: "/etc/hosts".into(),
1090                    mode: EditFileMode::Edit,
1091                },
1092                &stream_tx,
1093                cx,
1094            )
1095        });
1096
1097        let event = stream_rx.expect_authorization().await;
1098        assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1099
1100        // Test 3: Relative path without .zed should not require confirmation
1101        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1102        cx.update(|cx| {
1103            tool.authorize(
1104                &EditFileToolInput {
1105                    display_description: "test 3".into(),
1106                    path: "root/src/main.rs".into(),
1107                    mode: EditFileMode::Edit,
1108                },
1109                &stream_tx,
1110                cx,
1111            )
1112        })
1113        .await
1114        .unwrap();
1115        assert!(stream_rx.try_next().is_err());
1116
1117        // Test 4: Path with .zed in the middle should require confirmation
1118        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1119        let _auth = cx.update(|cx| {
1120            tool.authorize(
1121                &EditFileToolInput {
1122                    display_description: "test 4".into(),
1123                    path: "root/.zed/tasks.json".into(),
1124                    mode: EditFileMode::Edit,
1125                },
1126                &stream_tx,
1127                cx,
1128            )
1129        });
1130        let event = stream_rx.expect_authorization().await;
1131        assert_eq!(
1132            event.tool_call.fields.title,
1133            Some("test 4 (local settings)".into())
1134        );
1135
1136        // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1137        cx.update(|cx| {
1138            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1139            settings.always_allow_tool_actions = true;
1140            agent_settings::AgentSettings::override_global(settings, cx);
1141        });
1142
1143        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1144        cx.update(|cx| {
1145            tool.authorize(
1146                &EditFileToolInput {
1147                    display_description: "test 5.1".into(),
1148                    path: ".zed/settings.json".into(),
1149                    mode: EditFileMode::Edit,
1150                },
1151                &stream_tx,
1152                cx,
1153            )
1154        })
1155        .await
1156        .unwrap();
1157        assert!(stream_rx.try_next().is_err());
1158
1159        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1160        cx.update(|cx| {
1161            tool.authorize(
1162                &EditFileToolInput {
1163                    display_description: "test 5.2".into(),
1164                    path: "/etc/hosts".into(),
1165                    mode: EditFileMode::Edit,
1166                },
1167                &stream_tx,
1168                cx,
1169            )
1170        })
1171        .await
1172        .unwrap();
1173        assert!(stream_rx.try_next().is_err());
1174    }
1175
1176    #[gpui::test]
1177    async fn test_authorize_global_config(cx: &mut TestAppContext) {
1178        init_test(cx);
1179        let fs = project::FakeFs::new(cx.executor());
1180        fs.insert_tree("/project", json!({})).await;
1181        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1182        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1183        let context_server_registry =
1184            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1185        let model = Arc::new(FakeLanguageModel::default());
1186        let thread = cx.new(|cx| {
1187            Thread::new(
1188                project.clone(),
1189                cx.new(|_cx| ProjectContext::default()),
1190                context_server_registry,
1191                Templates::new(),
1192                Some(model.clone()),
1193                cx,
1194            )
1195        });
1196        let tool = Arc::new(EditFileTool::new(
1197            project.clone(),
1198            thread.downgrade(),
1199            language_registry,
1200        ));
1201
1202        // Test global config paths - these should require confirmation if they exist and are outside the project
1203        let test_cases = vec![
1204            (
1205                "/etc/hosts",
1206                true,
1207                "System file should require confirmation",
1208            ),
1209            (
1210                "/usr/local/bin/script",
1211                true,
1212                "System bin file should require confirmation",
1213            ),
1214            (
1215                "project/normal_file.rs",
1216                false,
1217                "Normal project file should not require confirmation",
1218            ),
1219        ];
1220
1221        for (path, should_confirm, description) in test_cases {
1222            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1223            let auth = cx.update(|cx| {
1224                tool.authorize(
1225                    &EditFileToolInput {
1226                        display_description: "Edit file".into(),
1227                        path: path.into(),
1228                        mode: EditFileMode::Edit,
1229                    },
1230                    &stream_tx,
1231                    cx,
1232                )
1233            });
1234
1235            if should_confirm {
1236                stream_rx.expect_authorization().await;
1237            } else {
1238                auth.await.unwrap();
1239                assert!(
1240                    stream_rx.try_next().is_err(),
1241                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1242                    description,
1243                    path
1244                );
1245            }
1246        }
1247    }
1248
1249    #[gpui::test]
1250    async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1251        init_test(cx);
1252        let fs = project::FakeFs::new(cx.executor());
1253
1254        // Create multiple worktree directories
1255        fs.insert_tree(
1256            "/workspace/frontend",
1257            json!({
1258                "src": {
1259                    "main.js": "console.log('frontend');"
1260                }
1261            }),
1262        )
1263        .await;
1264        fs.insert_tree(
1265            "/workspace/backend",
1266            json!({
1267                "src": {
1268                    "main.rs": "fn main() {}"
1269                }
1270            }),
1271        )
1272        .await;
1273        fs.insert_tree(
1274            "/workspace/shared",
1275            json!({
1276                ".zed": {
1277                    "settings.json": "{}"
1278                }
1279            }),
1280        )
1281        .await;
1282
1283        // Create project with multiple worktrees
1284        let project = Project::test(
1285            fs.clone(),
1286            [
1287                path!("/workspace/frontend").as_ref(),
1288                path!("/workspace/backend").as_ref(),
1289                path!("/workspace/shared").as_ref(),
1290            ],
1291            cx,
1292        )
1293        .await;
1294        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1295        let context_server_registry =
1296            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1297        let model = Arc::new(FakeLanguageModel::default());
1298        let thread = cx.new(|cx| {
1299            Thread::new(
1300                project.clone(),
1301                cx.new(|_cx| ProjectContext::default()),
1302                context_server_registry.clone(),
1303                Templates::new(),
1304                Some(model.clone()),
1305                cx,
1306            )
1307        });
1308        let tool = Arc::new(EditFileTool::new(
1309            project.clone(),
1310            thread.downgrade(),
1311            language_registry,
1312        ));
1313
1314        // Test files in different worktrees
1315        let test_cases = vec![
1316            ("frontend/src/main.js", false, "File in first worktree"),
1317            ("backend/src/main.rs", false, "File in second worktree"),
1318            (
1319                "shared/.zed/settings.json",
1320                true,
1321                ".zed file in third worktree",
1322            ),
1323            ("/etc/hosts", true, "Absolute path outside all worktrees"),
1324            (
1325                "../outside/file.txt",
1326                true,
1327                "Relative path outside worktrees",
1328            ),
1329        ];
1330
1331        for (path, should_confirm, description) in test_cases {
1332            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1333            let auth = cx.update(|cx| {
1334                tool.authorize(
1335                    &EditFileToolInput {
1336                        display_description: "Edit file".into(),
1337                        path: path.into(),
1338                        mode: EditFileMode::Edit,
1339                    },
1340                    &stream_tx,
1341                    cx,
1342                )
1343            });
1344
1345            if should_confirm {
1346                stream_rx.expect_authorization().await;
1347            } else {
1348                auth.await.unwrap();
1349                assert!(
1350                    stream_rx.try_next().is_err(),
1351                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1352                    description,
1353                    path
1354                );
1355            }
1356        }
1357    }
1358
1359    #[gpui::test]
1360    async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1361        init_test(cx);
1362        let fs = project::FakeFs::new(cx.executor());
1363        fs.insert_tree(
1364            "/project",
1365            json!({
1366                ".zed": {
1367                    "settings.json": "{}"
1368                },
1369                "src": {
1370                    ".zed": {
1371                        "local.json": "{}"
1372                    }
1373                }
1374            }),
1375        )
1376        .await;
1377        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1378        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1379        let context_server_registry =
1380            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1381        let model = Arc::new(FakeLanguageModel::default());
1382        let thread = cx.new(|cx| {
1383            Thread::new(
1384                project.clone(),
1385                cx.new(|_cx| ProjectContext::default()),
1386                context_server_registry.clone(),
1387                Templates::new(),
1388                Some(model.clone()),
1389                cx,
1390            )
1391        });
1392        let tool = Arc::new(EditFileTool::new(
1393            project.clone(),
1394            thread.downgrade(),
1395            language_registry,
1396        ));
1397
1398        // Test edge cases
1399        let test_cases = vec![
1400            // Empty path - find_project_path returns Some for empty paths
1401            ("", false, "Empty path is treated as project root"),
1402            // Root directory
1403            ("/", true, "Root directory should be outside project"),
1404            // Parent directory references - find_project_path resolves these
1405            (
1406                "project/../other",
1407                true,
1408                "Path with .. that goes outside of root directory",
1409            ),
1410            (
1411                "project/./src/file.rs",
1412                false,
1413                "Path with . should work normally",
1414            ),
1415            // Windows-style paths (if on Windows)
1416            #[cfg(target_os = "windows")]
1417            ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1418            #[cfg(target_os = "windows")]
1419            ("project\\src\\main.rs", false, "Windows-style project path"),
1420        ];
1421
1422        for (path, should_confirm, description) in test_cases {
1423            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1424            let auth = cx.update(|cx| {
1425                tool.authorize(
1426                    &EditFileToolInput {
1427                        display_description: "Edit file".into(),
1428                        path: path.into(),
1429                        mode: EditFileMode::Edit,
1430                    },
1431                    &stream_tx,
1432                    cx,
1433                )
1434            });
1435
1436            cx.run_until_parked();
1437
1438            if should_confirm {
1439                stream_rx.expect_authorization().await;
1440            } else {
1441                assert!(
1442                    stream_rx.try_next().is_err(),
1443                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1444                    description,
1445                    path
1446                );
1447                auth.await.unwrap();
1448            }
1449        }
1450    }
1451
1452    #[gpui::test]
1453    async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1454        init_test(cx);
1455        let fs = project::FakeFs::new(cx.executor());
1456        fs.insert_tree(
1457            "/project",
1458            json!({
1459                "existing.txt": "content",
1460                ".zed": {
1461                    "settings.json": "{}"
1462                }
1463            }),
1464        )
1465        .await;
1466        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1467        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1468        let context_server_registry =
1469            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1470        let model = Arc::new(FakeLanguageModel::default());
1471        let thread = cx.new(|cx| {
1472            Thread::new(
1473                project.clone(),
1474                cx.new(|_cx| ProjectContext::default()),
1475                context_server_registry.clone(),
1476                Templates::new(),
1477                Some(model.clone()),
1478                cx,
1479            )
1480        });
1481        let tool = Arc::new(EditFileTool::new(
1482            project.clone(),
1483            thread.downgrade(),
1484            language_registry,
1485        ));
1486
1487        // Test different EditFileMode values
1488        let modes = vec![
1489            EditFileMode::Edit,
1490            EditFileMode::Create,
1491            EditFileMode::Overwrite,
1492        ];
1493
1494        for mode in modes {
1495            // Test .zed path with different modes
1496            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1497            let _auth = cx.update(|cx| {
1498                tool.authorize(
1499                    &EditFileToolInput {
1500                        display_description: "Edit settings".into(),
1501                        path: "project/.zed/settings.json".into(),
1502                        mode: mode.clone(),
1503                    },
1504                    &stream_tx,
1505                    cx,
1506                )
1507            });
1508
1509            stream_rx.expect_authorization().await;
1510
1511            // Test outside path with different modes
1512            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1513            let _auth = cx.update(|cx| {
1514                tool.authorize(
1515                    &EditFileToolInput {
1516                        display_description: "Edit file".into(),
1517                        path: "/outside/file.txt".into(),
1518                        mode: mode.clone(),
1519                    },
1520                    &stream_tx,
1521                    cx,
1522                )
1523            });
1524
1525            stream_rx.expect_authorization().await;
1526
1527            // Test normal path with different modes
1528            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1529            cx.update(|cx| {
1530                tool.authorize(
1531                    &EditFileToolInput {
1532                        display_description: "Edit file".into(),
1533                        path: "project/normal.txt".into(),
1534                        mode: mode.clone(),
1535                    },
1536                    &stream_tx,
1537                    cx,
1538                )
1539            })
1540            .await
1541            .unwrap();
1542            assert!(stream_rx.try_next().is_err());
1543        }
1544    }
1545
1546    #[gpui::test]
1547    async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1548        init_test(cx);
1549        let fs = project::FakeFs::new(cx.executor());
1550        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1551        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1552        let context_server_registry =
1553            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1554        let model = Arc::new(FakeLanguageModel::default());
1555        let thread = cx.new(|cx| {
1556            Thread::new(
1557                project.clone(),
1558                cx.new(|_cx| ProjectContext::default()),
1559                context_server_registry,
1560                Templates::new(),
1561                Some(model.clone()),
1562                cx,
1563            )
1564        });
1565        let tool = Arc::new(EditFileTool::new(
1566            project,
1567            thread.downgrade(),
1568            language_registry,
1569        ));
1570
1571        cx.update(|cx| {
1572            // ...
1573            assert_eq!(
1574                tool.initial_title(
1575                    Err(json!({
1576                        "path": "src/main.rs",
1577                        "display_description": "",
1578                        "old_string": "old code",
1579                        "new_string": "new code"
1580                    })),
1581                    cx
1582                ),
1583                "src/main.rs"
1584            );
1585            assert_eq!(
1586                tool.initial_title(
1587                    Err(json!({
1588                        "path": "",
1589                        "display_description": "Fix error handling",
1590                        "old_string": "old code",
1591                        "new_string": "new code"
1592                    })),
1593                    cx
1594                ),
1595                "Fix error handling"
1596            );
1597            assert_eq!(
1598                tool.initial_title(
1599                    Err(json!({
1600                        "path": "src/main.rs",
1601                        "display_description": "Fix error handling",
1602                        "old_string": "old code",
1603                        "new_string": "new code"
1604                    })),
1605                    cx
1606                ),
1607                "src/main.rs"
1608            );
1609            assert_eq!(
1610                tool.initial_title(
1611                    Err(json!({
1612                        "path": "",
1613                        "display_description": "",
1614                        "old_string": "old code",
1615                        "new_string": "new code"
1616                    })),
1617                    cx
1618                ),
1619                DEFAULT_UI_TEXT
1620            );
1621            assert_eq!(
1622                tool.initial_title(Err(serde_json::Value::Null), cx),
1623                DEFAULT_UI_TEXT
1624            );
1625        });
1626    }
1627
1628    #[gpui::test]
1629    async fn test_diff_finalization(cx: &mut TestAppContext) {
1630        init_test(cx);
1631        let fs = project::FakeFs::new(cx.executor());
1632        fs.insert_tree("/", json!({"main.rs": ""})).await;
1633
1634        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1635        let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1636        let context_server_registry =
1637            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1638        let model = Arc::new(FakeLanguageModel::default());
1639        let thread = cx.new(|cx| {
1640            Thread::new(
1641                project.clone(),
1642                cx.new(|_cx| ProjectContext::default()),
1643                context_server_registry.clone(),
1644                Templates::new(),
1645                Some(model.clone()),
1646                cx,
1647            )
1648        });
1649
1650        // Ensure the diff is finalized after the edit completes.
1651        {
1652            let tool = Arc::new(EditFileTool::new(
1653                project.clone(),
1654                thread.downgrade(),
1655                languages.clone(),
1656            ));
1657            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1658            let edit = cx.update(|cx| {
1659                tool.run(
1660                    EditFileToolInput {
1661                        display_description: "Edit file".into(),
1662                        path: path!("/main.rs").into(),
1663                        mode: EditFileMode::Edit,
1664                    },
1665                    stream_tx,
1666                    cx,
1667                )
1668            });
1669            stream_rx.expect_update_fields().await;
1670            let diff = stream_rx.expect_diff().await;
1671            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1672            cx.run_until_parked();
1673            model.end_last_completion_stream();
1674            edit.await.unwrap();
1675            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1676        }
1677
1678        // Ensure the diff is finalized if an error occurs while editing.
1679        {
1680            model.forbid_requests();
1681            let tool = Arc::new(EditFileTool::new(
1682                project.clone(),
1683                thread.downgrade(),
1684                languages.clone(),
1685            ));
1686            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1687            let edit = cx.update(|cx| {
1688                tool.run(
1689                    EditFileToolInput {
1690                        display_description: "Edit file".into(),
1691                        path: path!("/main.rs").into(),
1692                        mode: EditFileMode::Edit,
1693                    },
1694                    stream_tx,
1695                    cx,
1696                )
1697            });
1698            stream_rx.expect_update_fields().await;
1699            let diff = stream_rx.expect_diff().await;
1700            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1701            edit.await.unwrap_err();
1702            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1703            model.allow_requests();
1704        }
1705
1706        // Ensure the diff is finalized if the tool call gets dropped.
1707        {
1708            let tool = Arc::new(EditFileTool::new(
1709                project.clone(),
1710                thread.downgrade(),
1711                languages.clone(),
1712            ));
1713            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1714            let edit = cx.update(|cx| {
1715                tool.run(
1716                    EditFileToolInput {
1717                        display_description: "Edit file".into(),
1718                        path: path!("/main.rs").into(),
1719                        mode: EditFileMode::Edit,
1720                    },
1721                    stream_tx,
1722                    cx,
1723                )
1724            });
1725            stream_rx.expect_update_fields().await;
1726            let diff = stream_rx.expect_diff().await;
1727            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1728            drop(edit);
1729            cx.run_until_parked();
1730            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1731        }
1732    }
1733
1734    fn init_test(cx: &mut TestAppContext) {
1735        cx.update(|cx| {
1736            let settings_store = SettingsStore::test(cx);
1737            cx.set_global(settings_store);
1738            language::init(cx);
1739            TelemetrySettings::register(cx);
1740            agent_settings::AgentSettings::register(cx);
1741            Project::init_settings(cx);
1742        });
1743    }
1744}