edit_file_tool.rs

   1use crate::{AgentTool, Thread, ToolCallEventStream};
   2use acp_thread::Diff;
   3use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
   4use anyhow::{Context as _, Result, anyhow};
   5use assistant_tools::edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat};
   6use cloud_llm_client::CompletionIntent;
   7use collections::HashSet;
   8use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
   9use indoc::formatdoc;
  10use language::language_settings::{self, FormatOnSave};
  11use language::{LanguageRegistry, ToPoint};
  12use language_model::LanguageModelToolResultContent;
  13use paths;
  14use project::lsp_store::{FormatTrigger, LspFormatTarget};
  15use project::{Project, ProjectPath};
  16use schemars::JsonSchema;
  17use serde::{Deserialize, Serialize};
  18use settings::Settings;
  19use smol::stream::StreamExt as _;
  20use std::path::{Path, PathBuf};
  21use std::sync::Arc;
  22use ui::SharedString;
  23use util::ResultExt;
  24
  25const DEFAULT_UI_TEXT: &str = "Editing file";
  26
  27/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
  28///
  29/// Before using this tool:
  30///
  31/// 1. Use the `read_file` tool to understand the file's contents and context
  32///
  33/// 2. Verify the directory path is correct (only applicable when creating new files):
  34///    - Use the `list_directory` tool to verify the parent directory exists and is the correct location
  35#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  36pub struct EditFileToolInput {
  37    /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
  38    ///
  39    /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
  40    ///
  41    /// NEVER mention the file path in this description.
  42    ///
  43    /// <example>Fix API endpoint URLs</example>
  44    /// <example>Update copyright year in `page_footer`</example>
  45    ///
  46    /// Make sure to include this field before all the others in the input object so that we can display it immediately.
  47    pub display_description: String,
  48
  49    /// The full path of the file to create or modify in the project.
  50    ///
  51    /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
  52    ///
  53    /// The following examples assume we have two root directories in the project:
  54    /// - /a/b/backend
  55    /// - /c/d/frontend
  56    ///
  57    /// <example>
  58    /// `backend/src/main.rs`
  59    ///
  60    /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
  61    /// </example>
  62    ///
  63    /// <example>
  64    /// `frontend/db.js`
  65    /// </example>
  66    pub path: PathBuf,
  67    /// The mode of operation on the file. Possible values:
  68    /// - 'edit': Make granular edits to an existing file.
  69    /// - 'create': Create a new file if it doesn't exist.
  70    /// - 'overwrite': Replace the entire contents of an existing file.
  71    ///
  72    /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
  73    pub mode: EditFileMode,
  74}
  75
  76#[derive(Debug, Serialize, Deserialize, JsonSchema)]
  77struct EditFileToolPartialInput {
  78    #[serde(default)]
  79    path: String,
  80    #[serde(default)]
  81    display_description: String,
  82}
  83
  84#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
  85#[serde(rename_all = "lowercase")]
  86#[schemars(inline)]
  87pub enum EditFileMode {
  88    Edit,
  89    Create,
  90    Overwrite,
  91}
  92
  93#[derive(Debug, Serialize, Deserialize)]
  94pub struct EditFileToolOutput {
  95    #[serde(alias = "original_path")]
  96    input_path: PathBuf,
  97    new_text: String,
  98    old_text: Arc<String>,
  99    #[serde(default)]
 100    diff: String,
 101    #[serde(alias = "raw_output")]
 102    edit_agent_output: EditAgentOutput,
 103}
 104
 105impl From<EditFileToolOutput> for LanguageModelToolResultContent {
 106    fn from(output: EditFileToolOutput) -> Self {
 107        if output.diff.is_empty() {
 108            "No edits were made.".into()
 109        } else {
 110            format!(
 111                "Edited {}:\n\n```diff\n{}\n```",
 112                output.input_path.display(),
 113                output.diff
 114            )
 115            .into()
 116        }
 117    }
 118}
 119
 120pub struct EditFileTool {
 121    thread: WeakEntity<Thread>,
 122    language_registry: Arc<LanguageRegistry>,
 123    project: Entity<Project>,
 124}
 125
 126impl EditFileTool {
 127    pub fn new(
 128        project: Entity<Project>,
 129        thread: WeakEntity<Thread>,
 130        language_registry: Arc<LanguageRegistry>,
 131    ) -> Self {
 132        Self {
 133            project,
 134            thread,
 135            language_registry,
 136        }
 137    }
 138
 139    fn authorize(
 140        &self,
 141        input: &EditFileToolInput,
 142        event_stream: &ToolCallEventStream,
 143        cx: &mut App,
 144    ) -> Task<Result<()>> {
 145        if agent_settings::AgentSettings::get_global(cx).always_allow_tool_actions {
 146            return Task::ready(Ok(()));
 147        }
 148
 149        // If any path component matches the local settings folder, then this could affect
 150        // the editor in ways beyond the project source, so prompt.
 151        let local_settings_folder = paths::local_settings_folder_relative_path();
 152        let path = Path::new(&input.path);
 153        if path
 154            .components()
 155            .any(|component| component.as_os_str() == local_settings_folder.as_os_str())
 156        {
 157            return event_stream.authorize(
 158                format!("{} (local settings)", input.display_description),
 159                cx,
 160            );
 161        }
 162
 163        // It's also possible that the global config dir is configured to be inside the project,
 164        // so check for that edge case too.
 165        if let Ok(canonical_path) = std::fs::canonicalize(&input.path)
 166            && canonical_path.starts_with(paths::config_dir())
 167        {
 168            return event_stream.authorize(
 169                format!("{} (global settings)", input.display_description),
 170                cx,
 171            );
 172        }
 173
 174        // Check if path is inside the global config directory
 175        // First check if it's already inside project - if not, try to canonicalize
 176        let Ok(project_path) = self.thread.read_with(cx, |thread, cx| {
 177            thread.project().read(cx).find_project_path(&input.path, cx)
 178        }) else {
 179            return Task::ready(Err(anyhow!("thread was dropped")));
 180        };
 181
 182        // If the path is inside the project, and it's not one of the above edge cases,
 183        // then no confirmation is necessary. Otherwise, confirmation is necessary.
 184        if project_path.is_some() {
 185            Task::ready(Ok(()))
 186        } else {
 187            event_stream.authorize(&input.display_description, cx)
 188        }
 189    }
 190}
 191
 192impl AgentTool for EditFileTool {
 193    type Input = EditFileToolInput;
 194    type Output = EditFileToolOutput;
 195
 196    fn name() -> &'static str {
 197        "edit_file"
 198    }
 199
 200    fn kind() -> acp::ToolKind {
 201        acp::ToolKind::Edit
 202    }
 203
 204    fn initial_title(
 205        &self,
 206        input: Result<Self::Input, serde_json::Value>,
 207        cx: &mut App,
 208    ) -> SharedString {
 209        match input {
 210            Ok(input) => self
 211                .project
 212                .read(cx)
 213                .find_project_path(&input.path, cx)
 214                .and_then(|project_path| {
 215                    self.project
 216                        .read(cx)
 217                        .short_full_path_for_project_path(&project_path, cx)
 218                })
 219                .unwrap_or(Path::new(&input.path).into())
 220                .to_string_lossy()
 221                .to_string()
 222                .into(),
 223            Err(raw_input) => {
 224                if let Some(input) =
 225                    serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
 226                {
 227                    let path = input.path.trim();
 228                    if !path.is_empty() {
 229                        return self
 230                            .project
 231                            .read(cx)
 232                            .find_project_path(&input.path, cx)
 233                            .and_then(|project_path| {
 234                                self.project
 235                                    .read(cx)
 236                                    .short_full_path_for_project_path(&project_path, cx)
 237                            })
 238                            .unwrap_or(Path::new(&input.path).into())
 239                            .to_string_lossy()
 240                            .to_string()
 241                            .into();
 242                    }
 243
 244                    let description = input.display_description.trim();
 245                    if !description.is_empty() {
 246                        return description.to_string().into();
 247                    }
 248                }
 249
 250                DEFAULT_UI_TEXT.into()
 251            }
 252        }
 253    }
 254
 255    fn run(
 256        self: Arc<Self>,
 257        input: Self::Input,
 258        event_stream: ToolCallEventStream,
 259        cx: &mut App,
 260    ) -> Task<Result<Self::Output>> {
 261        let Ok(project) = self
 262            .thread
 263            .read_with(cx, |thread, _cx| thread.project().clone())
 264        else {
 265            return Task::ready(Err(anyhow!("thread was dropped")));
 266        };
 267        let project_path = match resolve_path(&input, project.clone(), cx) {
 268            Ok(path) => path,
 269            Err(err) => return Task::ready(Err(anyhow!(err))),
 270        };
 271        let abs_path = project.read(cx).absolute_path(&project_path, cx);
 272        if let Some(abs_path) = abs_path.clone() {
 273            event_stream.update_fields(ToolCallUpdateFields {
 274                locations: Some(vec![acp::ToolCallLocation {
 275                    path: abs_path,
 276                    line: None,
 277                }]),
 278                ..Default::default()
 279            });
 280        }
 281
 282        let authorize = self.authorize(&input, &event_stream, cx);
 283        cx.spawn(async move |cx: &mut AsyncApp| {
 284            authorize.await?;
 285
 286            let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
 287                let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
 288                (request, thread.model().cloned(), thread.action_log().clone())
 289            })?;
 290            let request = request?;
 291            let model = model.context("No language model configured")?;
 292
 293            let edit_format = EditFormat::from_model(model.clone())?;
 294            let edit_agent = EditAgent::new(
 295                model,
 296                project.clone(),
 297                action_log.clone(),
 298                // TODO: move edit agent to this crate so we can use our templates
 299                assistant_tools::templates::Templates::new(),
 300                edit_format,
 301            );
 302
 303            let buffer = project
 304                .update(cx, |project, cx| {
 305                    project.open_buffer(project_path.clone(), cx)
 306                })?
 307                .await?;
 308
 309            let diff = cx.new(|cx| Diff::new(buffer.clone(), cx))?;
 310            event_stream.update_diff(diff.clone());
 311            let _finalize_diff = util::defer({
 312               let diff = diff.downgrade();
 313               let mut cx = cx.clone();
 314               move || {
 315                   diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
 316               }
 317            });
 318
 319            let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 320            let old_text = cx
 321                .background_spawn({
 322                    let old_snapshot = old_snapshot.clone();
 323                    async move { Arc::new(old_snapshot.text()) }
 324                })
 325                .await;
 326
 327
 328            let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
 329                edit_agent.edit(
 330                    buffer.clone(),
 331                    input.display_description.clone(),
 332                    &request,
 333                    cx,
 334                )
 335            } else {
 336                edit_agent.overwrite(
 337                    buffer.clone(),
 338                    input.display_description.clone(),
 339                    &request,
 340                    cx,
 341                )
 342            };
 343
 344            let mut hallucinated_old_text = false;
 345            let mut ambiguous_ranges = Vec::new();
 346            let mut emitted_location = false;
 347            while let Some(event) = events.next().await {
 348                match event {
 349                    EditAgentOutputEvent::Edited(range) => {
 350                        if !emitted_location {
 351                            let line = buffer.update(cx, |buffer, _cx| {
 352                                range.start.to_point(&buffer.snapshot()).row
 353                            }).ok();
 354                            if let Some(abs_path) = abs_path.clone() {
 355                                event_stream.update_fields(ToolCallUpdateFields {
 356                                    locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
 357                                    ..Default::default()
 358                                });
 359                            }
 360                            emitted_location = true;
 361                        }
 362                    },
 363                    EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
 364                    EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
 365                    EditAgentOutputEvent::ResolvingEditRange(range) => {
 366                        diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx))?;
 367                        // if !emitted_location {
 368                        //     let line = buffer.update(cx, |buffer, _cx| {
 369                        //         range.start.to_point(&buffer.snapshot()).row
 370                        //     }).ok();
 371                        //     if let Some(abs_path) = abs_path.clone() {
 372                        //         event_stream.update_fields(ToolCallUpdateFields {
 373                        //             locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
 374                        //             ..Default::default()
 375                        //         });
 376                        //     }
 377                        // }
 378                    }
 379                }
 380            }
 381
 382            // If format_on_save is enabled, format the buffer
 383            let format_on_save_enabled = buffer
 384                .read_with(cx, |buffer, cx| {
 385                    let settings = language_settings::language_settings(
 386                        buffer.language().map(|l| l.name()),
 387                        buffer.file(),
 388                        cx,
 389                    );
 390                    settings.format_on_save != FormatOnSave::Off
 391                })
 392                .unwrap_or(false);
 393
 394            let edit_agent_output = output.await?;
 395
 396            if format_on_save_enabled {
 397                action_log.update(cx, |log, cx| {
 398                    log.buffer_edited(buffer.clone(), cx);
 399                })?;
 400
 401                let format_task = project.update(cx, |project, cx| {
 402                    project.format(
 403                        HashSet::from_iter([buffer.clone()]),
 404                        LspFormatTarget::Buffers,
 405                        false, // Don't push to history since the tool did it.
 406                        FormatTrigger::Save,
 407                        cx,
 408                    )
 409                })?;
 410                format_task.await.log_err();
 411            }
 412
 413            project
 414                .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
 415                .await?;
 416
 417            action_log.update(cx, |log, cx| {
 418                log.buffer_edited(buffer.clone(), cx);
 419            })?;
 420
 421            let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
 422            let (new_text, unified_diff) = cx
 423                .background_spawn({
 424                    let new_snapshot = new_snapshot.clone();
 425                    let old_text = old_text.clone();
 426                    async move {
 427                        let new_text = new_snapshot.text();
 428                        let diff = language::unified_diff(&old_text, &new_text);
 429                        (new_text, diff)
 430                    }
 431                })
 432                .await;
 433
 434            let input_path = input.path.display();
 435            if unified_diff.is_empty() {
 436                anyhow::ensure!(
 437                    !hallucinated_old_text,
 438                    formatdoc! {"
 439                        Some edits were produced but none of them could be applied.
 440                        Read the relevant sections of {input_path} again so that
 441                        I can perform the requested edits.
 442                    "}
 443                );
 444                anyhow::ensure!(
 445                    ambiguous_ranges.is_empty(),
 446                    {
 447                        let line_numbers = ambiguous_ranges
 448                            .iter()
 449                            .map(|range| range.start.to_string())
 450                            .collect::<Vec<_>>()
 451                            .join(", ");
 452                        formatdoc! {"
 453                            <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
 454                            relevant sections of {input_path} again and extend <old_text> so
 455                            that I can perform the requested edits.
 456                        "}
 457                    }
 458                );
 459            }
 460
 461            Ok(EditFileToolOutput {
 462                input_path: input.path,
 463                new_text,
 464                old_text,
 465                diff: unified_diff,
 466                edit_agent_output,
 467            })
 468        })
 469    }
 470
 471    fn replay(
 472        &self,
 473        _input: Self::Input,
 474        output: Self::Output,
 475        event_stream: ToolCallEventStream,
 476        cx: &mut App,
 477    ) -> Result<()> {
 478        event_stream.update_diff(cx.new(|cx| {
 479            Diff::finalized(
 480                output.input_path,
 481                Some(output.old_text.to_string()),
 482                output.new_text,
 483                self.language_registry.clone(),
 484                cx,
 485            )
 486        }));
 487        Ok(())
 488    }
 489}
 490
 491/// Validate that the file path is valid, meaning:
 492///
 493/// - For `edit` and `overwrite`, the path must point to an existing file.
 494/// - For `create`, the file must not already exist, but it's parent dir must exist.
 495fn resolve_path(
 496    input: &EditFileToolInput,
 497    project: Entity<Project>,
 498    cx: &mut App,
 499) -> Result<ProjectPath> {
 500    let project = project.read(cx);
 501
 502    match input.mode {
 503        EditFileMode::Edit | EditFileMode::Overwrite => {
 504            let path = project
 505                .find_project_path(&input.path, cx)
 506                .context("Can't edit file: path not found")?;
 507
 508            let entry = project
 509                .entry_for_path(&path, cx)
 510                .context("Can't edit file: path not found")?;
 511
 512            anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
 513            Ok(path)
 514        }
 515
 516        EditFileMode::Create => {
 517            if let Some(path) = project.find_project_path(&input.path, cx) {
 518                anyhow::ensure!(
 519                    project.entry_for_path(&path, cx).is_none(),
 520                    "Can't create file: file already exists"
 521                );
 522            }
 523
 524            let parent_path = input
 525                .path
 526                .parent()
 527                .context("Can't create file: incorrect path")?;
 528
 529            let parent_project_path = project.find_project_path(&parent_path, cx);
 530
 531            let parent_entry = parent_project_path
 532                .as_ref()
 533                .and_then(|path| project.entry_for_path(path, cx))
 534                .context("Can't create file: parent directory doesn't exist")?;
 535
 536            anyhow::ensure!(
 537                parent_entry.is_dir(),
 538                "Can't create file: parent is not a directory"
 539            );
 540
 541            let file_name = input
 542                .path
 543                .file_name()
 544                .context("Can't create file: invalid filename")?;
 545
 546            let new_file_path = parent_project_path.map(|parent| ProjectPath {
 547                path: Arc::from(parent.path.join(file_name)),
 548                ..parent
 549            });
 550
 551            new_file_path.context("Can't create file")
 552        }
 553    }
 554}
 555
 556#[cfg(test)]
 557mod tests {
 558    use super::*;
 559    use crate::{ContextServerRegistry, Templates};
 560    use client::TelemetrySettings;
 561    use fs::Fs;
 562    use gpui::{TestAppContext, UpdateGlobal};
 563    use language_model::fake_provider::FakeLanguageModel;
 564    use prompt_store::ProjectContext;
 565    use serde_json::json;
 566    use settings::SettingsStore;
 567    use util::path;
 568
 569    #[gpui::test]
 570    async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
 571        init_test(cx);
 572
 573        let fs = project::FakeFs::new(cx.executor());
 574        fs.insert_tree("/root", json!({})).await;
 575        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 576        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 577        let context_server_registry =
 578            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 579        let model = Arc::new(FakeLanguageModel::default());
 580        let thread = cx.new(|cx| {
 581            Thread::new(
 582                project.clone(),
 583                cx.new(|_cx| ProjectContext::default()),
 584                context_server_registry,
 585                Templates::new(),
 586                Some(model),
 587                cx,
 588            )
 589        });
 590        let result = cx
 591            .update(|cx| {
 592                let input = EditFileToolInput {
 593                    display_description: "Some edit".into(),
 594                    path: "root/nonexistent_file.txt".into(),
 595                    mode: EditFileMode::Edit,
 596                };
 597                Arc::new(EditFileTool::new(
 598                    project,
 599                    thread.downgrade(),
 600                    language_registry,
 601                ))
 602                .run(input, ToolCallEventStream::test().0, cx)
 603            })
 604            .await;
 605        assert_eq!(
 606            result.unwrap_err().to_string(),
 607            "Can't edit file: path not found"
 608        );
 609    }
 610
 611    #[gpui::test]
 612    async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
 613        let mode = &EditFileMode::Create;
 614
 615        let result = test_resolve_path(mode, "root/new.txt", cx);
 616        assert_resolved_path_eq(result.await, "new.txt");
 617
 618        let result = test_resolve_path(mode, "new.txt", cx);
 619        assert_resolved_path_eq(result.await, "new.txt");
 620
 621        let result = test_resolve_path(mode, "dir/new.txt", cx);
 622        assert_resolved_path_eq(result.await, "dir/new.txt");
 623
 624        let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
 625        assert_eq!(
 626            result.await.unwrap_err().to_string(),
 627            "Can't create file: file already exists"
 628        );
 629
 630        let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
 631        assert_eq!(
 632            result.await.unwrap_err().to_string(),
 633            "Can't create file: parent directory doesn't exist"
 634        );
 635    }
 636
 637    #[gpui::test]
 638    async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
 639        let mode = &EditFileMode::Edit;
 640
 641        let path_with_root = "root/dir/subdir/existing.txt";
 642        let path_without_root = "dir/subdir/existing.txt";
 643        let result = test_resolve_path(mode, path_with_root, cx);
 644        assert_resolved_path_eq(result.await, path_without_root);
 645
 646        let result = test_resolve_path(mode, path_without_root, cx);
 647        assert_resolved_path_eq(result.await, path_without_root);
 648
 649        let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
 650        assert_eq!(
 651            result.await.unwrap_err().to_string(),
 652            "Can't edit file: path not found"
 653        );
 654
 655        let result = test_resolve_path(mode, "root/dir", cx);
 656        assert_eq!(
 657            result.await.unwrap_err().to_string(),
 658            "Can't edit file: path is a directory"
 659        );
 660    }
 661
 662    async fn test_resolve_path(
 663        mode: &EditFileMode,
 664        path: &str,
 665        cx: &mut TestAppContext,
 666    ) -> anyhow::Result<ProjectPath> {
 667        init_test(cx);
 668
 669        let fs = project::FakeFs::new(cx.executor());
 670        fs.insert_tree(
 671            "/root",
 672            json!({
 673                "dir": {
 674                    "subdir": {
 675                        "existing.txt": "hello"
 676                    }
 677                }
 678            }),
 679        )
 680        .await;
 681        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 682
 683        let input = EditFileToolInput {
 684            display_description: "Some edit".into(),
 685            path: path.into(),
 686            mode: mode.clone(),
 687        };
 688
 689        cx.update(|cx| resolve_path(&input, project, cx))
 690    }
 691
 692    fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &str) {
 693        let actual = path
 694            .expect("Should return valid path")
 695            .path
 696            .to_str()
 697            .unwrap()
 698            .replace("\\", "/"); // Naive Windows paths normalization
 699        assert_eq!(actual, expected);
 700    }
 701
 702    #[gpui::test]
 703    async fn test_format_on_save(cx: &mut TestAppContext) {
 704        init_test(cx);
 705
 706        let fs = project::FakeFs::new(cx.executor());
 707        fs.insert_tree("/root", json!({"src": {}})).await;
 708
 709        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 710
 711        // Set up a Rust language with LSP formatting support
 712        let rust_language = Arc::new(language::Language::new(
 713            language::LanguageConfig {
 714                name: "Rust".into(),
 715                matcher: language::LanguageMatcher {
 716                    path_suffixes: vec!["rs".to_string()],
 717                    ..Default::default()
 718                },
 719                ..Default::default()
 720            },
 721            None,
 722        ));
 723
 724        // Register the language and fake LSP
 725        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
 726        language_registry.add(rust_language);
 727
 728        let mut fake_language_servers = language_registry.register_fake_lsp(
 729            "Rust",
 730            language::FakeLspAdapter {
 731                capabilities: lsp::ServerCapabilities {
 732                    document_formatting_provider: Some(lsp::OneOf::Left(true)),
 733                    ..Default::default()
 734                },
 735                ..Default::default()
 736            },
 737        );
 738
 739        // Create the file
 740        fs.save(
 741            path!("/root/src/main.rs").as_ref(),
 742            &"initial content".into(),
 743            language::LineEnding::Unix,
 744        )
 745        .await
 746        .unwrap();
 747
 748        // Open the buffer to trigger LSP initialization
 749        let buffer = project
 750            .update(cx, |project, cx| {
 751                project.open_local_buffer(path!("/root/src/main.rs"), cx)
 752            })
 753            .await
 754            .unwrap();
 755
 756        // Register the buffer with language servers
 757        let _handle = project.update(cx, |project, cx| {
 758            project.register_buffer_with_language_servers(&buffer, cx)
 759        });
 760
 761        const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
 762        const FORMATTED_CONTENT: &str =
 763            "This file was formatted by the fake formatter in the test.\n";
 764
 765        // Get the fake language server and set up formatting handler
 766        let fake_language_server = fake_language_servers.next().await.unwrap();
 767        fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
 768            |_, _| async move {
 769                Ok(Some(vec![lsp::TextEdit {
 770                    range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
 771                    new_text: FORMATTED_CONTENT.to_string(),
 772                }]))
 773            }
 774        });
 775
 776        let context_server_registry =
 777            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 778        let model = Arc::new(FakeLanguageModel::default());
 779        let thread = cx.new(|cx| {
 780            Thread::new(
 781                project.clone(),
 782                cx.new(|_cx| ProjectContext::default()),
 783                context_server_registry,
 784                Templates::new(),
 785                Some(model.clone()),
 786                cx,
 787            )
 788        });
 789
 790        // First, test with format_on_save enabled
 791        cx.update(|cx| {
 792            SettingsStore::update_global(cx, |store, cx| {
 793                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
 794                    cx,
 795                    |settings| {
 796                        settings.defaults.format_on_save = Some(FormatOnSave::On);
 797                        settings.defaults.formatter =
 798                            Some(language::language_settings::SelectedFormatter::Auto);
 799                    },
 800                );
 801            });
 802        });
 803
 804        // Have the model stream unformatted content
 805        let edit_result = {
 806            let edit_task = cx.update(|cx| {
 807                let input = EditFileToolInput {
 808                    display_description: "Create main function".into(),
 809                    path: "root/src/main.rs".into(),
 810                    mode: EditFileMode::Overwrite,
 811                };
 812                Arc::new(EditFileTool::new(
 813                    project.clone(),
 814                    thread.downgrade(),
 815                    language_registry.clone(),
 816                ))
 817                .run(input, ToolCallEventStream::test().0, cx)
 818            });
 819
 820            // Stream the unformatted content
 821            cx.executor().run_until_parked();
 822            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 823            model.end_last_completion_stream();
 824
 825            edit_task.await
 826        };
 827        assert!(edit_result.is_ok());
 828
 829        // Wait for any async operations (e.g. formatting) to complete
 830        cx.executor().run_until_parked();
 831
 832        // Read the file to verify it was formatted automatically
 833        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 834        assert_eq!(
 835            // Ignore carriage returns on Windows
 836            new_content.replace("\r\n", "\n"),
 837            FORMATTED_CONTENT,
 838            "Code should be formatted when format_on_save is enabled"
 839        );
 840
 841        let stale_buffer_count = thread
 842            .read_with(cx, |thread, _cx| thread.action_log.clone())
 843            .read_with(cx, |log, cx| log.stale_buffers(cx).count());
 844
 845        assert_eq!(
 846            stale_buffer_count, 0,
 847            "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
 848             This causes the agent to think the file was modified externally when it was just formatted.",
 849            stale_buffer_count
 850        );
 851
 852        // Next, test with format_on_save disabled
 853        cx.update(|cx| {
 854            SettingsStore::update_global(cx, |store, cx| {
 855                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
 856                    cx,
 857                    |settings| {
 858                        settings.defaults.format_on_save = Some(FormatOnSave::Off);
 859                    },
 860                );
 861            });
 862        });
 863
 864        // Stream unformatted edits again
 865        let edit_result = {
 866            let edit_task = cx.update(|cx| {
 867                let input = EditFileToolInput {
 868                    display_description: "Update main function".into(),
 869                    path: "root/src/main.rs".into(),
 870                    mode: EditFileMode::Overwrite,
 871                };
 872                Arc::new(EditFileTool::new(
 873                    project.clone(),
 874                    thread.downgrade(),
 875                    language_registry,
 876                ))
 877                .run(input, ToolCallEventStream::test().0, cx)
 878            });
 879
 880            // Stream the unformatted content
 881            cx.executor().run_until_parked();
 882            model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
 883            model.end_last_completion_stream();
 884
 885            edit_task.await
 886        };
 887        assert!(edit_result.is_ok());
 888
 889        // Wait for any async operations (e.g. formatting) to complete
 890        cx.executor().run_until_parked();
 891
 892        // Verify the file was not formatted
 893        let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
 894        assert_eq!(
 895            // Ignore carriage returns on Windows
 896            new_content.replace("\r\n", "\n"),
 897            UNFORMATTED_CONTENT,
 898            "Code should not be formatted when format_on_save is disabled"
 899        );
 900    }
 901
 902    #[gpui::test]
 903    async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
 904        init_test(cx);
 905
 906        let fs = project::FakeFs::new(cx.executor());
 907        fs.insert_tree("/root", json!({"src": {}})).await;
 908
 909        // Create a simple file with trailing whitespace
 910        fs.save(
 911            path!("/root/src/main.rs").as_ref(),
 912            &"initial content".into(),
 913            language::LineEnding::Unix,
 914        )
 915        .await
 916        .unwrap();
 917
 918        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
 919        let context_server_registry =
 920            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
 921        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
 922        let model = Arc::new(FakeLanguageModel::default());
 923        let thread = cx.new(|cx| {
 924            Thread::new(
 925                project.clone(),
 926                cx.new(|_cx| ProjectContext::default()),
 927                context_server_registry,
 928                Templates::new(),
 929                Some(model.clone()),
 930                cx,
 931            )
 932        });
 933
 934        // First, test with remove_trailing_whitespace_on_save enabled
 935        cx.update(|cx| {
 936            SettingsStore::update_global(cx, |store, cx| {
 937                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
 938                    cx,
 939                    |settings| {
 940                        settings.defaults.remove_trailing_whitespace_on_save = Some(true);
 941                    },
 942                );
 943            });
 944        });
 945
 946        const CONTENT_WITH_TRAILING_WHITESPACE: &str =
 947            "fn main() {  \n    println!(\"Hello!\");  \n}\n";
 948
 949        // Have the model stream content that contains trailing whitespace
 950        let edit_result = {
 951            let edit_task = cx.update(|cx| {
 952                let input = EditFileToolInput {
 953                    display_description: "Create main function".into(),
 954                    path: "root/src/main.rs".into(),
 955                    mode: EditFileMode::Overwrite,
 956                };
 957                Arc::new(EditFileTool::new(
 958                    project.clone(),
 959                    thread.downgrade(),
 960                    language_registry.clone(),
 961                ))
 962                .run(input, ToolCallEventStream::test().0, cx)
 963            });
 964
 965            // Stream the content with trailing whitespace
 966            cx.executor().run_until_parked();
 967            model.send_last_completion_stream_text_chunk(
 968                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
 969            );
 970            model.end_last_completion_stream();
 971
 972            edit_task.await
 973        };
 974        assert!(edit_result.is_ok());
 975
 976        // Wait for any async operations (e.g. formatting) to complete
 977        cx.executor().run_until_parked();
 978
 979        // Read the file to verify trailing whitespace was removed automatically
 980        assert_eq!(
 981            // Ignore carriage returns on Windows
 982            fs.load(path!("/root/src/main.rs").as_ref())
 983                .await
 984                .unwrap()
 985                .replace("\r\n", "\n"),
 986            "fn main() {\n    println!(\"Hello!\");\n}\n",
 987            "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
 988        );
 989
 990        // Next, test with remove_trailing_whitespace_on_save disabled
 991        cx.update(|cx| {
 992            SettingsStore::update_global(cx, |store, cx| {
 993                store.update_user_settings::<language::language_settings::AllLanguageSettings>(
 994                    cx,
 995                    |settings| {
 996                        settings.defaults.remove_trailing_whitespace_on_save = Some(false);
 997                    },
 998                );
 999            });
1000        });
1001
1002        // Stream edits again with trailing whitespace
1003        let edit_result = {
1004            let edit_task = cx.update(|cx| {
1005                let input = EditFileToolInput {
1006                    display_description: "Update main function".into(),
1007                    path: "root/src/main.rs".into(),
1008                    mode: EditFileMode::Overwrite,
1009                };
1010                Arc::new(EditFileTool::new(
1011                    project.clone(),
1012                    thread.downgrade(),
1013                    language_registry,
1014                ))
1015                .run(input, ToolCallEventStream::test().0, cx)
1016            });
1017
1018            // Stream the content with trailing whitespace
1019            cx.executor().run_until_parked();
1020            model.send_last_completion_stream_text_chunk(
1021                CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1022            );
1023            model.end_last_completion_stream();
1024
1025            edit_task.await
1026        };
1027        assert!(edit_result.is_ok());
1028
1029        // Wait for any async operations (e.g. formatting) to complete
1030        cx.executor().run_until_parked();
1031
1032        // Verify the file still has trailing whitespace
1033        // Read the file again - it should still have trailing whitespace
1034        let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1035        assert_eq!(
1036            // Ignore carriage returns on Windows
1037            final_content.replace("\r\n", "\n"),
1038            CONTENT_WITH_TRAILING_WHITESPACE,
1039            "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1040        );
1041    }
1042
1043    #[gpui::test]
1044    async fn test_authorize(cx: &mut TestAppContext) {
1045        init_test(cx);
1046        let fs = project::FakeFs::new(cx.executor());
1047        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1048        let context_server_registry =
1049            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1050        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1051        let model = Arc::new(FakeLanguageModel::default());
1052        let thread = cx.new(|cx| {
1053            Thread::new(
1054                project.clone(),
1055                cx.new(|_cx| ProjectContext::default()),
1056                context_server_registry,
1057                Templates::new(),
1058                Some(model.clone()),
1059                cx,
1060            )
1061        });
1062        let tool = Arc::new(EditFileTool::new(
1063            project.clone(),
1064            thread.downgrade(),
1065            language_registry,
1066        ));
1067        fs.insert_tree("/root", json!({})).await;
1068
1069        // Test 1: Path with .zed component should require confirmation
1070        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1071        let _auth = cx.update(|cx| {
1072            tool.authorize(
1073                &EditFileToolInput {
1074                    display_description: "test 1".into(),
1075                    path: ".zed/settings.json".into(),
1076                    mode: EditFileMode::Edit,
1077                },
1078                &stream_tx,
1079                cx,
1080            )
1081        });
1082
1083        let event = stream_rx.expect_authorization().await;
1084        assert_eq!(
1085            event.tool_call.fields.title,
1086            Some("test 1 (local settings)".into())
1087        );
1088
1089        // Test 2: Path outside project should require confirmation
1090        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1091        let _auth = cx.update(|cx| {
1092            tool.authorize(
1093                &EditFileToolInput {
1094                    display_description: "test 2".into(),
1095                    path: "/etc/hosts".into(),
1096                    mode: EditFileMode::Edit,
1097                },
1098                &stream_tx,
1099                cx,
1100            )
1101        });
1102
1103        let event = stream_rx.expect_authorization().await;
1104        assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1105
1106        // Test 3: Relative path without .zed should not require confirmation
1107        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1108        cx.update(|cx| {
1109            tool.authorize(
1110                &EditFileToolInput {
1111                    display_description: "test 3".into(),
1112                    path: "root/src/main.rs".into(),
1113                    mode: EditFileMode::Edit,
1114                },
1115                &stream_tx,
1116                cx,
1117            )
1118        })
1119        .await
1120        .unwrap();
1121        assert!(stream_rx.try_next().is_err());
1122
1123        // Test 4: Path with .zed in the middle should require confirmation
1124        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1125        let _auth = cx.update(|cx| {
1126            tool.authorize(
1127                &EditFileToolInput {
1128                    display_description: "test 4".into(),
1129                    path: "root/.zed/tasks.json".into(),
1130                    mode: EditFileMode::Edit,
1131                },
1132                &stream_tx,
1133                cx,
1134            )
1135        });
1136        let event = stream_rx.expect_authorization().await;
1137        assert_eq!(
1138            event.tool_call.fields.title,
1139            Some("test 4 (local settings)".into())
1140        );
1141
1142        // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1143        cx.update(|cx| {
1144            let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1145            settings.always_allow_tool_actions = true;
1146            agent_settings::AgentSettings::override_global(settings, cx);
1147        });
1148
1149        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1150        cx.update(|cx| {
1151            tool.authorize(
1152                &EditFileToolInput {
1153                    display_description: "test 5.1".into(),
1154                    path: ".zed/settings.json".into(),
1155                    mode: EditFileMode::Edit,
1156                },
1157                &stream_tx,
1158                cx,
1159            )
1160        })
1161        .await
1162        .unwrap();
1163        assert!(stream_rx.try_next().is_err());
1164
1165        let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1166        cx.update(|cx| {
1167            tool.authorize(
1168                &EditFileToolInput {
1169                    display_description: "test 5.2".into(),
1170                    path: "/etc/hosts".into(),
1171                    mode: EditFileMode::Edit,
1172                },
1173                &stream_tx,
1174                cx,
1175            )
1176        })
1177        .await
1178        .unwrap();
1179        assert!(stream_rx.try_next().is_err());
1180    }
1181
1182    #[gpui::test]
1183    async fn test_authorize_global_config(cx: &mut TestAppContext) {
1184        init_test(cx);
1185        let fs = project::FakeFs::new(cx.executor());
1186        fs.insert_tree("/project", json!({})).await;
1187        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1188        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1189        let context_server_registry =
1190            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1191        let model = Arc::new(FakeLanguageModel::default());
1192        let thread = cx.new(|cx| {
1193            Thread::new(
1194                project.clone(),
1195                cx.new(|_cx| ProjectContext::default()),
1196                context_server_registry,
1197                Templates::new(),
1198                Some(model.clone()),
1199                cx,
1200            )
1201        });
1202        let tool = Arc::new(EditFileTool::new(
1203            project.clone(),
1204            thread.downgrade(),
1205            language_registry,
1206        ));
1207
1208        // Test global config paths - these should require confirmation if they exist and are outside the project
1209        let test_cases = vec![
1210            (
1211                "/etc/hosts",
1212                true,
1213                "System file should require confirmation",
1214            ),
1215            (
1216                "/usr/local/bin/script",
1217                true,
1218                "System bin file should require confirmation",
1219            ),
1220            (
1221                "project/normal_file.rs",
1222                false,
1223                "Normal project file should not require confirmation",
1224            ),
1225        ];
1226
1227        for (path, should_confirm, description) in test_cases {
1228            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1229            let auth = cx.update(|cx| {
1230                tool.authorize(
1231                    &EditFileToolInput {
1232                        display_description: "Edit file".into(),
1233                        path: path.into(),
1234                        mode: EditFileMode::Edit,
1235                    },
1236                    &stream_tx,
1237                    cx,
1238                )
1239            });
1240
1241            if should_confirm {
1242                stream_rx.expect_authorization().await;
1243            } else {
1244                auth.await.unwrap();
1245                assert!(
1246                    stream_rx.try_next().is_err(),
1247                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1248                    description,
1249                    path
1250                );
1251            }
1252        }
1253    }
1254
1255    #[gpui::test]
1256    async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1257        init_test(cx);
1258        let fs = project::FakeFs::new(cx.executor());
1259
1260        // Create multiple worktree directories
1261        fs.insert_tree(
1262            "/workspace/frontend",
1263            json!({
1264                "src": {
1265                    "main.js": "console.log('frontend');"
1266                }
1267            }),
1268        )
1269        .await;
1270        fs.insert_tree(
1271            "/workspace/backend",
1272            json!({
1273                "src": {
1274                    "main.rs": "fn main() {}"
1275                }
1276            }),
1277        )
1278        .await;
1279        fs.insert_tree(
1280            "/workspace/shared",
1281            json!({
1282                ".zed": {
1283                    "settings.json": "{}"
1284                }
1285            }),
1286        )
1287        .await;
1288
1289        // Create project with multiple worktrees
1290        let project = Project::test(
1291            fs.clone(),
1292            [
1293                path!("/workspace/frontend").as_ref(),
1294                path!("/workspace/backend").as_ref(),
1295                path!("/workspace/shared").as_ref(),
1296            ],
1297            cx,
1298        )
1299        .await;
1300        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1301        let context_server_registry =
1302            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1303        let model = Arc::new(FakeLanguageModel::default());
1304        let thread = cx.new(|cx| {
1305            Thread::new(
1306                project.clone(),
1307                cx.new(|_cx| ProjectContext::default()),
1308                context_server_registry.clone(),
1309                Templates::new(),
1310                Some(model.clone()),
1311                cx,
1312            )
1313        });
1314        let tool = Arc::new(EditFileTool::new(
1315            project.clone(),
1316            thread.downgrade(),
1317            language_registry,
1318        ));
1319
1320        // Test files in different worktrees
1321        let test_cases = vec![
1322            ("frontend/src/main.js", false, "File in first worktree"),
1323            ("backend/src/main.rs", false, "File in second worktree"),
1324            (
1325                "shared/.zed/settings.json",
1326                true,
1327                ".zed file in third worktree",
1328            ),
1329            ("/etc/hosts", true, "Absolute path outside all worktrees"),
1330            (
1331                "../outside/file.txt",
1332                true,
1333                "Relative path outside worktrees",
1334            ),
1335        ];
1336
1337        for (path, should_confirm, description) in test_cases {
1338            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1339            let auth = cx.update(|cx| {
1340                tool.authorize(
1341                    &EditFileToolInput {
1342                        display_description: "Edit file".into(),
1343                        path: path.into(),
1344                        mode: EditFileMode::Edit,
1345                    },
1346                    &stream_tx,
1347                    cx,
1348                )
1349            });
1350
1351            if should_confirm {
1352                stream_rx.expect_authorization().await;
1353            } else {
1354                auth.await.unwrap();
1355                assert!(
1356                    stream_rx.try_next().is_err(),
1357                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1358                    description,
1359                    path
1360                );
1361            }
1362        }
1363    }
1364
1365    #[gpui::test]
1366    async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1367        init_test(cx);
1368        let fs = project::FakeFs::new(cx.executor());
1369        fs.insert_tree(
1370            "/project",
1371            json!({
1372                ".zed": {
1373                    "settings.json": "{}"
1374                },
1375                "src": {
1376                    ".zed": {
1377                        "local.json": "{}"
1378                    }
1379                }
1380            }),
1381        )
1382        .await;
1383        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1384        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1385        let context_server_registry =
1386            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1387        let model = Arc::new(FakeLanguageModel::default());
1388        let thread = cx.new(|cx| {
1389            Thread::new(
1390                project.clone(),
1391                cx.new(|_cx| ProjectContext::default()),
1392                context_server_registry.clone(),
1393                Templates::new(),
1394                Some(model.clone()),
1395                cx,
1396            )
1397        });
1398        let tool = Arc::new(EditFileTool::new(
1399            project.clone(),
1400            thread.downgrade(),
1401            language_registry,
1402        ));
1403
1404        // Test edge cases
1405        let test_cases = vec![
1406            // Empty path - find_project_path returns Some for empty paths
1407            ("", false, "Empty path is treated as project root"),
1408            // Root directory
1409            ("/", true, "Root directory should be outside project"),
1410            // Parent directory references - find_project_path resolves these
1411            (
1412                "project/../other",
1413                false,
1414                "Path with .. is resolved by find_project_path",
1415            ),
1416            (
1417                "project/./src/file.rs",
1418                false,
1419                "Path with . should work normally",
1420            ),
1421            // Windows-style paths (if on Windows)
1422            #[cfg(target_os = "windows")]
1423            ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1424            #[cfg(target_os = "windows")]
1425            ("project\\src\\main.rs", false, "Windows-style project path"),
1426        ];
1427
1428        for (path, should_confirm, description) in test_cases {
1429            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1430            let auth = cx.update(|cx| {
1431                tool.authorize(
1432                    &EditFileToolInput {
1433                        display_description: "Edit file".into(),
1434                        path: path.into(),
1435                        mode: EditFileMode::Edit,
1436                    },
1437                    &stream_tx,
1438                    cx,
1439                )
1440            });
1441
1442            if should_confirm {
1443                stream_rx.expect_authorization().await;
1444            } else {
1445                auth.await.unwrap();
1446                assert!(
1447                    stream_rx.try_next().is_err(),
1448                    "Failed for case: {} - path: {} - expected no confirmation but got one",
1449                    description,
1450                    path
1451                );
1452            }
1453        }
1454    }
1455
1456    #[gpui::test]
1457    async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1458        init_test(cx);
1459        let fs = project::FakeFs::new(cx.executor());
1460        fs.insert_tree(
1461            "/project",
1462            json!({
1463                "existing.txt": "content",
1464                ".zed": {
1465                    "settings.json": "{}"
1466                }
1467            }),
1468        )
1469        .await;
1470        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1471        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1472        let context_server_registry =
1473            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1474        let model = Arc::new(FakeLanguageModel::default());
1475        let thread = cx.new(|cx| {
1476            Thread::new(
1477                project.clone(),
1478                cx.new(|_cx| ProjectContext::default()),
1479                context_server_registry.clone(),
1480                Templates::new(),
1481                Some(model.clone()),
1482                cx,
1483            )
1484        });
1485        let tool = Arc::new(EditFileTool::new(
1486            project.clone(),
1487            thread.downgrade(),
1488            language_registry,
1489        ));
1490
1491        // Test different EditFileMode values
1492        let modes = vec![
1493            EditFileMode::Edit,
1494            EditFileMode::Create,
1495            EditFileMode::Overwrite,
1496        ];
1497
1498        for mode in modes {
1499            // Test .zed path with different modes
1500            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1501            let _auth = cx.update(|cx| {
1502                tool.authorize(
1503                    &EditFileToolInput {
1504                        display_description: "Edit settings".into(),
1505                        path: "project/.zed/settings.json".into(),
1506                        mode: mode.clone(),
1507                    },
1508                    &stream_tx,
1509                    cx,
1510                )
1511            });
1512
1513            stream_rx.expect_authorization().await;
1514
1515            // Test outside path with different modes
1516            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1517            let _auth = cx.update(|cx| {
1518                tool.authorize(
1519                    &EditFileToolInput {
1520                        display_description: "Edit file".into(),
1521                        path: "/outside/file.txt".into(),
1522                        mode: mode.clone(),
1523                    },
1524                    &stream_tx,
1525                    cx,
1526                )
1527            });
1528
1529            stream_rx.expect_authorization().await;
1530
1531            // Test normal path with different modes
1532            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1533            cx.update(|cx| {
1534                tool.authorize(
1535                    &EditFileToolInput {
1536                        display_description: "Edit file".into(),
1537                        path: "project/normal.txt".into(),
1538                        mode: mode.clone(),
1539                    },
1540                    &stream_tx,
1541                    cx,
1542                )
1543            })
1544            .await
1545            .unwrap();
1546            assert!(stream_rx.try_next().is_err());
1547        }
1548    }
1549
1550    #[gpui::test]
1551    async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1552        init_test(cx);
1553        let fs = project::FakeFs::new(cx.executor());
1554        let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1555        let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1556        let context_server_registry =
1557            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1558        let model = Arc::new(FakeLanguageModel::default());
1559        let thread = cx.new(|cx| {
1560            Thread::new(
1561                project.clone(),
1562                cx.new(|_cx| ProjectContext::default()),
1563                context_server_registry,
1564                Templates::new(),
1565                Some(model.clone()),
1566                cx,
1567            )
1568        });
1569        let tool = Arc::new(EditFileTool::new(
1570            project,
1571            thread.downgrade(),
1572            language_registry,
1573        ));
1574
1575        cx.update(|cx| {
1576            // ...
1577            assert_eq!(
1578                tool.initial_title(
1579                    Err(json!({
1580                        "path": "src/main.rs",
1581                        "display_description": "",
1582                        "old_string": "old code",
1583                        "new_string": "new code"
1584                    })),
1585                    cx
1586                ),
1587                "src/main.rs"
1588            );
1589            assert_eq!(
1590                tool.initial_title(
1591                    Err(json!({
1592                        "path": "",
1593                        "display_description": "Fix error handling",
1594                        "old_string": "old code",
1595                        "new_string": "new code"
1596                    })),
1597                    cx
1598                ),
1599                "Fix error handling"
1600            );
1601            assert_eq!(
1602                tool.initial_title(
1603                    Err(json!({
1604                        "path": "src/main.rs",
1605                        "display_description": "Fix error handling",
1606                        "old_string": "old code",
1607                        "new_string": "new code"
1608                    })),
1609                    cx
1610                ),
1611                "src/main.rs"
1612            );
1613            assert_eq!(
1614                tool.initial_title(
1615                    Err(json!({
1616                        "path": "",
1617                        "display_description": "",
1618                        "old_string": "old code",
1619                        "new_string": "new code"
1620                    })),
1621                    cx
1622                ),
1623                DEFAULT_UI_TEXT
1624            );
1625            assert_eq!(
1626                tool.initial_title(Err(serde_json::Value::Null), cx),
1627                DEFAULT_UI_TEXT
1628            );
1629        });
1630    }
1631
1632    #[gpui::test]
1633    async fn test_diff_finalization(cx: &mut TestAppContext) {
1634        init_test(cx);
1635        let fs = project::FakeFs::new(cx.executor());
1636        fs.insert_tree("/", json!({"main.rs": ""})).await;
1637
1638        let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1639        let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1640        let context_server_registry =
1641            cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1642        let model = Arc::new(FakeLanguageModel::default());
1643        let thread = cx.new(|cx| {
1644            Thread::new(
1645                project.clone(),
1646                cx.new(|_cx| ProjectContext::default()),
1647                context_server_registry.clone(),
1648                Templates::new(),
1649                Some(model.clone()),
1650                cx,
1651            )
1652        });
1653
1654        // Ensure the diff is finalized after the edit completes.
1655        {
1656            let tool = Arc::new(EditFileTool::new(
1657                project.clone(),
1658                thread.downgrade(),
1659                languages.clone(),
1660            ));
1661            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1662            let edit = cx.update(|cx| {
1663                tool.run(
1664                    EditFileToolInput {
1665                        display_description: "Edit file".into(),
1666                        path: path!("/main.rs").into(),
1667                        mode: EditFileMode::Edit,
1668                    },
1669                    stream_tx,
1670                    cx,
1671                )
1672            });
1673            stream_rx.expect_update_fields().await;
1674            let diff = stream_rx.expect_diff().await;
1675            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1676            cx.run_until_parked();
1677            model.end_last_completion_stream();
1678            edit.await.unwrap();
1679            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1680        }
1681
1682        // Ensure the diff is finalized if an error occurs while editing.
1683        {
1684            model.forbid_requests();
1685            let tool = Arc::new(EditFileTool::new(
1686                project.clone(),
1687                thread.downgrade(),
1688                languages.clone(),
1689            ));
1690            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1691            let edit = cx.update(|cx| {
1692                tool.run(
1693                    EditFileToolInput {
1694                        display_description: "Edit file".into(),
1695                        path: path!("/main.rs").into(),
1696                        mode: EditFileMode::Edit,
1697                    },
1698                    stream_tx,
1699                    cx,
1700                )
1701            });
1702            stream_rx.expect_update_fields().await;
1703            let diff = stream_rx.expect_diff().await;
1704            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1705            edit.await.unwrap_err();
1706            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1707            model.allow_requests();
1708        }
1709
1710        // Ensure the diff is finalized if the tool call gets dropped.
1711        {
1712            let tool = Arc::new(EditFileTool::new(
1713                project.clone(),
1714                thread.downgrade(),
1715                languages.clone(),
1716            ));
1717            let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1718            let edit = cx.update(|cx| {
1719                tool.run(
1720                    EditFileToolInput {
1721                        display_description: "Edit file".into(),
1722                        path: path!("/main.rs").into(),
1723                        mode: EditFileMode::Edit,
1724                    },
1725                    stream_tx,
1726                    cx,
1727                )
1728            });
1729            stream_rx.expect_update_fields().await;
1730            let diff = stream_rx.expect_diff().await;
1731            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1732            drop(edit);
1733            cx.run_until_parked();
1734            diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1735        }
1736    }
1737
1738    fn init_test(cx: &mut TestAppContext) {
1739        cx.update(|cx| {
1740            let settings_store = SettingsStore::test(cx);
1741            cx.set_global(settings_store);
1742            language::init(cx);
1743            TelemetrySettings::register(cx);
1744            agent_settings::AgentSettings::register(cx);
1745            Project::init_settings(cx);
1746        });
1747    }
1748}