edit_file_tool.rs

   1use crate::{AgentTool, Thread, ToolCallEventStream};
   2use acp_thread::Diff;
   3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
   4use anyhow::{Context as _, Result, anyhow};
   5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
   9use indoc::formatdoc;
  10use language::language_settings::{self, FormatOnSave};
  11use language::{LanguageRegistry, ToPoint};
  12use language_model::LanguageModelToolResultContent;
  13use paths;
  14use project::lsp_store::{FormatTrigger, LspFormatTarget};
  15use project::{Project, ProjectPath};
  16use schemars::JsonSchema;
  17use serde::{Deserialize, Serialize};
  18use settings::Settings;
  19use smol::stream::StreamExt as _;
  20use std::ffi::OsStr;
  21use std::path::{Path, PathBuf};
  22use std::sync::Arc;
  23use ui::SharedString;
  24use util::ResultExt;
  25use util::rel_path::RelPath;
  26
  27const DEFAULT_UI_TEXT: &str = "Editing file";
  28
  29/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
  30///
  31/// Before using this tool:
  32///
  33/// 1. Use the `read_file` tool to understand the file's contents and context
  34///
  35/// 2. Verify the directory path is correct (only applicable when creating new files):
  36///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
  37#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  38pub struct EditFileToolInput {
  39    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
  40    ///
  41    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
  42    ///
  43    /// NEVER mention the file path in this description.
  44    ///
  45    /// <example>Fix API endpoint URLs</example>
  46    /// <example>Update copyright year in `page_footer`</example>
  47    ///
  48    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
  49    pub display_description: String,
  50
  51    /// The full path of the file to create or modify in the project.
  52    ///
  53    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
  54    ///
  55    /// The following examples assume we have two root directories in the project:
  56    /// - /a/b/backend
  57    /// - /c/d/frontend
  58    ///
  59    /// <example>
  60    /// `backend/src/main.rs`
  61    ///
  62    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
  63    /// </example>
  64    ///
  65    /// <example>
  66    /// `frontend/db.js`
  67    /// </example>
  68    pub path: PathBuf,
  69    /// The mode of operation on the file. Possible values:
  70    /// - 'edit': Make granular edits to an existing file.
  71    /// - 'create': Create a new file if it doesn't exist.
  72    /// - 'overwrite': Replace the entire contents of an existing file.
  73    ///
  74    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
  75    pub mode: EditFileMode,
  76}
  77
  78#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  79struct EditFileToolPartialInput {
  80    #[serde(default)]
  81    path: String,
  82    #[serde(default)]
  83    display_description: String,
  84}
  85
  86#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  87#[serde(rename_all = "lowercase")]
  88#[schemars(inline)]
  89pub enum EditFileMode {
  90    Edit,
  91    Create,
  92    Overwrite,
  93}
  94
  95#[derive(Debug, Serialize, Deserialize)]
  96pub struct EditFileToolOutput {
  97    #[serde(alias = "original_path")]
  98    input_path: PathBuf,
  99    new_text: String,
 100    old_text: Arc<String>,
 101    #[serde(default)]
 102    diff: String,
 103    #[serde(alias = "raw_output")]
 104    edit_agent_output: EditAgentOutput,
 105}
 106
 107impl From<EditFileToolOutput> for LanguageModelToolResultContent {
 108    fn from(output: EditFileToolOutput) -> Self {
 109        if output.diff.is_empty() {
 110            "No edits were made.".into()
 111        } else {
 112            format!(
 113                "Edited {}:\n\n```diff\n{}\n```",
 114                output.input_path.display(),
 115                output.diff
 116            )
 117            .into()
 118        }
 119    }
 120}
 121
 122pub struct EditFileTool {
 123    thread: WeakEntity<Thread>,
 124    language_registry: Arc<LanguageRegistry>,
 125    project: Entity<Project>,
 126}
 127
 128impl EditFileTool {
 129    pub fn new(
 130        project: Entity<Project>,
 131        thread: WeakEntity<Thread>,
 132        language_registry: Arc<LanguageRegistry>,
 133    ) -> Self {
 134        Self {
 135            project,
 136            thread,
 137            language_registry,
 138        }
 139    }
 140
 141    fn authorize(
 142        &self,
 143        input: &EditFileToolInput,
 144        event_stream: &ToolCallEventStream,
 145        cx: &mut App,
 146    ) -> Task<Result<()>> {
 147        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
 148            return Task::ready(Ok(()));
 149        }
 150
 151        // If any path component matches the local settings folder, then this could affect
 152        // the editor in ways beyond the project source, so prompt.
 153        let local_settings_folder = paths::local_settings_folder_name();
 154        let path = Path::new(&input.path);
 155        if path.components().any(|component| {
 156            component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
 157        }) {
 158            return event_stream.authorize(
 159                format!("{} (local settings)", input.display_description),
 160                cx,
 161            );
 162        }
 163
 164        // It's also possible that the global config dir is configured to be inside the project,
 165        // so check for that edge case too.
 166        // TODO this is broken when remoting
 167        if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
 168            && canonical_path.starts_with(paths::config_dir())
 169        {
 170            return event_stream.authorize(
 171                format!("{} (global settings)", input.display_description),
 172                cx,
 173            );
 174        }
 175
 176        // Check if path is inside the global config directory
 177        // First check if it's already inside project - if not, try to canonicalize
 178        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
 179            thread.project().read(cx).find_project_path(&input.path, cx)
 180        }) else {
 181            return Task::ready(Err(anyhow!("thread was dropped")));
 182        };
 183
 184        // If the path is inside the project, and it's not one of the above edge cases,
 185        // then no confirmation is necessary. Otherwise, confirmation is necessary.
 186        if project_path.is_some() {
 187            Task::ready(Ok(()))
 188        } else {
 189            event_stream.authorize(&input.display_description, cx)
 190        }
 191    }
 192}
 193
 194impl AgentTool for EditFileTool {
 195    type Input = EditFileToolInput;
 196    type Output = EditFileToolOutput;
 197
 198    fn name() -> &'static str {
 199        "edit_file"
 200    }
 201
 202    fn kind() -> acp::ToolKind {
 203        acp::ToolKind::Edit
 204    }
 205
 206    fn initial_title(
 207        &self,
 208        input: Result<Self::Input, serde_json::Value>,
 209        cx: &mut App,
 210    ) -> SharedString {
 211        match input {
 212            Ok(input) => self
 213                .project
 214                .read(cx)
 215                .find_project_path(&input.path, cx)
 216                .and_then(|project_path| {
 217                    self.project
 218                        .read(cx)
 219                        .short_full_path_for_project_path(&project_path, cx)
 220                })
 221                .unwrap_or(input.path.to_string_lossy().to_string())
 222                .into(),
 223            Err(raw_input) => {
 224                if let Some(input) =
 225                    serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
 226                {
 227                    let path = input.path.trim();
 228                    if !path.is_empty() {
 229                        return self
 230                            .project
 231                            .read(cx)
 232                            .find_project_path(&input.path, cx)
 233                            .and_then(|project_path| {
 234                                self.project
 235                                    .read(cx)
 236                                    .short_full_path_for_project_path(&project_path, cx)
 237                            })
 238                            .unwrap_or(input.path)
 239                            .into();
 240                    }
 241
 242                    let description = input.display_description.trim();
 243                    if !description.is_empty() {
 244                        return description.to_string().into();
 245                    }
 246                }
 247
 248                DEFAULT_UI_TEXT.into()
 249            }
 250        }
 251    }
 252
 253    fn run(
 254        self: Arc<Self>,
 255        input: Self::Input,
 256        event_stream: ToolCallEventStream,
 257        cx: &mut App,
 258    ) -> Task<Result<Self::Output>> {
 259        let Ok(project) = self
 260            .thread
 261            .read_with(cx, |thread, _cx| thread.project().clone())
 262        else {
 263            return Task::ready(Err(anyhow!("thread was dropped")));
 264        };
 265        let project_path = match resolve_path(&input, project.clone(), cx) {
 266            Ok(path) => path,
 267            Err(err) => return Task::ready(Err(anyhow!(err))),
 268        };
 269        let abs_path = project.read(cx).absolute_path(&project_path, cx);
 270        if let Some(abs_path) = abs_path.clone() {
 271            event_stream.update_fields(ToolCallUpdateFields {
 272                locations: Some(vec![acp::ToolCallLocation {
 273                    path: abs_path,
 274                    line: None,
 275                    meta: None,
 276                }]),
 277                ..Default::default()
 278            });
 279        }
 280
 281        let authorize = self.authorize(&input, &event_stream, cx);
 282        cx.spawn(async move |cx: &mut AsyncApp| {
 283            authorize.await?;
 284
 285            let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
 286                let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
 287                (request, thread.model().cloned(), thread.action_log().clone())
 288            })?;
 289            let request = request?;
 290            let model = model.context("No language model configured")?;
 291
 292            let edit_format = EditFormat::from_model(model.clone())?;
 293            let edit_agent = EditAgent::new(
 294                model,
 295                project.clone(),
 296                action_log.clone(),
 297                // TODO: move edit agent to this crate so we can use our templates
 298                assistant_tools::templates::Templates::new(),
 299                edit_format,
 300            );
 301
 302            let buffer = project
 303                .update(cx, |project, cx| {
 304                    project.open_buffer(project_path.clone(), cx)
 305                })?
 306                .await?;
 307
 308            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
 309            event_stream.update_diff(diff.clone());
 310            let _finalize_diff = util::defer({
 311               let diff = diff.downgrade();
 312               let mut cx = cx.clone();
 313               move || {
 314                   diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
 315               }
 316            });
 317
 318            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 319            let old_text = cx
 320                .background_spawn({
 321                    let old_snapshot = old_snapshot.clone();
 322                    async move { Arc::new(old_snapshot.text()) }
 323                })
 324                .await;
 325
 326
 327            let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
 328                edit_agent.edit(
 329                    buffer.clone(),
 330                    input.display_description.clone(),
 331                    &request,
 332                    cx,
 333                )
 334            } else {
 335                edit_agent.overwrite(
 336                    buffer.clone(),
 337                    input.display_description.clone(),
 338                    &request,
 339                    cx,
 340                )
 341            };
 342
 343            let mut hallucinated_old_text = false;
 344            let mut ambiguous_ranges = Vec::new();
 345            let mut emitted_location = false;
 346            while let Some(event) = events.next().await {
 347                match event {
 348                    EditAgentOutputEvent::Edited(range) => {
 349                        if !emitted_location {
 350                            let line = buffer.update(cx, |buffer, _cx| {
 351                                range.start.to_point(&buffer.snapshot()).row
 352                            }).ok();
 353                            if let Some(abs_path) = abs_path.clone() {
 354                                event_stream.update_fields(ToolCallUpdateFields {
 355                                    locations: Some(vec![ToolCallLocation { path: abs_path, line, meta: None }]),
 356                                    ..Default::default()
 357                                });
 358                            }
 359                            emitted_location = true;
 360                        }
 361                    },
 362                    EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
 363                    EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
 364                    EditAgentOutputEvent::ResolvingEditRange(range) => {
 365                        diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
 366                        // if !emitted_location {
 367                        //     let line = buffer.update(cx, |buffer, _cx| {
 368                        //         range.start.to_point(&buffer.snapshot()).row
 369                        //     }).ok();
 370                        //     if let Some(abs_path) = abs_path.clone() {
 371                        //         event_stream.update_fields(ToolCallUpdateFields {
 372                        //             locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
 373                        //             ..Default::default()
 374                        //         });
 375                        //     }
 376                        // }
 377                    }
 378                }
 379            }
 380
 381            // If format_on_save is enabled, format the buffer
 382            let format_on_save_enabled = buffer
 383                .read_with(cx, |buffer, cx| {
 384                    let settings = language_settings::language_settings(
 385                        buffer.language().map(|l| l.name()),
 386                        buffer.file(),
 387                        cx,
 388                    );
 389                    settings.format_on_save != FormatOnSave::Off
 390                })
 391                .unwrap_or(false);
 392
 393            let edit_agent_output = output.await?;
 394
 395            if format_on_save_enabled {
 396                action_log.update(cx, |log, cx| {
 397                    log.buffer_edited(buffer.clone(), cx);
 398                })?;
 399
 400                let format_task = project.update(cx, |project, cx| {
 401                    project.format(
 402                        HashSet::from_iter([buffer.clone()]),
 403                        LspFormatTarget::Buffers,
 404                        false, // Don't push to history since the tool did it.
 405                        FormatTrigger::Save,
 406                        cx,
 407                    )
 408                })?;
 409                format_task.await.log_err();
 410            }
 411
 412            project
 413                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
 414                .await?;
 415
 416            action_log.update(cx, |log, cx| {
 417                log.buffer_edited(buffer.clone(), cx);
 418            })?;
 419
 420            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 421            let (new_text, unified_diff) = cx
 422                .background_spawn({
 423                    let new_snapshot = new_snapshot.clone();
 424                    let old_text = old_text.clone();
 425                    async move {
 426                        let new_text = new_snapshot.text();
 427                        let diff = language::unified_diff(&old_text, &new_text);
 428                        (new_text, diff)
 429                    }
 430                })
 431                .await;
 432
 433            let input_path = input.path.display();
 434            if unified_diff.is_empty() {
 435                anyhow::ensure!(
 436                    !hallucinated_old_text,
 437                    formatdoc! {"
 438                        Some edits were produced but none of them could be applied.
 439                        Read the relevant sections of {input_path} again so that
 440                        I can perform the requested edits.
 441                    "}
 442                );
 443                anyhow::ensure!(
 444                    ambiguous_ranges.is_empty(),
 445                    {
 446                        let line_numbers = ambiguous_ranges
 447                            .iter()
 448                            .map(|range| range.start.to_string())
 449                            .collect::<Vec<_>>()
 450                            .join(", ");
 451                        formatdoc! {"
 452                            <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
 453                            relevant sections of {input_path} again and extend <old_text> so
 454                            that I can perform the requested edits.
 455                        "}
 456                    }
 457                );
 458            }
 459
 460            Ok(EditFileToolOutput {
 461                input_path: input.path,
 462                new_text,
 463                old_text,
 464                diff: unified_diff,
 465                edit_agent_output,
 466            })
 467        })
 468    }
 469
 470    fn replay(
 471        &self,
 472        _input: Self::Input,
 473        output: Self::Output,
 474        event_stream: ToolCallEventStream,
 475        cx: &mut App,
 476    ) -> Result<()> {
 477        event_stream.update_diff(cx.new(|cx| {
 478            Diff::finalized(
 479                output.input_path.to_string_lossy().to_string(),
 480                Some(output.old_text.to_string()),
 481                output.new_text,
 482                self.language_registry.clone(),
 483                cx,
 484            )
 485        }));
 486        Ok(())
 487    }
 488}
 489
 490/// Validate that the file path is valid, meaning:
 491///
 492/// - For `edit` and `overwrite`, the path must point to an existing file.
 493/// - For `create`, the file must not already exist, but it's parent dir must exist.
 494fn resolve_path(
 495    input: &EditFileToolInput,
 496    project: Entity<Project>,
 497    cx: &mut App,
 498) -> Result<ProjectPath> {
 499    let project = project.read(cx);
 500
 501    match input.mode {
 502        EditFileMode::Edit | EditFileMode::Overwrite => {
 503            let path = project
 504                .find_project_path(&input.path, cx)
 505                .context("Can't edit file: path not found")?;
 506
 507            let entry = project
 508                .entry_for_path(&path, cx)
 509                .context("Can't edit file: path not found")?;
 510
 511            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
 512            Ok(path)
 513        }
 514
 515        EditFileMode::Create => {
 516            if let Some(path) = project.find_project_path(&input.path, cx) {
 517                anyhow::ensure!(
 518                    project.entry_for_path(&path, cx).is_none(),
 519                    "Can't create file: file already exists"
 520                );
 521            }
 522
 523            let parent_path = input
 524                .path
 525                .parent()
 526                .context("Can't create file: incorrect path")?;
 527
 528            let parent_project_path = project.find_project_path(&parent_path, cx);
 529
 530            let parent_entry = parent_project_path
 531                .as_ref()
 532                .and_then(|path| project.entry_for_path(path, cx))
 533                .context("Can't create file: parent directory doesn't exist")?;
 534
 535            anyhow::ensure!(
 536                parent_entry.is_dir(),
 537                "Can't create file: parent is not a directory"
 538            );
 539
 540            let file_name = input
 541                .path
 542                .file_name()
 543                .and_then(|file_name| file_name.to_str())
 544                .and_then(|file_name| RelPath::new(file_name).ok())
 545                .context("Can't create file: invalid filename")?;
 546
 547            let new_file_path = parent_project_path.map(|parent| ProjectPath {
 548                path: parent.path.join(file_name),
 549                ..parent
 550            });
 551
 552            new_file_path.context("Can't create file")
 553        }
 554    }
 555}
 556
 557#[cfg(test)]
 558mod tests {
 559    use super::*;
 560    use crate::{ContextServerRegistry, Templates};
 561    use client::TelemetrySettings;
 562    use fs::Fs;
 563    use gpui::{TestAppContext, UpdateGlobal};
 564    use language_model::fake_provider::FakeLanguageModel;
 565    use prompt_store::ProjectContext;
 566    use serde_json::json;
 567    use settings::SettingsStore;
 568    use util::path;
 569
 570    #[gpui::test]
 571    async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
 572        init_test(cx);
 573
 574        let fs = project::FakeFs::new(cx.executor());
 575        fs.insert_tree("/root", json!({})).await;
 576        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 577        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 578        let context_server_registry =
 579            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 580        let model = Arc::new(FakeLanguageModel::default());
 581        let thread = cx.new(|cx| {
 582            Thread::new(
 583                project.clone(),
 584                cx.new(|_cx| ProjectContext::default()),
 585                context_server_registry,
 586                Templates::new(),
 587                Some(model),
 588                cx,
 589            )
 590        });
 591        let result = cx
 592            .update(|cx| {
 593                let input = EditFileToolInput {
 594                    display_description: "Some edit".into(),
 595                    path: "root/nonexistent_file.txt".into(),
 596                    mode: EditFileMode::Edit,
 597                };
 598                Arc::new(EditFileTool::new(
 599                    project,
 600                    thread.downgrade(),
 601                    language_registry,
 602                ))
 603                .run(input, ToolCallEventStream::test().0, cx)
 604            })
 605            .await;
 606        assert_eq!(
 607            result.unwrap_err().to_string(),
 608            "Can't edit file: path not found"
 609        );
 610    }
 611
 612    #[gpui::test]
 613    async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
 614        let mode = &EditFileMode::Create;
 615
 616        let result = test_resolve_path(mode, "root/new.txt", cx);
 617        assert_resolved_path_eq(result.await, "new.txt");
 618
 619        let result = test_resolve_path(mode, "new.txt", cx);
 620        assert_resolved_path_eq(result.await, "new.txt");
 621
 622        let result = test_resolve_path(mode, "dir/new.txt", cx);
 623        assert_resolved_path_eq(result.await, "dir/new.txt");
 624
 625        let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
 626        assert_eq!(
 627            result.await.unwrap_err().to_string(),
 628            "Can't create file: file already exists"
 629        );
 630
 631        let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
 632        assert_eq!(
 633            result.await.unwrap_err().to_string(),
 634            "Can't create file: parent directory doesn't exist"
 635        );
 636    }
 637
 638    #[gpui::test]
 639    async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
 640        let mode = &EditFileMode::Edit;
 641
 642        let path_with_root = "root/dir/subdir/existing.txt";
 643        let path_without_root = "dir/subdir/existing.txt";
 644        let result = test_resolve_path(mode, path_with_root, cx);
 645        assert_resolved_path_eq(result.await, path_without_root);
 646
 647        let result = test_resolve_path(mode, path_without_root, cx);
 648        assert_resolved_path_eq(result.await, path_without_root);
 649
 650        let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
 651        assert_eq!(
 652            result.await.unwrap_err().to_string(),
 653            "Can't edit file: path not found"
 654        );
 655
 656        let result = test_resolve_path(mode, "root/dir", cx);
 657        assert_eq!(
 658            result.await.unwrap_err().to_string(),
 659            "Can't edit file: path is a directory"
 660        );
 661    }
 662
 663    async fn test_resolve_path(
 664        mode: &EditFileMode,
 665        path: &str,
 666        cx: &mut TestAppContext,
 667    ) -> anyhow::Result<ProjectPath> {
 668        init_test(cx);
 669
 670        let fs = project::FakeFs::new(cx.executor());
 671        fs.insert_tree(
 672            "/root",
 673            json!({
 674                "dir": {
 675                    "subdir": {
 676                        "existing.txt": "hello"
 677                    }
 678                }
 679            }),
 680        )
 681        .await;
 682        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 683
 684        let input = EditFileToolInput {
 685            display_description: "Some edit".into(),
 686            path: path.into(),
 687            mode: mode.clone(),
 688        };
 689
 690        cx.update(|cx| resolve_path(&input, project, cx))
 691    }
 692
 693    #[track_caller]
 694    fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
 695        let actual = path.expect("Should return valid path").path;
 696        let actual = actual.as_str();
 697        assert_eq!(actual, expected);
 698    }
 699
 700    #[gpui::test]
 701    async fn test_format_on_save(cx: &mut TestAppContext) {
 702        init_test(cx);
 703
 704        let fs = project::FakeFs::new(cx.executor());
 705        fs.insert_tree("/root", json!({"src": {}})).await;
 706
 707        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 708
 709        // Set up a Rust language with LSP formatting support
 710        let rust_language = Arc::new(language::Language::new(
 711            language::LanguageConfig {
 712                name: "Rust".into(),
 713                matcher: language::LanguageMatcher {
 714                    path_suffixes: vec!["rs".to_string()],
 715                    ..Default::default()
 716                },
 717                ..Default::default()
 718            },
 719            None,
 720        ));
 721
 722        // Register the language and fake LSP
 723        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
 724        language_registry.add(rust_language);
 725
 726        let mut fake_language_servers = language_registry.register_fake_lsp(
 727            "Rust",
 728            language::FakeLspAdapter {
 729                capabilities: lsp::ServerCapabilities {
 730                    document_formatting_provider: Some(lsp::OneOf::Left(true)),
 731                    ..Default::default()
 732                },
 733                ..Default::default()
 734            },
 735        );
 736
 737        // Create the file
 738        fs.save(
 739            path!("/root/src/main.rs").as_ref(),
 740            &"initial content".into(),
 741            language::LineEnding::Unix,
 742        )
 743        .await
 744        .unwrap();
 745
 746        // Open the buffer to trigger LSP initialization
 747        let buffer = project
 748            .update(cx, |project, cx| {
 749                project.open_local_buffer(path!("/root/src/main.rs"), cx)
 750            })
 751            .await
 752            .unwrap();
 753
 754        // Register the buffer with language servers
 755        let _handle = project.update(cx, |project, cx| {
 756            project.register_buffer_with_language_servers(&buffer, cx)
 757        });
 758
 759        const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
 760        const FORMATTED_CONTENT: &str =
 761            "This file was formatted by the fake formatter in the test.\n";
 762
 763        // Get the fake language server and set up formatting handler
 764        let fake_language_server = fake_language_servers.next().await.unwrap();
 765        fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
 766            |_, _| async move {
 767                Ok(Some(vec![lsp::TextEdit {
 768                    range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
 769                    new_text: FORMATTED_CONTENT.to_string(),
 770                }]))
 771            }
 772        });
 773
 774        let context_server_registry =
 775            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 776        let model = Arc::new(FakeLanguageModel::default());
 777        let thread = cx.new(|cx| {
 778            Thread::new(
 779                project.clone(),
 780                cx.new(|_cx| ProjectContext::default()),
 781                context_server_registry,
 782                Templates::new(),
 783                Some(model.clone()),
 784                cx,
 785            )
 786        });
 787
 788        // First, test with format_on_save enabled
 789        cx.update(|cx| {
 790            SettingsStore::update_global(cx, |store, cx| {
 791                store.update_user_settings(cx, |settings| {
 792                    settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
 793                    settings.project.all_languages.defaults.formatter =
 794                        Some(language::language_settings::SelectedFormatter::Auto);
 795                });
 796            });
 797        });
 798
 799        // Have the model stream unformatted content
 800        let edit_result = {
 801            let edit_task = cx.update(|cx| {
 802                let input = EditFileToolInput {
 803                    display_description: "Create main function".into(),
 804                    path: "root/src/main.rs".into(),
 805                    mode: EditFileMode::Overwrite,
 806                };
 807                Arc::new(EditFileTool::new(
 808                    project.clone(),
 809                    thread.downgrade(),
 810                    language_registry.clone(),
 811                ))
 812                .run(input, ToolCallEventStream::test().0, cx)
 813            });
 814
 815            // Stream the unformatted content
 816            cx.executor().run_until_parked();
 817            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 818            model.end_last_completion_stream();
 819
 820            edit_task.await
 821        };
 822        assert!(edit_result.is_ok());
 823
 824        // Wait for any async operations (e.g. formatting) to complete
 825        cx.executor().run_until_parked();
 826
 827        // Read the file to verify it was formatted automatically
 828        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 829        assert_eq!(
 830            // Ignore carriage returns on Windows
 831            new_content.replace("\r\n", "\n"),
 832            FORMATTED_CONTENT,
 833            "Code should be formatted when format_on_save is enabled"
 834        );
 835
 836        let stale_buffer_count = thread
 837            .read_with(cx, |thread, _cx| thread.action_log.clone())
 838            .read_with(cx, |log, cx| log.stale_buffers(cx).count());
 839
 840        assert_eq!(
 841            stale_buffer_count, 0,
 842            "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
 843             This causes the agent to think the file was modified externally when it was just formatted.",
 844            stale_buffer_count
 845        );
 846
 847        // Next, test with format_on_save disabled
 848        cx.update(|cx| {
 849            SettingsStore::update_global(cx, |store, cx| {
 850                store.update_user_settings(cx, |settings| {
 851                    settings.project.all_languages.defaults.format_on_save =
 852                        Some(FormatOnSave::Off);
 853                });
 854            });
 855        });
 856
 857        // Stream unformatted edits again
 858        let edit_result = {
 859            let edit_task = cx.update(|cx| {
 860                let input = EditFileToolInput {
 861                    display_description: "Update main function".into(),
 862                    path: "root/src/main.rs".into(),
 863                    mode: EditFileMode::Overwrite,
 864                };
 865                Arc::new(EditFileTool::new(
 866                    project.clone(),
 867                    thread.downgrade(),
 868                    language_registry,
 869                ))
 870                .run(input, ToolCallEventStream::test().0, cx)
 871            });
 872
 873            // Stream the unformatted content
 874            cx.executor().run_until_parked();
 875            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 876            model.end_last_completion_stream();
 877
 878            edit_task.await
 879        };
 880        assert!(edit_result.is_ok());
 881
 882        // Wait for any async operations (e.g. formatting) to complete
 883        cx.executor().run_until_parked();
 884
 885        // Verify the file was not formatted
 886        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 887        assert_eq!(
 888            // Ignore carriage returns on Windows
 889            new_content.replace("\r\n", "\n"),
 890            UNFORMATTED_CONTENT,
 891            "Code should not be formatted when format_on_save is disabled"
 892        );
 893    }
 894
 895    #[gpui::test]
 896    async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
 897        init_test(cx);
 898
 899        let fs = project::FakeFs::new(cx.executor());
 900        fs.insert_tree("/root", json!({"src": {}})).await;
 901
 902        // Create a simple file with trailing whitespace
 903        fs.save(
 904            path!("/root/src/main.rs").as_ref(),
 905            &"initial content".into(),
 906            language::LineEnding::Unix,
 907        )
 908        .await
 909        .unwrap();
 910
 911        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 912        let context_server_registry =
 913            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 914        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 915        let model = Arc::new(FakeLanguageModel::default());
 916        let thread = cx.new(|cx| {
 917            Thread::new(
 918                project.clone(),
 919                cx.new(|_cx| ProjectContext::default()),
 920                context_server_registry,
 921                Templates::new(),
 922                Some(model.clone()),
 923                cx,
 924            )
 925        });
 926
 927        // First, test with remove_trailing_whitespace_on_save enabled
 928        cx.update(|cx| {
 929            SettingsStore::update_global(cx, |store, cx| {
 930                store.update_user_settings(cx, |settings| {
 931                    settings
 932                        .project
 933                        .all_languages
 934                        .defaults
 935                        .remove_trailing_whitespace_on_save = Some(true);
 936                });
 937            });
 938        });
 939
 940        const CONTENT_WITH_TRAILING_WHITESPACE: &str =
 941            "fn main() {  \n    println!(\"Hello!\");  \n}\n";
 942
 943        // Have the model stream content that contains trailing whitespace
 944        let edit_result = {
 945            let edit_task = cx.update(|cx| {
 946                let input = EditFileToolInput {
 947                    display_description: "Create main function".into(),
 948                    path: "root/src/main.rs".into(),
 949                    mode: EditFileMode::Overwrite,
 950                };
 951                Arc::new(EditFileTool::new(
 952                    project.clone(),
 953                    thread.downgrade(),
 954                    language_registry.clone(),
 955                ))
 956                .run(input, ToolCallEventStream::test().0, cx)
 957            });
 958
 959            // Stream the content with trailing whitespace
 960            cx.executor().run_until_parked();
 961            model.send_last_completion_stream_text_chunk(
 962                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
 963            );
 964            model.end_last_completion_stream();
 965
 966            edit_task.await
 967        };
 968        assert!(edit_result.is_ok());
 969
 970        // Wait for any async operations (e.g. formatting) to complete
 971        cx.executor().run_until_parked();
 972
 973        // Read the file to verify trailing whitespace was removed automatically
 974        assert_eq!(
 975            // Ignore carriage returns on Windows
 976            fs.load(path!("/root/src/main.rs").as_ref())
 977                .await
 978                .unwrap()
 979                .replace("\r\n", "\n"),
 980            "fn main() {\n    println!(\"Hello!\");\n}\n",
 981            "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
 982        );
 983
 984        // Next, test with remove_trailing_whitespace_on_save disabled
 985        cx.update(|cx| {
 986            SettingsStore::update_global(cx, |store, cx| {
 987                store.update_user_settings(cx, |settings| {
 988                    settings
 989                        .project
 990                        .all_languages
 991                        .defaults
 992                        .remove_trailing_whitespace_on_save = Some(false);
 993                });
 994            });
 995        });
 996
 997        // Stream edits again with trailing whitespace
 998        let edit_result = {
 999            let edit_task = cx.update(|cx| {
1000                let input = EditFileToolInput {
1001                    display_description: "Update main function".into(),
1002                    path: "root/src/main.rs".into(),
1003                    mode: EditFileMode::Overwrite,
1004                };
1005                Arc::new(EditFileTool::new(
1006                    project.clone(),
1007                    thread.downgrade(),
1008                    language_registry,
1009                ))
1010                .run(input, ToolCallEventStream::test().0, cx)
1011            });
1012
1013            // Stream the content with trailing whitespace
1014            cx.executor().run_until_parked();
1015            model.send_last_completion_stream_text_chunk(
1016                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1017            );
1018            model.end_last_completion_stream();
1019
1020            edit_task.await
1021        };
1022        assert!(edit_result.is_ok());
1023
1024        // Wait for any async operations (e.g. formatting) to complete
1025        cx.executor().run_until_parked();
1026
1027        // Verify the file still has trailing whitespace
1028        // Read the file again - it should still have trailing whitespace
1029        let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1030        assert_eq!(
1031            // Ignore carriage returns on Windows
1032            final_content.replace("\r\n", "\n"),
1033            CONTENT_WITH_TRAILING_WHITESPACE,
1034            "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1035        );
1036    }
1037
1038    #[gpui::test]
1039    async fn test_authorize(cx: &mut TestAppContext) {
1040        init_test(cx);
1041        let fs = project::FakeFs::new(cx.executor());
1042        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1043        let context_server_registry =
1044            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1045        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1046        let model = Arc::new(FakeLanguageModel::default());
1047        let thread = cx.new(|cx| {
1048            Thread::new(
1049                project.clone(),
1050                cx.new(|_cx| ProjectContext::default()),
1051                context_server_registry,
1052                Templates::new(),
1053                Some(model.clone()),
1054                cx,
1055            )
1056        });
1057        let tool = Arc::new(EditFileTool::new(
1058            project.clone(),
1059            thread.downgrade(),
1060            language_registry,
1061        ));
1062        fs.insert_tree("/root", json!({})).await;
1063
1064        // Test 1: Path with .zed component should require confirmation
1065        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1066        let _auth = cx.update(|cx| {
1067            tool.authorize(
1068                &EditFileToolInput {
1069                    display_description: "test 1".into(),
1070                    path: ".zed/settings.json".into(),
1071                    mode: EditFileMode::Edit,
1072                },
1073                &stream_tx,
1074                cx,
1075            )
1076        });
1077
1078        let event = stream_rx.expect_authorization().await;
1079        assert_eq!(
1080            event.tool_call.fields.title,
1081            Some("test 1 (local settings)".into())
1082        );
1083
1084        // Test 2: Path outside project should require confirmation
1085        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1086        let _auth = cx.update(|cx| {
1087            tool.authorize(
1088                &EditFileToolInput {
1089                    display_description: "test 2".into(),
1090                    path: "/etc/hosts".into(),
1091                    mode: EditFileMode::Edit,
1092                },
1093                &stream_tx,
1094                cx,
1095            )
1096        });
1097
1098        let event = stream_rx.expect_authorization().await;
1099        assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1100
1101        // Test 3: Relative path without .zed should not require confirmation
1102        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1103        cx.update(|cx| {
1104            tool.authorize(
1105                &EditFileToolInput {
1106                    display_description: "test 3".into(),
1107                    path: "root/src/main.rs".into(),
1108                    mode: EditFileMode::Edit,
1109                },
1110                &stream_tx,
1111                cx,
1112            )
1113        })
1114        .await
1115        .unwrap();
1116        assert!(stream_rx.try_next().is_err());
1117
1118        // Test 4: Path with .zed in the middle should require confirmation
1119        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1120        let _auth = cx.update(|cx| {
1121            tool.authorize(
1122                &EditFileToolInput {
1123                    display_description: "test 4".into(),
1124                    path: "root/.zed/tasks.json".into(),
1125                    mode: EditFileMode::Edit,
1126                },
1127                &stream_tx,
1128                cx,
1129            )
1130        });
1131        let event = stream_rx.expect_authorization().await;
1132        assert_eq!(
1133            event.tool_call.fields.title,
1134            Some("test 4 (local settings)".into())
1135        );
1136
1137        // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1138        cx.update(|cx| {
1139            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1140            settings.always_allow_tool_actions = true;
1141            agent_settings::AgentSettings::override_global(settings, cx);
1142        });
1143
1144        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1145        cx.update(|cx| {
1146            tool.authorize(
1147                &EditFileToolInput {
1148                    display_description: "test 5.1".into(),
1149                    path: ".zed/settings.json".into(),
1150                    mode: EditFileMode::Edit,
1151                },
1152                &stream_tx,
1153                cx,
1154            )
1155        })
1156        .await
1157        .unwrap();
1158        assert!(stream_rx.try_next().is_err());
1159
1160        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1161        cx.update(|cx| {
1162            tool.authorize(
1163                &EditFileToolInput {
1164                    display_description: "test 5.2".into(),
1165                    path: "/etc/hosts".into(),
1166                    mode: EditFileMode::Edit,
1167                },
1168                &stream_tx,
1169                cx,
1170            )
1171        })
1172        .await
1173        .unwrap();
1174        assert!(stream_rx.try_next().is_err());
1175    }
1176
1177    #[gpui::test]
1178    async fn test_authorize_global_config(cx: &mut TestAppContext) {
1179        init_test(cx);
1180        let fs = project::FakeFs::new(cx.executor());
1181        fs.insert_tree("/project", json!({})).await;
1182        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1183        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1184        let context_server_registry =
1185            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1186        let model = Arc::new(FakeLanguageModel::default());
1187        let thread = cx.new(|cx| {
1188            Thread::new(
1189                project.clone(),
1190                cx.new(|_cx| ProjectContext::default()),
1191                context_server_registry,
1192                Templates::new(),
1193                Some(model.clone()),
1194                cx,
1195            )
1196        });
1197        let tool = Arc::new(EditFileTool::new(
1198            project.clone(),
1199            thread.downgrade(),
1200            language_registry,
1201        ));
1202
1203        // Test global config paths - these should require confirmation if they exist and are outside the project
1204        let test_cases = vec![
1205            (
1206                "/etc/hosts",
1207                true,
1208                "System file should require confirmation",
1209            ),
1210            (
1211                "/usr/local/bin/script",
1212                true,
1213                "System bin file should require confirmation",
1214            ),
1215            (
1216                "project/normal_file.rs",
1217                false,
1218                "Normal project file should not require confirmation",
1219            ),
1220        ];
1221
1222        for (path, should_confirm, description) in test_cases {
1223            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1224            let auth = cx.update(|cx| {
1225                tool.authorize(
1226                    &EditFileToolInput {
1227                        display_description: "Edit file".into(),
1228                        path: path.into(),
1229                        mode: EditFileMode::Edit,
1230                    },
1231                    &stream_tx,
1232                    cx,
1233                )
1234            });
1235
1236            if should_confirm {
1237                stream_rx.expect_authorization().await;
1238            } else {
1239                auth.await.unwrap();
1240                assert!(
1241                    stream_rx.try_next().is_err(),
1242                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1243                    description,
1244                    path
1245                );
1246            }
1247        }
1248    }
1249
1250    #[gpui::test]
1251    async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1252        init_test(cx);
1253        let fs = project::FakeFs::new(cx.executor());
1254
1255        // Create multiple worktree directories
1256        fs.insert_tree(
1257            "/workspace/frontend",
1258            json!({
1259                "src": {
1260                    "main.js": "console.log('frontend');"
1261                }
1262            }),
1263        )
1264        .await;
1265        fs.insert_tree(
1266            "/workspace/backend",
1267            json!({
1268                "src": {
1269                    "main.rs": "fn main() {}"
1270                }
1271            }),
1272        )
1273        .await;
1274        fs.insert_tree(
1275            "/workspace/shared",
1276            json!({
1277                ".zed": {
1278                    "settings.json": "{}"
1279                }
1280            }),
1281        )
1282        .await;
1283
1284        // Create project with multiple worktrees
1285        let project = Project::test(
1286            fs.clone(),
1287            [
1288                path!("/workspace/frontend").as_ref(),
1289                path!("/workspace/backend").as_ref(),
1290                path!("/workspace/shared").as_ref(),
1291            ],
1292            cx,
1293        )
1294        .await;
1295        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1296        let context_server_registry =
1297            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1298        let model = Arc::new(FakeLanguageModel::default());
1299        let thread = cx.new(|cx| {
1300            Thread::new(
1301                project.clone(),
1302                cx.new(|_cx| ProjectContext::default()),
1303                context_server_registry.clone(),
1304                Templates::new(),
1305                Some(model.clone()),
1306                cx,
1307            )
1308        });
1309        let tool = Arc::new(EditFileTool::new(
1310            project.clone(),
1311            thread.downgrade(),
1312            language_registry,
1313        ));
1314
1315        // Test files in different worktrees
1316        let test_cases = vec![
1317            ("frontend/src/main.js", false, "File in first worktree"),
1318            ("backend/src/main.rs", false, "File in second worktree"),
1319            (
1320                "shared/.zed/settings.json",
1321                true,
1322                ".zed file in third worktree",
1323            ),
1324            ("/etc/hosts", true, "Absolute path outside all worktrees"),
1325            (
1326                "../outside/file.txt",
1327                true,
1328                "Relative path outside worktrees",
1329            ),
1330        ];
1331
1332        for (path, should_confirm, description) in test_cases {
1333            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1334            let auth = cx.update(|cx| {
1335                tool.authorize(
1336                    &EditFileToolInput {
1337                        display_description: "Edit file".into(),
1338                        path: path.into(),
1339                        mode: EditFileMode::Edit,
1340                    },
1341                    &stream_tx,
1342                    cx,
1343                )
1344            });
1345
1346            if should_confirm {
1347                stream_rx.expect_authorization().await;
1348            } else {
1349                auth.await.unwrap();
1350                assert!(
1351                    stream_rx.try_next().is_err(),
1352                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1353                    description,
1354                    path
1355                );
1356            }
1357        }
1358    }
1359
1360    #[gpui::test]
1361    async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1362        init_test(cx);
1363        let fs = project::FakeFs::new(cx.executor());
1364        fs.insert_tree(
1365            "/project",
1366            json!({
1367                ".zed": {
1368                    "settings.json": "{}"
1369                },
1370                "src": {
1371                    ".zed": {
1372                        "local.json": "{}"
1373                    }
1374                }
1375            }),
1376        )
1377        .await;
1378        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1379        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1380        let context_server_registry =
1381            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1382        let model = Arc::new(FakeLanguageModel::default());
1383        let thread = cx.new(|cx| {
1384            Thread::new(
1385                project.clone(),
1386                cx.new(|_cx| ProjectContext::default()),
1387                context_server_registry.clone(),
1388                Templates::new(),
1389                Some(model.clone()),
1390                cx,
1391            )
1392        });
1393        let tool = Arc::new(EditFileTool::new(
1394            project.clone(),
1395            thread.downgrade(),
1396            language_registry,
1397        ));
1398
1399        // Test edge cases
1400        let test_cases = vec![
1401            // Empty path - find_project_path returns Some for empty paths
1402            ("", false, "Empty path is treated as project root"),
1403            // Root directory
1404            ("/", true, "Root directory should be outside project"),
1405            // Parent directory references - find_project_path resolves these
1406            (
1407                "project/../other",
1408                true,
1409                "Path with .. that goes outside of root directory",
1410            ),
1411            (
1412                "project/./src/file.rs",
1413                false,
1414                "Path with . should work normally",
1415            ),
1416            // Windows-style paths (if on Windows)
1417            #[cfg(target_os = "windows")]
1418            ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1419            #[cfg(target_os = "windows")]
1420            ("project\\src\\main.rs", false, "Windows-style project path"),
1421        ];
1422
1423        for (path, should_confirm, description) in test_cases {
1424            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1425            let auth = cx.update(|cx| {
1426                tool.authorize(
1427                    &EditFileToolInput {
1428                        display_description: "Edit file".into(),
1429                        path: path.into(),
1430                        mode: EditFileMode::Edit,
1431                    },
1432                    &stream_tx,
1433                    cx,
1434                )
1435            });
1436
1437            cx.run_until_parked();
1438
1439            if should_confirm {
1440                stream_rx.expect_authorization().await;
1441            } else {
1442                assert!(
1443                    stream_rx.try_next().is_err(),
1444                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1445                    description,
1446                    path
1447                );
1448                auth.await.unwrap();
1449            }
1450        }
1451    }
1452
1453    #[gpui::test]
1454    async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1455        init_test(cx);
1456        let fs = project::FakeFs::new(cx.executor());
1457        fs.insert_tree(
1458            "/project",
1459            json!({
1460                "existing.txt": "content",
1461                ".zed": {
1462                    "settings.json": "{}"
1463                }
1464            }),
1465        )
1466        .await;
1467        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1468        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1469        let context_server_registry =
1470            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1471        let model = Arc::new(FakeLanguageModel::default());
1472        let thread = cx.new(|cx| {
1473            Thread::new(
1474                project.clone(),
1475                cx.new(|_cx| ProjectContext::default()),
1476                context_server_registry.clone(),
1477                Templates::new(),
1478                Some(model.clone()),
1479                cx,
1480            )
1481        });
1482        let tool = Arc::new(EditFileTool::new(
1483            project.clone(),
1484            thread.downgrade(),
1485            language_registry,
1486        ));
1487
1488        // Test different EditFileMode values
1489        let modes = vec![
1490            EditFileMode::Edit,
1491            EditFileMode::Create,
1492            EditFileMode::Overwrite,
1493        ];
1494
1495        for mode in modes {
1496            // Test .zed path with different modes
1497            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1498            let _auth = cx.update(|cx| {
1499                tool.authorize(
1500                    &EditFileToolInput {
1501                        display_description: "Edit settings".into(),
1502                        path: "project/.zed/settings.json".into(),
1503                        mode: mode.clone(),
1504                    },
1505                    &stream_tx,
1506                    cx,
1507                )
1508            });
1509
1510            stream_rx.expect_authorization().await;
1511
1512            // Test outside path with different modes
1513            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1514            let _auth = cx.update(|cx| {
1515                tool.authorize(
1516                    &EditFileToolInput {
1517                        display_description: "Edit file".into(),
1518                        path: "/outside/file.txt".into(),
1519                        mode: mode.clone(),
1520                    },
1521                    &stream_tx,
1522                    cx,
1523                )
1524            });
1525
1526            stream_rx.expect_authorization().await;
1527
1528            // Test normal path with different modes
1529            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1530            cx.update(|cx| {
1531                tool.authorize(
1532                    &EditFileToolInput {
1533                        display_description: "Edit file".into(),
1534                        path: "project/normal.txt".into(),
1535                        mode: mode.clone(),
1536                    },
1537                    &stream_tx,
1538                    cx,
1539                )
1540            })
1541            .await
1542            .unwrap();
1543            assert!(stream_rx.try_next().is_err());
1544        }
1545    }
1546
1547    #[gpui::test]
1548    async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1549        init_test(cx);
1550        let fs = project::FakeFs::new(cx.executor());
1551        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1552        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1553        let context_server_registry =
1554            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1555        let model = Arc::new(FakeLanguageModel::default());
1556        let thread = cx.new(|cx| {
1557            Thread::new(
1558                project.clone(),
1559                cx.new(|_cx| ProjectContext::default()),
1560                context_server_registry,
1561                Templates::new(),
1562                Some(model.clone()),
1563                cx,
1564            )
1565        });
1566        let tool = Arc::new(EditFileTool::new(
1567            project,
1568            thread.downgrade(),
1569            language_registry,
1570        ));
1571
1572        cx.update(|cx| {
1573            // ...
1574            assert_eq!(
1575                tool.initial_title(
1576                    Err(json!({
1577                        "path": "src/main.rs",
1578                        "display_description": "",
1579                        "old_string": "old code",
1580                        "new_string": "new code"
1581                    })),
1582                    cx
1583                ),
1584                "src/main.rs"
1585            );
1586            assert_eq!(
1587                tool.initial_title(
1588                    Err(json!({
1589                        "path": "",
1590                        "display_description": "Fix error handling",
1591                        "old_string": "old code",
1592                        "new_string": "new code"
1593                    })),
1594                    cx
1595                ),
1596                "Fix error handling"
1597            );
1598            assert_eq!(
1599                tool.initial_title(
1600                    Err(json!({
1601                        "path": "src/main.rs",
1602                        "display_description": "Fix error handling",
1603                        "old_string": "old code",
1604                        "new_string": "new code"
1605                    })),
1606                    cx
1607                ),
1608                "src/main.rs"
1609            );
1610            assert_eq!(
1611                tool.initial_title(
1612                    Err(json!({
1613                        "path": "",
1614                        "display_description": "",
1615                        "old_string": "old code",
1616                        "new_string": "new code"
1617                    })),
1618                    cx
1619                ),
1620                DEFAULT_UI_TEXT
1621            );
1622            assert_eq!(
1623                tool.initial_title(Err(serde_json::Value::Null), cx),
1624                DEFAULT_UI_TEXT
1625            );
1626        });
1627    }
1628
1629    #[gpui::test]
1630    async fn test_diff_finalization(cx: &mut TestAppContext) {
1631        init_test(cx);
1632        let fs = project::FakeFs::new(cx.executor());
1633        fs.insert_tree("/", json!({"main.rs": ""})).await;
1634
1635        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1636        let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1637        let context_server_registry =
1638            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1639        let model = Arc::new(FakeLanguageModel::default());
1640        let thread = cx.new(|cx| {
1641            Thread::new(
1642                project.clone(),
1643                cx.new(|_cx| ProjectContext::default()),
1644                context_server_registry.clone(),
1645                Templates::new(),
1646                Some(model.clone()),
1647                cx,
1648            )
1649        });
1650
1651        // Ensure the diff is finalized after the edit completes.
1652        {
1653            let tool = Arc::new(EditFileTool::new(
1654                project.clone(),
1655                thread.downgrade(),
1656                languages.clone(),
1657            ));
1658            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1659            let edit = cx.update(|cx| {
1660                tool.run(
1661                    EditFileToolInput {
1662                        display_description: "Edit file".into(),
1663                        path: path!("/main.rs").into(),
1664                        mode: EditFileMode::Edit,
1665                    },
1666                    stream_tx,
1667                    cx,
1668                )
1669            });
1670            stream_rx.expect_update_fields().await;
1671            let diff = stream_rx.expect_diff().await;
1672            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1673            cx.run_until_parked();
1674            model.end_last_completion_stream();
1675            edit.await.unwrap();
1676            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1677        }
1678
1679        // Ensure the diff is finalized if an error occurs while editing.
1680        {
1681            model.forbid_requests();
1682            let tool = Arc::new(EditFileTool::new(
1683                project.clone(),
1684                thread.downgrade(),
1685                languages.clone(),
1686            ));
1687            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1688            let edit = cx.update(|cx| {
1689                tool.run(
1690                    EditFileToolInput {
1691                        display_description: "Edit file".into(),
1692                        path: path!("/main.rs").into(),
1693                        mode: EditFileMode::Edit,
1694                    },
1695                    stream_tx,
1696                    cx,
1697                )
1698            });
1699            stream_rx.expect_update_fields().await;
1700            let diff = stream_rx.expect_diff().await;
1701            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1702            edit.await.unwrap_err();
1703            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1704            model.allow_requests();
1705        }
1706
1707        // Ensure the diff is finalized if the tool call gets dropped.
1708        {
1709            let tool = Arc::new(EditFileTool::new(
1710                project.clone(),
1711                thread.downgrade(),
1712                languages.clone(),
1713            ));
1714            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1715            let edit = cx.update(|cx| {
1716                tool.run(
1717                    EditFileToolInput {
1718                        display_description: "Edit file".into(),
1719                        path: path!("/main.rs").into(),
1720                        mode: EditFileMode::Edit,
1721                    },
1722                    stream_tx,
1723                    cx,
1724                )
1725            });
1726            stream_rx.expect_update_fields().await;
1727            let diff = stream_rx.expect_diff().await;
1728            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1729            drop(edit);
1730            cx.run_until_parked();
1731            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1732        }
1733    }
1734
1735    fn init_test(cx: &mut TestAppContext) {
1736        cx.update(|cx| {
1737            let settings_store = SettingsStore::test(cx);
1738            cx.set_global(settings_store);
1739            language::init(cx);
1740            TelemetrySettings::register(cx);
1741            agent_settings::AgentSettings::register(cx);
1742            Project::init_settings(cx);
1743        });
1744    }
1745}