edit_file_tool.rs

   1use crate::{
   2    AgentTool, Templates, Thread, ToolCallEventStream,
   3    edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
   4};
   5use acp_thread::Diff;
   6use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
   7use anyhow::{Context as _, Result, anyhow};
   8use cloud_llm_client::CompletionIntent;
   9use collections::HashSet;
  10use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
  11use indoc::formatdoc;
  12use language::language_settings::{self, FormatOnSave};
  13use language::{LanguageRegistry, ToPoint};
  14use language_model::LanguageModelToolResultContent;
  15use paths;
  16use project::lsp_store::{FormatTrigger, LspFormatTarget};
  17use project::{Project, ProjectPath};
  18use schemars::JsonSchema;
  19use serde::{Deserialize, Serialize};
  20use settings::Settings;
  21use smol::stream::StreamExt as _;
  22use std::ffi::OsStr;
  23use std::path::{Path, PathBuf};
  24use std::sync::Arc;
  25use ui::SharedString;
  26use util::ResultExt;
  27use util::rel_path::RelPath;
  28
  29const DEFAULT_UI_TEXT: &str = "Editing file";
  30
  31/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
  32///
  33/// Before using this tool:
  34///
  35/// 1. Use the `read_file` tool to understand the file's contents and context
  36///
  37/// 2. Verify the directory path is correct (only applicable when creating new files):
  38///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
  39#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  40pub struct EditFileToolInput {
  41    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
  42    ///
  43    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
  44    ///
  45    /// NEVER mention the file path in this description.
  46    ///
  47    /// <example>Fix API endpoint URLs</example>
  48    /// <example>Update copyright year in `page_footer`</example>
  49    ///
  50    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
  51    pub display_description: String,
  52
  53    /// The full path of the file to create or modify in the project.
  54    ///
  55    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
  56    ///
  57    /// The following examples assume we have two root directories in the project:
  58    /// - /a/b/backend
  59    /// - /c/d/frontend
  60    ///
  61    /// <example>
  62    /// `backend/src/main.rs`
  63    ///
  64    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
  65    /// </example>
  66    ///
  67    /// <example>
  68    /// `frontend/db.js`
  69    /// </example>
  70    pub path: PathBuf,
  71    /// The mode of operation on the file. Possible values:
  72    /// - 'edit': Make granular edits to an existing file.
  73    /// - 'create': Create a new file if it doesn't exist.
  74    /// - 'overwrite': Replace the entire contents of an existing file.
  75    ///
  76    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
  77    pub mode: EditFileMode,
  78}
  79
  80#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  81struct EditFileToolPartialInput {
  82    #[serde(default)]
  83    path: String,
  84    #[serde(default)]
  85    display_description: String,
  86}
  87
  88#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  89#[serde(rename_all = "lowercase")]
  90#[schemars(inline)]
  91pub enum EditFileMode {
  92    Edit,
  93    Create,
  94    Overwrite,
  95}
  96
  97#[derive(Debug, Serialize, Deserialize)]
  98pub struct EditFileToolOutput {
  99    #[serde(alias = "original_path")]
 100    input_path: PathBuf,
 101    new_text: String,
 102    old_text: Arc<String>,
 103    #[serde(default)]
 104    diff: String,
 105    #[serde(alias = "raw_output")]
 106    edit_agent_output: EditAgentOutput,
 107}
 108
 109impl From<EditFileToolOutput> for LanguageModelToolResultContent {
 110    fn from(output: EditFileToolOutput) -> Self {
 111        if output.diff.is_empty() {
 112            "No edits were made.".into()
 113        } else {
 114            format!(
 115                "Edited {}:\n\n```diff\n{}\n```",
 116                output.input_path.display(),
 117                output.diff
 118            )
 119            .into()
 120        }
 121    }
 122}
 123
 124pub struct EditFileTool {
 125    thread: WeakEntity<Thread>,
 126    language_registry: Arc<LanguageRegistry>,
 127    project: Entity<Project>,
 128    templates: Arc<Templates>,
 129}
 130
 131impl EditFileTool {
 132    pub fn new(
 133        project: Entity<Project>,
 134        thread: WeakEntity<Thread>,
 135        language_registry: Arc<LanguageRegistry>,
 136        templates: Arc<Templates>,
 137    ) -> Self {
 138        Self {
 139            project,
 140            thread,
 141            language_registry,
 142            templates,
 143        }
 144    }
 145
 146    fn authorize(
 147        &self,
 148        input: &EditFileToolInput,
 149        event_stream: &ToolCallEventStream,
 150        cx: &mut App,
 151    ) -> Task<Result<()>> {
 152        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
 153            return Task::ready(Ok(()));
 154        }
 155
 156        // If any path component matches the local settings folder, then this could affect
 157        // the editor in ways beyond the project source, so prompt.
 158        let local_settings_folder = paths::local_settings_folder_name();
 159        let path = Path::new(&input.path);
 160        if path.components().any(|component| {
 161            component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
 162        }) {
 163            return event_stream.authorize(
 164                format!("{} (local settings)", input.display_description),
 165                cx,
 166            );
 167        }
 168
 169        // It's also possible that the global config dir is configured to be inside the project,
 170        // so check for that edge case too.
 171        // TODO this is broken when remoting
 172        if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
 173            && canonical_path.starts_with(paths::config_dir())
 174        {
 175            return event_stream.authorize(
 176                format!("{} (global settings)", input.display_description),
 177                cx,
 178            );
 179        }
 180
 181        // Check if path is inside the global config directory
 182        // First check if it's already inside project - if not, try to canonicalize
 183        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
 184            thread.project().read(cx).find_project_path(&input.path, cx)
 185        }) else {
 186            return Task::ready(Err(anyhow!("thread was dropped")));
 187        };
 188
 189        // If the path is inside the project, and it's not one of the above edge cases,
 190        // then no confirmation is necessary. Otherwise, confirmation is necessary.
 191        if project_path.is_some() {
 192            Task::ready(Ok(()))
 193        } else {
 194            event_stream.authorize(&input.display_description, cx)
 195        }
 196    }
 197}
 198
 199impl AgentTool for EditFileTool {
 200    type Input = EditFileToolInput;
 201    type Output = EditFileToolOutput;
 202
 203    fn name() -> &'static str {
 204        "edit_file"
 205    }
 206
 207    fn kind() -> acp::ToolKind {
 208        acp::ToolKind::Edit
 209    }
 210
 211    fn initial_title(
 212        &self,
 213        input: Result<Self::Input, serde_json::Value>,
 214        cx: &mut App,
 215    ) -> SharedString {
 216        match input {
 217            Ok(input) => self
 218                .project
 219                .read(cx)
 220                .find_project_path(&input.path, cx)
 221                .and_then(|project_path| {
 222                    self.project
 223                        .read(cx)
 224                        .short_full_path_for_project_path(&project_path, cx)
 225                })
 226                .unwrap_or(input.path.to_string_lossy().into_owned())
 227                .into(),
 228            Err(raw_input) => {
 229                if let Some(input) =
 230                    serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
 231                {
 232                    let path = input.path.trim();
 233                    if !path.is_empty() {
 234                        return self
 235                            .project
 236                            .read(cx)
 237                            .find_project_path(&input.path, cx)
 238                            .and_then(|project_path| {
 239                                self.project
 240                                    .read(cx)
 241                                    .short_full_path_for_project_path(&project_path, cx)
 242                            })
 243                            .unwrap_or(input.path)
 244                            .into();
 245                    }
 246
 247                    let description = input.display_description.trim();
 248                    if !description.is_empty() {
 249                        return description.to_string().into();
 250                    }
 251                }
 252
 253                DEFAULT_UI_TEXT.into()
 254            }
 255        }
 256    }
 257
 258    fn run(
 259        self: Arc<Self>,
 260        input: Self::Input,
 261        event_stream: ToolCallEventStream,
 262        cx: &mut App,
 263    ) -> Task<Result<Self::Output>> {
 264        let Ok(project) = self
 265            .thread
 266            .read_with(cx, |thread, _cx| thread.project().clone())
 267        else {
 268            return Task::ready(Err(anyhow!("thread was dropped")));
 269        };
 270        let project_path = match resolve_path(&input, project.clone(), cx) {
 271            Ok(path) => path,
 272            Err(err) => return Task::ready(Err(anyhow!(err))),
 273        };
 274        let abs_path = project.read(cx).absolute_path(&project_path, cx);
 275        if let Some(abs_path) = abs_path.clone() {
 276            event_stream.update_fields(ToolCallUpdateFields {
 277                locations: Some(vec![acp::ToolCallLocation {
 278                    path: abs_path,
 279                    line: None,
 280                    meta: None,
 281                }]),
 282                ..Default::default()
 283            });
 284        }
 285
 286        let authorize = self.authorize(&input, &event_stream, cx);
 287        cx.spawn(async move |cx: &mut AsyncApp| {
 288            authorize.await?;
 289
 290            let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
 291                let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
 292                (request, thread.model().cloned(), thread.action_log().clone())
 293            })?;
 294            let request = request?;
 295            let model = model.context("No language model configured")?;
 296
 297            let edit_format = EditFormat::from_model(model.clone())?;
 298            let edit_agent = EditAgent::new(
 299                model,
 300                project.clone(),
 301                action_log.clone(),
 302                self.templates.clone(),
 303                edit_format,
 304            );
 305
 306            let buffer = project
 307                .update(cx, |project, cx| {
 308                    project.open_buffer(project_path.clone(), cx)
 309                })?
 310                .await?;
 311
 312            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
 313            event_stream.update_diff(diff.clone());
 314            let _finalize_diff = util::defer({
 315               let diff = diff.downgrade();
 316               let mut cx = cx.clone();
 317               move || {
 318                   diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
 319               }
 320            });
 321
 322            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 323            let old_text = cx
 324                .background_spawn({
 325                    let old_snapshot = old_snapshot.clone();
 326                    async move { Arc::new(old_snapshot.text()) }
 327                })
 328                .await;
 329
 330
 331            let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
 332                edit_agent.edit(
 333                    buffer.clone(),
 334                    input.display_description.clone(),
 335                    &request,
 336                    cx,
 337                )
 338            } else {
 339                edit_agent.overwrite(
 340                    buffer.clone(),
 341                    input.display_description.clone(),
 342                    &request,
 343                    cx,
 344                )
 345            };
 346
 347            let mut hallucinated_old_text = false;
 348            let mut ambiguous_ranges = Vec::new();
 349            let mut emitted_location = false;
 350            while let Some(event) = events.next().await {
 351                match event {
 352                    EditAgentOutputEvent::Edited(range) => {
 353                        if !emitted_location {
 354                            let line = buffer.update(cx, |buffer, _cx| {
 355                                range.start.to_point(&buffer.snapshot()).row
 356                            }).ok();
 357                            if let Some(abs_path) = abs_path.clone() {
 358                                event_stream.update_fields(ToolCallUpdateFields {
 359                                    locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
 360                                    ..Default::default()
 361                                });
 362                            }
 363                            emitted_location = true;
 364                        }
 365                    },
 366                    EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
 367                    EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
 368                    EditAgentOutputEvent::ResolvingEditRange(range) => {
 369                        diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
 370                        // if !emitted_location {
 371                        //     let line = buffer.update(cx, |buffer, _cx| {
 372                        //         range.start.to_point(&buffer.snapshot()).row
 373                        //     }).ok();
 374                        //     if let Some(abs_path) = abs_path.clone() {
 375                        //         event_stream.update_fields(ToolCallUpdateFields {
 376                        //             locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
 377                        //             ..Default::default()
 378                        //         });
 379                        //     }
 380                        // }
 381                    }
 382                }
 383            }
 384
 385            // If format_on_save is enabled, format the buffer
 386            let format_on_save_enabled = buffer
 387                .read_with(cx, |buffer, cx| {
 388                    let settings = language_settings::language_settings(
 389                        buffer.language().map(|l| l.name()),
 390                        buffer.file(),
 391                        cx,
 392                    );
 393                    settings.format_on_save != FormatOnSave::Off
 394                })
 395                .unwrap_or(false);
 396
 397            let edit_agent_output = output.await?;
 398
 399            if format_on_save_enabled {
 400                action_log.update(cx, |log, cx| {
 401                    log.buffer_edited(buffer.clone(), cx);
 402                })?;
 403
 404                let format_task = project.update(cx, |project, cx| {
 405                    project.format(
 406                        HashSet::from_iter([buffer.clone()]),
 407                        LspFormatTarget::Buffers,
 408                        false, // Don't push to history since the tool did it.
 409                        FormatTrigger::Save,
 410                        cx,
 411                    )
 412                })?;
 413                format_task.await.log_err();
 414            }
 415
 416            project
 417                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
 418                .await?;
 419
 420            action_log.update(cx, |log, cx| {
 421                log.buffer_edited(buffer.clone(), cx);
 422            })?;
 423
 424            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 425            let (new_text, unified_diff) = cx
 426                .background_spawn({
 427                    let new_snapshot = new_snapshot.clone();
 428                    let old_text = old_text.clone();
 429                    async move {
 430                        let new_text = new_snapshot.text();
 431                        let diff = language::unified_diff(&old_text, &new_text);
 432                        (new_text, diff)
 433                    }
 434                })
 435                .await;
 436
 437            let input_path = input.path.display();
 438            if unified_diff.is_empty() {
 439                anyhow::ensure!(
 440                    !hallucinated_old_text,
 441                    formatdoc! {"
 442                        Some edits were produced but none of them could be applied.
 443                        Read the relevant sections of {input_path} again so that
 444                        I can perform the requested edits.
 445                    "}
 446                );
 447                anyhow::ensure!(
 448                    ambiguous_ranges.is_empty(),
 449                    {
 450                        let line_numbers = ambiguous_ranges
 451                            .iter()
 452                            .map(|range| range.start.to_string())
 453                            .collect::<Vec<_>>()
 454                            .join(", ");
 455                        formatdoc! {"
 456                            <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
 457                            relevant sections of {input_path} again and extend <old_text> so
 458                            that I can perform the requested edits.
 459                        "}
 460                    }
 461                );
 462            }
 463
 464            Ok(EditFileToolOutput {
 465                input_path: input.path,
 466                new_text,
 467                old_text,
 468                diff: unified_diff,
 469                edit_agent_output,
 470            })
 471        })
 472    }
 473
 474    fn replay(
 475        &self,
 476        _input: Self::Input,
 477        output: Self::Output,
 478        event_stream: ToolCallEventStream,
 479        cx: &mut App,
 480    ) -> Result<()> {
 481        event_stream.update_diff(cx.new(|cx| {
 482            Diff::finalized(
 483                output.input_path.to_string_lossy().into_owned(),
 484                Some(output.old_text.to_string()),
 485                output.new_text,
 486                self.language_registry.clone(),
 487                cx,
 488            )
 489        }));
 490        Ok(())
 491    }
 492}
 493
 494/// Validate that the file path is valid, meaning:
 495///
 496/// - For `edit` and `overwrite`, the path must point to an existing file.
 497/// - For `create`, the file must not already exist, but it's parent dir must exist.
 498fn resolve_path(
 499    input: &EditFileToolInput,
 500    project: Entity<Project>,
 501    cx: &mut App,
 502) -> Result<ProjectPath> {
 503    let project = project.read(cx);
 504
 505    match input.mode {
 506        EditFileMode::Edit | EditFileMode::Overwrite => {
 507            let path = project
 508                .find_project_path(&input.path, cx)
 509                .context("Can't edit file: path not found")?;
 510
 511            let entry = project
 512                .entry_for_path(&path, cx)
 513                .context("Can't edit file: path not found")?;
 514
 515            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
 516            Ok(path)
 517        }
 518
 519        EditFileMode::Create => {
 520            if let Some(path) = project.find_project_path(&input.path, cx) {
 521                anyhow::ensure!(
 522                    project.entry_for_path(&path, cx).is_none(),
 523                    "Can't create file: file already exists"
 524                );
 525            }
 526
 527            let parent_path = input
 528                .path
 529                .parent()
 530                .context("Can't create file: incorrect path")?;
 531
 532            let parent_project_path = project.find_project_path(&parent_path, cx);
 533
 534            let parent_entry = parent_project_path
 535                .as_ref()
 536                .and_then(|path| project.entry_for_path(path, cx))
 537                .context("Can't create file: parent directory doesn't exist")?;
 538
 539            anyhow::ensure!(
 540                parent_entry.is_dir(),
 541                "Can't create file: parent is not a directory"
 542            );
 543
 544            let file_name = input
 545                .path
 546                .file_name()
 547                .and_then(|file_name| file_name.to_str())
 548                .and_then(|file_name| RelPath::unix(file_name).ok())
 549                .context("Can't create file: invalid filename")?;
 550
 551            let new_file_path = parent_project_path.map(|parent| ProjectPath {
 552                path: parent.path.join(file_name),
 553                ..parent
 554            });
 555
 556            new_file_path.context("Can't create file")
 557        }
 558    }
 559}
 560
 561#[cfg(test)]
 562mod tests {
 563    use super::*;
 564    use crate::{ContextServerRegistry, Templates};
 565    use client::TelemetrySettings;
 566    use encoding::all::UTF_8;
 567    use fs::{Fs, encodings::EncodingWrapper};
 568    use gpui::{TestAppContext, UpdateGlobal};
 569    use language_model::fake_provider::FakeLanguageModel;
 570    use prompt_store::ProjectContext;
 571    use serde_json::json;
 572    use settings::SettingsStore;
 573    use text::Rope;
 574    use util::{path, rel_path::rel_path};
 575
 576    #[gpui::test]
 577    async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
 578        init_test(cx);
 579
 580        let fs = project::FakeFs::new(cx.executor());
 581        fs.insert_tree("/root", json!({})).await;
 582        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 583        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 584        let context_server_registry =
 585            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 586        let model = Arc::new(FakeLanguageModel::default());
 587        let thread = cx.new(|cx| {
 588            Thread::new(
 589                project.clone(),
 590                cx.new(|_cx| ProjectContext::default()),
 591                context_server_registry,
 592                Templates::new(),
 593                Some(model),
 594                cx,
 595            )
 596        });
 597        let result = cx
 598            .update(|cx| {
 599                let input = EditFileToolInput {
 600                    display_description: "Some edit".into(),
 601                    path: "root/nonexistent_file.txt".into(),
 602                    mode: EditFileMode::Edit,
 603                };
 604                Arc::new(EditFileTool::new(
 605                    project,
 606                    thread.downgrade(),
 607                    language_registry,
 608                    Templates::new(),
 609                ))
 610                .run(input, ToolCallEventStream::test().0, cx)
 611            })
 612            .await;
 613        assert_eq!(
 614            result.unwrap_err().to_string(),
 615            "Can't edit file: path not found"
 616        );
 617    }
 618
 619    #[gpui::test]
 620    async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
 621        let mode = &EditFileMode::Create;
 622
 623        let result = test_resolve_path(mode, "root/new.txt", cx);
 624        assert_resolved_path_eq(result.await, rel_path("new.txt"));
 625
 626        let result = test_resolve_path(mode, "new.txt", cx);
 627        assert_resolved_path_eq(result.await, rel_path("new.txt"));
 628
 629        let result = test_resolve_path(mode, "dir/new.txt", cx);
 630        assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
 631
 632        let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
 633        assert_eq!(
 634            result.await.unwrap_err().to_string(),
 635            "Can't create file: file already exists"
 636        );
 637
 638        let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
 639        assert_eq!(
 640            result.await.unwrap_err().to_string(),
 641            "Can't create file: parent directory doesn't exist"
 642        );
 643    }
 644
 645    #[gpui::test]
 646    async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
 647        let mode = &EditFileMode::Edit;
 648
 649        let path_with_root = "root/dir/subdir/existing.txt";
 650        let path_without_root = "dir/subdir/existing.txt";
 651        let result = test_resolve_path(mode, path_with_root, cx);
 652        assert_resolved_path_eq(result.await, rel_path(path_without_root));
 653
 654        let result = test_resolve_path(mode, path_without_root, cx);
 655        assert_resolved_path_eq(result.await, rel_path(path_without_root));
 656
 657        let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
 658        assert_eq!(
 659            result.await.unwrap_err().to_string(),
 660            "Can't edit file: path not found"
 661        );
 662
 663        let result = test_resolve_path(mode, "root/dir", cx);
 664        assert_eq!(
 665            result.await.unwrap_err().to_string(),
 666            "Can't edit file: path is a directory"
 667        );
 668    }
 669
 670    async fn test_resolve_path(
 671        mode: &EditFileMode,
 672        path: &str,
 673        cx: &mut TestAppContext,
 674    ) -> anyhow::Result<ProjectPath> {
 675        init_test(cx);
 676
 677        let fs = project::FakeFs::new(cx.executor());
 678        fs.insert_tree(
 679            "/root",
 680            json!({
 681                "dir": {
 682                    "subdir": {
 683                        "existing.txt": "hello"
 684                    }
 685                }
 686            }),
 687        )
 688        .await;
 689        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 690
 691        let input = EditFileToolInput {
 692            display_description: "Some edit".into(),
 693            path: path.into(),
 694            mode: mode.clone(),
 695        };
 696
 697        cx.update(|cx| resolve_path(&input, project, cx))
 698    }
 699
 700    #[track_caller]
 701    fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
 702        let actual = path.expect("Should return valid path").path;
 703        assert_eq!(actual.as_ref(), expected);
 704    }
 705
 706    #[gpui::test]
 707    async fn test_format_on_save(cx: &mut TestAppContext) {
 708        init_test(cx);
 709
 710        let fs = project::FakeFs::new(cx.executor());
 711        fs.insert_tree("/root", json!({"src": {}})).await;
 712
 713        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 714
 715        // Set up a Rust language with LSP formatting support
 716        let rust_language = Arc::new(language::Language::new(
 717            language::LanguageConfig {
 718                name: "Rust".into(),
 719                matcher: language::LanguageMatcher {
 720                    path_suffixes: vec!["rs".to_string()],
 721                    ..Default::default()
 722                },
 723                ..Default::default()
 724            },
 725            None,
 726        ));
 727
 728        // Register the language and fake LSP
 729        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
 730        language_registry.add(rust_language);
 731
 732        let mut fake_language_servers = language_registry.register_fake_lsp(
 733            "Rust",
 734            language::FakeLspAdapter {
 735                capabilities: lsp::ServerCapabilities {
 736                    document_formatting_provider: Some(lsp::OneOf::Left(true)),
 737                    ..Default::default()
 738                },
 739                ..Default::default()
 740            },
 741        );
 742
 743        // Create the file
 744        fs.save(
 745            path!("/root/src/main.rs").as_ref(),
 746            &Rope::from_str_small("initial content"),
 747            language::LineEnding::Unix,
 748            EncodingWrapper::new(UTF_8),
 749        )
 750        .await
 751        .unwrap();
 752
 753        // Open the buffer to trigger LSP initialization
 754        let buffer = project
 755            .update(cx, |project, cx| {
 756                project.open_local_buffer(path!("/root/src/main.rs"), cx)
 757            })
 758            .await
 759            .unwrap();
 760
 761        // Register the buffer with language servers
 762        let _handle = project.update(cx, |project, cx| {
 763            project.register_buffer_with_language_servers(&buffer, cx)
 764        });
 765
 766        const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
 767        const FORMATTED_CONTENT: &str =
 768            "This file was formatted by the fake formatter in the test.\n";
 769
 770        // Get the fake language server and set up formatting handler
 771        let fake_language_server = fake_language_servers.next().await.unwrap();
 772        fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
 773            |_, _| async move {
 774                Ok(Some(vec![lsp::TextEdit {
 775                    range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
 776                    new_text: FORMATTED_CONTENT.to_string(),
 777                }]))
 778            }
 779        });
 780
 781        let context_server_registry =
 782            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 783        let model = Arc::new(FakeLanguageModel::default());
 784        let thread = cx.new(|cx| {
 785            Thread::new(
 786                project.clone(),
 787                cx.new(|_cx| ProjectContext::default()),
 788                context_server_registry,
 789                Templates::new(),
 790                Some(model.clone()),
 791                cx,
 792            )
 793        });
 794
 795        // First, test with format_on_save enabled
 796        cx.update(|cx| {
 797            SettingsStore::update_global(cx, |store, cx| {
 798                store.update_user_settings(cx, |settings| {
 799                    settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
 800                    settings.project.all_languages.defaults.formatter =
 801                        Some(language::language_settings::FormatterList::default());
 802                });
 803            });
 804        });
 805
 806        // Have the model stream unformatted content
 807        let edit_result = {
 808            let edit_task = cx.update(|cx| {
 809                let input = EditFileToolInput {
 810                    display_description: "Create main function".into(),
 811                    path: "root/src/main.rs".into(),
 812                    mode: EditFileMode::Overwrite,
 813                };
 814                Arc::new(EditFileTool::new(
 815                    project.clone(),
 816                    thread.downgrade(),
 817                    language_registry.clone(),
 818                    Templates::new(),
 819                ))
 820                .run(input, ToolCallEventStream::test().0, cx)
 821            });
 822
 823            // Stream the unformatted content
 824            cx.executor().run_until_parked();
 825            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 826            model.end_last_completion_stream();
 827
 828            edit_task.await
 829        };
 830        assert!(edit_result.is_ok());
 831
 832        // Wait for any async operations (e.g. formatting) to complete
 833        cx.executor().run_until_parked();
 834
 835        // Read the file to verify it was formatted automatically
 836        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 837        assert_eq!(
 838            // Ignore carriage returns on Windows
 839            new_content.replace("\r\n", "\n"),
 840            FORMATTED_CONTENT,
 841            "Code should be formatted when format_on_save is enabled"
 842        );
 843
 844        let stale_buffer_count = thread
 845            .read_with(cx, |thread, _cx| thread.action_log.clone())
 846            .read_with(cx, |log, cx| log.stale_buffers(cx).count());
 847
 848        assert_eq!(
 849            stale_buffer_count, 0,
 850            "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
 851             This causes the agent to think the file was modified externally when it was just formatted.",
 852            stale_buffer_count
 853        );
 854
 855        // Next, test with format_on_save disabled
 856        cx.update(|cx| {
 857            SettingsStore::update_global(cx, |store, cx| {
 858                store.update_user_settings(cx, |settings| {
 859                    settings.project.all_languages.defaults.format_on_save =
 860                        Some(FormatOnSave::Off);
 861                });
 862            });
 863        });
 864
 865        // Stream unformatted edits again
 866        let edit_result = {
 867            let edit_task = cx.update(|cx| {
 868                let input = EditFileToolInput {
 869                    display_description: "Update main function".into(),
 870                    path: "root/src/main.rs".into(),
 871                    mode: EditFileMode::Overwrite,
 872                };
 873                Arc::new(EditFileTool::new(
 874                    project.clone(),
 875                    thread.downgrade(),
 876                    language_registry,
 877                    Templates::new(),
 878                ))
 879                .run(input, ToolCallEventStream::test().0, cx)
 880            });
 881
 882            // Stream the unformatted content
 883            cx.executor().run_until_parked();
 884            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 885            model.end_last_completion_stream();
 886
 887            edit_task.await
 888        };
 889        assert!(edit_result.is_ok());
 890
 891        // Wait for any async operations (e.g. formatting) to complete
 892        cx.executor().run_until_parked();
 893
 894        // Verify the file was not formatted
 895        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 896        assert_eq!(
 897            // Ignore carriage returns on Windows
 898            new_content.replace("\r\n", "\n"),
 899            UNFORMATTED_CONTENT,
 900            "Code should not be formatted when format_on_save is disabled"
 901        );
 902    }
 903
 904    #[gpui::test]
 905    async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
 906        init_test(cx);
 907
 908        let fs = project::FakeFs::new(cx.executor());
 909        fs.insert_tree("/root", json!({"src": {}})).await;
 910
 911        // Create a simple file with trailing whitespace
 912        fs.save(
 913            path!("/root/src/main.rs").as_ref(),
 914            &Rope::from_str_small("initial content"),
 915            language::LineEnding::Unix,
 916            EncodingWrapper::new(UTF_8),
 917        )
 918        .await
 919        .unwrap();
 920
 921        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 922        let context_server_registry =
 923            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 924        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 925        let model = Arc::new(FakeLanguageModel::default());
 926        let thread = cx.new(|cx| {
 927            Thread::new(
 928                project.clone(),
 929                cx.new(|_cx| ProjectContext::default()),
 930                context_server_registry,
 931                Templates::new(),
 932                Some(model.clone()),
 933                cx,
 934            )
 935        });
 936
 937        // First, test with remove_trailing_whitespace_on_save enabled
 938        cx.update(|cx| {
 939            SettingsStore::update_global(cx, |store, cx| {
 940                store.update_user_settings(cx, |settings| {
 941                    settings
 942                        .project
 943                        .all_languages
 944                        .defaults
 945                        .remove_trailing_whitespace_on_save = Some(true);
 946                });
 947            });
 948        });
 949
 950        const CONTENT_WITH_TRAILING_WHITESPACE: &str =
 951            "fn main() {  \n    println!(\"Hello!\");  \n}\n";
 952
 953        // Have the model stream content that contains trailing whitespace
 954        let edit_result = {
 955            let edit_task = cx.update(|cx| {
 956                let input = EditFileToolInput {
 957                    display_description: "Create main function".into(),
 958                    path: "root/src/main.rs".into(),
 959                    mode: EditFileMode::Overwrite,
 960                };
 961                Arc::new(EditFileTool::new(
 962                    project.clone(),
 963                    thread.downgrade(),
 964                    language_registry.clone(),
 965                    Templates::new(),
 966                ))
 967                .run(input, ToolCallEventStream::test().0, cx)
 968            });
 969
 970            // Stream the content with trailing whitespace
 971            cx.executor().run_until_parked();
 972            model.send_last_completion_stream_text_chunk(
 973                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
 974            );
 975            model.end_last_completion_stream();
 976
 977            edit_task.await
 978        };
 979        assert!(edit_result.is_ok());
 980
 981        // Wait for any async operations (e.g. formatting) to complete
 982        cx.executor().run_until_parked();
 983
 984        // Read the file to verify trailing whitespace was removed automatically
 985        assert_eq!(
 986            // Ignore carriage returns on Windows
 987            fs.load(path!("/root/src/main.rs").as_ref())
 988                .await
 989                .unwrap()
 990                .replace("\r\n", "\n"),
 991            "fn main() {\n    println!(\"Hello!\");\n}\n",
 992            "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
 993        );
 994
 995        // Next, test with remove_trailing_whitespace_on_save disabled
 996        cx.update(|cx| {
 997            SettingsStore::update_global(cx, |store, cx| {
 998                store.update_user_settings(cx, |settings| {
 999                    settings
1000                        .project
1001                        .all_languages
1002                        .defaults
1003                        .remove_trailing_whitespace_on_save = Some(false);
1004                });
1005            });
1006        });
1007
1008        // Stream edits again with trailing whitespace
1009        let edit_result = {
1010            let edit_task = cx.update(|cx| {
1011                let input = EditFileToolInput {
1012                    display_description: "Update main function".into(),
1013                    path: "root/src/main.rs".into(),
1014                    mode: EditFileMode::Overwrite,
1015                };
1016                Arc::new(EditFileTool::new(
1017                    project.clone(),
1018                    thread.downgrade(),
1019                    language_registry,
1020                    Templates::new(),
1021                ))
1022                .run(input, ToolCallEventStream::test().0, cx)
1023            });
1024
1025            // Stream the content with trailing whitespace
1026            cx.executor().run_until_parked();
1027            model.send_last_completion_stream_text_chunk(
1028                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1029            );
1030            model.end_last_completion_stream();
1031
1032            edit_task.await
1033        };
1034        assert!(edit_result.is_ok());
1035
1036        // Wait for any async operations (e.g. formatting) to complete
1037        cx.executor().run_until_parked();
1038
1039        // Verify the file still has trailing whitespace
1040        // Read the file again - it should still have trailing whitespace
1041        let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1042        assert_eq!(
1043            // Ignore carriage returns on Windows
1044            final_content.replace("\r\n", "\n"),
1045            CONTENT_WITH_TRAILING_WHITESPACE,
1046            "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1047        );
1048    }
1049
1050    #[gpui::test]
1051    async fn test_authorize(cx: &mut TestAppContext) {
1052        init_test(cx);
1053        let fs = project::FakeFs::new(cx.executor());
1054        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1055        let context_server_registry =
1056            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1057        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1058        let model = Arc::new(FakeLanguageModel::default());
1059        let thread = cx.new(|cx| {
1060            Thread::new(
1061                project.clone(),
1062                cx.new(|_cx| ProjectContext::default()),
1063                context_server_registry,
1064                Templates::new(),
1065                Some(model.clone()),
1066                cx,
1067            )
1068        });
1069        let tool = Arc::new(EditFileTool::new(
1070            project.clone(),
1071            thread.downgrade(),
1072            language_registry,
1073            Templates::new(),
1074        ));
1075        fs.insert_tree("/root", json!({})).await;
1076
1077        // Test 1: Path with .zed component should require confirmation
1078        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1079        let _auth = cx.update(|cx| {
1080            tool.authorize(
1081                &EditFileToolInput {
1082                    display_description: "test 1".into(),
1083                    path: ".zed/settings.json".into(),
1084                    mode: EditFileMode::Edit,
1085                },
1086                &stream_tx,
1087                cx,
1088            )
1089        });
1090
1091        let event = stream_rx.expect_authorization().await;
1092        assert_eq!(
1093            event.tool_call.fields.title,
1094            Some("test 1 (local settings)".into())
1095        );
1096
1097        // Test 2: Path outside project should require confirmation
1098        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1099        let _auth = cx.update(|cx| {
1100            tool.authorize(
1101                &EditFileToolInput {
1102                    display_description: "test 2".into(),
1103                    path: "/etc/hosts".into(),
1104                    mode: EditFileMode::Edit,
1105                },
1106                &stream_tx,
1107                cx,
1108            )
1109        });
1110
1111        let event = stream_rx.expect_authorization().await;
1112        assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1113
1114        // Test 3: Relative path without .zed should not require confirmation
1115        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1116        cx.update(|cx| {
1117            tool.authorize(
1118                &EditFileToolInput {
1119                    display_description: "test 3".into(),
1120                    path: "root/src/main.rs".into(),
1121                    mode: EditFileMode::Edit,
1122                },
1123                &stream_tx,
1124                cx,
1125            )
1126        })
1127        .await
1128        .unwrap();
1129        assert!(stream_rx.try_next().is_err());
1130
1131        // Test 4: Path with .zed in the middle should require confirmation
1132        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1133        let _auth = cx.update(|cx| {
1134            tool.authorize(
1135                &EditFileToolInput {
1136                    display_description: "test 4".into(),
1137                    path: "root/.zed/tasks.json".into(),
1138                    mode: EditFileMode::Edit,
1139                },
1140                &stream_tx,
1141                cx,
1142            )
1143        });
1144        let event = stream_rx.expect_authorization().await;
1145        assert_eq!(
1146            event.tool_call.fields.title,
1147            Some("test 4 (local settings)".into())
1148        );
1149
1150        // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1151        cx.update(|cx| {
1152            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1153            settings.always_allow_tool_actions = true;
1154            agent_settings::AgentSettings::override_global(settings, cx);
1155        });
1156
1157        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1158        cx.update(|cx| {
1159            tool.authorize(
1160                &EditFileToolInput {
1161                    display_description: "test 5.1".into(),
1162                    path: ".zed/settings.json".into(),
1163                    mode: EditFileMode::Edit,
1164                },
1165                &stream_tx,
1166                cx,
1167            )
1168        })
1169        .await
1170        .unwrap();
1171        assert!(stream_rx.try_next().is_err());
1172
1173        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1174        cx.update(|cx| {
1175            tool.authorize(
1176                &EditFileToolInput {
1177                    display_description: "test 5.2".into(),
1178                    path: "/etc/hosts".into(),
1179                    mode: EditFileMode::Edit,
1180                },
1181                &stream_tx,
1182                cx,
1183            )
1184        })
1185        .await
1186        .unwrap();
1187        assert!(stream_rx.try_next().is_err());
1188    }
1189
1190    #[gpui::test]
1191    async fn test_authorize_global_config(cx: &mut TestAppContext) {
1192        init_test(cx);
1193        let fs = project::FakeFs::new(cx.executor());
1194        fs.insert_tree("/project", json!({})).await;
1195        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1196        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1197        let context_server_registry =
1198            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1199        let model = Arc::new(FakeLanguageModel::default());
1200        let thread = cx.new(|cx| {
1201            Thread::new(
1202                project.clone(),
1203                cx.new(|_cx| ProjectContext::default()),
1204                context_server_registry,
1205                Templates::new(),
1206                Some(model.clone()),
1207                cx,
1208            )
1209        });
1210        let tool = Arc::new(EditFileTool::new(
1211            project.clone(),
1212            thread.downgrade(),
1213            language_registry,
1214            Templates::new(),
1215        ));
1216
1217        // Test global config paths - these should require confirmation if they exist and are outside the project
1218        let test_cases = vec![
1219            (
1220                "/etc/hosts",
1221                true,
1222                "System file should require confirmation",
1223            ),
1224            (
1225                "/usr/local/bin/script",
1226                true,
1227                "System bin file should require confirmation",
1228            ),
1229            (
1230                "project/normal_file.rs",
1231                false,
1232                "Normal project file should not require confirmation",
1233            ),
1234        ];
1235
1236        for (path, should_confirm, description) in test_cases {
1237            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1238            let auth = cx.update(|cx| {
1239                tool.authorize(
1240                    &EditFileToolInput {
1241                        display_description: "Edit file".into(),
1242                        path: path.into(),
1243                        mode: EditFileMode::Edit,
1244                    },
1245                    &stream_tx,
1246                    cx,
1247                )
1248            });
1249
1250            if should_confirm {
1251                stream_rx.expect_authorization().await;
1252            } else {
1253                auth.await.unwrap();
1254                assert!(
1255                    stream_rx.try_next().is_err(),
1256                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1257                    description,
1258                    path
1259                );
1260            }
1261        }
1262    }
1263
1264    #[gpui::test]
1265    async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1266        init_test(cx);
1267        let fs = project::FakeFs::new(cx.executor());
1268
1269        // Create multiple worktree directories
1270        fs.insert_tree(
1271            "/workspace/frontend",
1272            json!({
1273                "src": {
1274                    "main.js": "console.log('frontend');"
1275                }
1276            }),
1277        )
1278        .await;
1279        fs.insert_tree(
1280            "/workspace/backend",
1281            json!({
1282                "src": {
1283                    "main.rs": "fn main() {}"
1284                }
1285            }),
1286        )
1287        .await;
1288        fs.insert_tree(
1289            "/workspace/shared",
1290            json!({
1291                ".zed": {
1292                    "settings.json": "{}"
1293                }
1294            }),
1295        )
1296        .await;
1297
1298        // Create project with multiple worktrees
1299        let project = Project::test(
1300            fs.clone(),
1301            [
1302                path!("/workspace/frontend").as_ref(),
1303                path!("/workspace/backend").as_ref(),
1304                path!("/workspace/shared").as_ref(),
1305            ],
1306            cx,
1307        )
1308        .await;
1309        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1310        let context_server_registry =
1311            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1312        let model = Arc::new(FakeLanguageModel::default());
1313        let thread = cx.new(|cx| {
1314            Thread::new(
1315                project.clone(),
1316                cx.new(|_cx| ProjectContext::default()),
1317                context_server_registry.clone(),
1318                Templates::new(),
1319                Some(model.clone()),
1320                cx,
1321            )
1322        });
1323        let tool = Arc::new(EditFileTool::new(
1324            project.clone(),
1325            thread.downgrade(),
1326            language_registry,
1327            Templates::new(),
1328        ));
1329
1330        // Test files in different worktrees
1331        let test_cases = vec![
1332            ("frontend/src/main.js", false, "File in first worktree"),
1333            ("backend/src/main.rs", false, "File in second worktree"),
1334            (
1335                "shared/.zed/settings.json",
1336                true,
1337                ".zed file in third worktree",
1338            ),
1339            ("/etc/hosts", true, "Absolute path outside all worktrees"),
1340            (
1341                "../outside/file.txt",
1342                true,
1343                "Relative path outside worktrees",
1344            ),
1345        ];
1346
1347        for (path, should_confirm, description) in test_cases {
1348            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1349            let auth = cx.update(|cx| {
1350                tool.authorize(
1351                    &EditFileToolInput {
1352                        display_description: "Edit file".into(),
1353                        path: path.into(),
1354                        mode: EditFileMode::Edit,
1355                    },
1356                    &stream_tx,
1357                    cx,
1358                )
1359            });
1360
1361            if should_confirm {
1362                stream_rx.expect_authorization().await;
1363            } else {
1364                auth.await.unwrap();
1365                assert!(
1366                    stream_rx.try_next().is_err(),
1367                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1368                    description,
1369                    path
1370                );
1371            }
1372        }
1373    }
1374
1375    #[gpui::test]
1376    async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1377        init_test(cx);
1378        let fs = project::FakeFs::new(cx.executor());
1379        fs.insert_tree(
1380            "/project",
1381            json!({
1382                ".zed": {
1383                    "settings.json": "{}"
1384                },
1385                "src": {
1386                    ".zed": {
1387                        "local.json": "{}"
1388                    }
1389                }
1390            }),
1391        )
1392        .await;
1393        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1394        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1395        let context_server_registry =
1396            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1397        let model = Arc::new(FakeLanguageModel::default());
1398        let thread = cx.new(|cx| {
1399            Thread::new(
1400                project.clone(),
1401                cx.new(|_cx| ProjectContext::default()),
1402                context_server_registry.clone(),
1403                Templates::new(),
1404                Some(model.clone()),
1405                cx,
1406            )
1407        });
1408        let tool = Arc::new(EditFileTool::new(
1409            project.clone(),
1410            thread.downgrade(),
1411            language_registry,
1412            Templates::new(),
1413        ));
1414
1415        // Test edge cases
1416        let test_cases = vec![
1417            // Empty path - find_project_path returns Some for empty paths
1418            ("", false, "Empty path is treated as project root"),
1419            // Root directory
1420            ("/", true, "Root directory should be outside project"),
1421            // Parent directory references - find_project_path resolves these
1422            (
1423                "project/../other",
1424                true,
1425                "Path with .. that goes outside of root directory",
1426            ),
1427            (
1428                "project/./src/file.rs",
1429                false,
1430                "Path with . should work normally",
1431            ),
1432            // Windows-style paths (if on Windows)
1433            #[cfg(target_os = "windows")]
1434            ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1435            #[cfg(target_os = "windows")]
1436            ("project\\src\\main.rs", false, "Windows-style project path"),
1437        ];
1438
1439        for (path, should_confirm, description) in test_cases {
1440            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1441            let auth = cx.update(|cx| {
1442                tool.authorize(
1443                    &EditFileToolInput {
1444                        display_description: "Edit file".into(),
1445                        path: path.into(),
1446                        mode: EditFileMode::Edit,
1447                    },
1448                    &stream_tx,
1449                    cx,
1450                )
1451            });
1452
1453            cx.run_until_parked();
1454
1455            if should_confirm {
1456                stream_rx.expect_authorization().await;
1457            } else {
1458                assert!(
1459                    stream_rx.try_next().is_err(),
1460                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1461                    description,
1462                    path
1463                );
1464                auth.await.unwrap();
1465            }
1466        }
1467    }
1468
1469    #[gpui::test]
1470    async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1471        init_test(cx);
1472        let fs = project::FakeFs::new(cx.executor());
1473        fs.insert_tree(
1474            "/project",
1475            json!({
1476                "existing.txt": "content",
1477                ".zed": {
1478                    "settings.json": "{}"
1479                }
1480            }),
1481        )
1482        .await;
1483        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1484        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1485        let context_server_registry =
1486            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1487        let model = Arc::new(FakeLanguageModel::default());
1488        let thread = cx.new(|cx| {
1489            Thread::new(
1490                project.clone(),
1491                cx.new(|_cx| ProjectContext::default()),
1492                context_server_registry.clone(),
1493                Templates::new(),
1494                Some(model.clone()),
1495                cx,
1496            )
1497        });
1498        let tool = Arc::new(EditFileTool::new(
1499            project.clone(),
1500            thread.downgrade(),
1501            language_registry,
1502            Templates::new(),
1503        ));
1504
1505        // Test different EditFileMode values
1506        let modes = vec![
1507            EditFileMode::Edit,
1508            EditFileMode::Create,
1509            EditFileMode::Overwrite,
1510        ];
1511
1512        for mode in modes {
1513            // Test .zed path with different modes
1514            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1515            let _auth = cx.update(|cx| {
1516                tool.authorize(
1517                    &EditFileToolInput {
1518                        display_description: "Edit settings".into(),
1519                        path: "project/.zed/settings.json".into(),
1520                        mode: mode.clone(),
1521                    },
1522                    &stream_tx,
1523                    cx,
1524                )
1525            });
1526
1527            stream_rx.expect_authorization().await;
1528
1529            // Test outside path with different modes
1530            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1531            let _auth = cx.update(|cx| {
1532                tool.authorize(
1533                    &EditFileToolInput {
1534                        display_description: "Edit file".into(),
1535                        path: "/outside/file.txt".into(),
1536                        mode: mode.clone(),
1537                    },
1538                    &stream_tx,
1539                    cx,
1540                )
1541            });
1542
1543            stream_rx.expect_authorization().await;
1544
1545            // Test normal path with different modes
1546            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1547            cx.update(|cx| {
1548                tool.authorize(
1549                    &EditFileToolInput {
1550                        display_description: "Edit file".into(),
1551                        path: "project/normal.txt".into(),
1552                        mode: mode.clone(),
1553                    },
1554                    &stream_tx,
1555                    cx,
1556                )
1557            })
1558            .await
1559            .unwrap();
1560            assert!(stream_rx.try_next().is_err());
1561        }
1562    }
1563
1564    #[gpui::test]
1565    async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1566        init_test(cx);
1567        let fs = project::FakeFs::new(cx.executor());
1568        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1569        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1570        let context_server_registry =
1571            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1572        let model = Arc::new(FakeLanguageModel::default());
1573        let thread = cx.new(|cx| {
1574            Thread::new(
1575                project.clone(),
1576                cx.new(|_cx| ProjectContext::default()),
1577                context_server_registry,
1578                Templates::new(),
1579                Some(model.clone()),
1580                cx,
1581            )
1582        });
1583        let tool = Arc::new(EditFileTool::new(
1584            project,
1585            thread.downgrade(),
1586            language_registry,
1587            Templates::new(),
1588        ));
1589
1590        cx.update(|cx| {
1591            // ...
1592            assert_eq!(
1593                tool.initial_title(
1594                    Err(json!({
1595                        "path": "src/main.rs",
1596                        "display_description": "",
1597                        "old_string": "old code",
1598                        "new_string": "new code"
1599                    })),
1600                    cx
1601                ),
1602                "src/main.rs"
1603            );
1604            assert_eq!(
1605                tool.initial_title(
1606                    Err(json!({
1607                        "path": "",
1608                        "display_description": "Fix error handling",
1609                        "old_string": "old code",
1610                        "new_string": "new code"
1611                    })),
1612                    cx
1613                ),
1614                "Fix error handling"
1615            );
1616            assert_eq!(
1617                tool.initial_title(
1618                    Err(json!({
1619                        "path": "src/main.rs",
1620                        "display_description": "Fix error handling",
1621                        "old_string": "old code",
1622                        "new_string": "new code"
1623                    })),
1624                    cx
1625                ),
1626                "src/main.rs"
1627            );
1628            assert_eq!(
1629                tool.initial_title(
1630                    Err(json!({
1631                        "path": "",
1632                        "display_description": "",
1633                        "old_string": "old code",
1634                        "new_string": "new code"
1635                    })),
1636                    cx
1637                ),
1638                DEFAULT_UI_TEXT
1639            );
1640            assert_eq!(
1641                tool.initial_title(Err(serde_json::Value::Null), cx),
1642                DEFAULT_UI_TEXT
1643            );
1644        });
1645    }
1646
1647    #[gpui::test]
1648    async fn test_diff_finalization(cx: &mut TestAppContext) {
1649        init_test(cx);
1650        let fs = project::FakeFs::new(cx.executor());
1651        fs.insert_tree("/", json!({"main.rs": ""})).await;
1652
1653        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1654        let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1655        let context_server_registry =
1656            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1657        let model = Arc::new(FakeLanguageModel::default());
1658        let thread = cx.new(|cx| {
1659            Thread::new(
1660                project.clone(),
1661                cx.new(|_cx| ProjectContext::default()),
1662                context_server_registry.clone(),
1663                Templates::new(),
1664                Some(model.clone()),
1665                cx,
1666            )
1667        });
1668
1669        // Ensure the diff is finalized after the edit completes.
1670        {
1671            let tool = Arc::new(EditFileTool::new(
1672                project.clone(),
1673                thread.downgrade(),
1674                languages.clone(),
1675                Templates::new(),
1676            ));
1677            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1678            let edit = cx.update(|cx| {
1679                tool.run(
1680                    EditFileToolInput {
1681                        display_description: "Edit file".into(),
1682                        path: path!("/main.rs").into(),
1683                        mode: EditFileMode::Edit,
1684                    },
1685                    stream_tx,
1686                    cx,
1687                )
1688            });
1689            stream_rx.expect_update_fields().await;
1690            let diff = stream_rx.expect_diff().await;
1691            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1692            cx.run_until_parked();
1693            model.end_last_completion_stream();
1694            edit.await.unwrap();
1695            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1696        }
1697
1698        // Ensure the diff is finalized if an error occurs while editing.
1699        {
1700            model.forbid_requests();
1701            let tool = Arc::new(EditFileTool::new(
1702                project.clone(),
1703                thread.downgrade(),
1704                languages.clone(),
1705                Templates::new(),
1706            ));
1707            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1708            let edit = cx.update(|cx| {
1709                tool.run(
1710                    EditFileToolInput {
1711                        display_description: "Edit file".into(),
1712                        path: path!("/main.rs").into(),
1713                        mode: EditFileMode::Edit,
1714                    },
1715                    stream_tx,
1716                    cx,
1717                )
1718            });
1719            stream_rx.expect_update_fields().await;
1720            let diff = stream_rx.expect_diff().await;
1721            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1722            edit.await.unwrap_err();
1723            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1724            model.allow_requests();
1725        }
1726
1727        // Ensure the diff is finalized if the tool call gets dropped.
1728        {
1729            let tool = Arc::new(EditFileTool::new(
1730                project.clone(),
1731                thread.downgrade(),
1732                languages.clone(),
1733                Templates::new(),
1734            ));
1735            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1736            let edit = cx.update(|cx| {
1737                tool.run(
1738                    EditFileToolInput {
1739                        display_description: "Edit file".into(),
1740                        path: path!("/main.rs").into(),
1741                        mode: EditFileMode::Edit,
1742                    },
1743                    stream_tx,
1744                    cx,
1745                )
1746            });
1747            stream_rx.expect_update_fields().await;
1748            let diff = stream_rx.expect_diff().await;
1749            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1750            drop(edit);
1751            cx.run_until_parked();
1752            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1753        }
1754    }
1755
1756    fn init_test(cx: &mut TestAppContext) {
1757        cx.update(|cx| {
1758            let settings_store = SettingsStore::test(cx);
1759            cx.set_global(settings_store);
1760            language::init(cx);
1761            TelemetrySettings::register(cx);
1762            agent_settings::AgentSettings::register(cx);
1763            Project::init_settings(cx);
1764        });
1765    }
1766}