1use super::restore_file_from_disk_tool::RestoreFileFromDiskTool;
2use super::save_file_tool::SaveFileTool;
3use crate::{
4 AgentTool, Templates, Thread, ToolCallEventStream, ToolPermissionDecision,
5 decide_permission_for_path,
6 edit_agent::{EditAgent, EditAgentOutput, EditAgentOutputEvent, EditFormat},
7};
8use acp_thread::Diff;
9use agent_client_protocol::{self as acp, ToolCallLocation, ToolCallUpdateFields};
10use anyhow::{Context as _, Result, anyhow};
11use cloud_llm_client::CompletionIntent;
12use collections::HashSet;
13use futures::{FutureExt as _, StreamExt as _};
14use gpui::{App, AppContext, AsyncApp, Entity, Task, WeakEntity};
15use indoc::formatdoc;
16use language::language_settings::{self, FormatOnSave};
17use language::{LanguageRegistry, ToPoint};
18use language_model::LanguageModelToolResultContent;
19use paths;
20use project::lsp_store::{FormatTrigger, LspFormatTarget};
21use project::{Project, ProjectPath};
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use settings::Settings;
25use std::ffi::OsStr;
26use std::path::{Path, PathBuf};
27use std::sync::Arc;
28use ui::SharedString;
29use util::ResultExt;
30use util::rel_path::RelPath;
31
32const DEFAULT_UI_TEXT: &str = "Editing file";
33
34/// This is a tool for creating a new file or editing an existing file. For moving or renaming files, you should generally use the `terminal` tool with the 'mv' command instead.
35///
36/// Before using this tool:
37///
38/// 1. Use the `read_file` tool to understand the file's contents and context
39///
40/// 2. Verify the directory path is correct (only applicable when creating new files):
41/// - Use the `list_directory` tool to verify the parent directory exists and is the correct location
42#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
43pub struct EditFileToolInput {
44 /// A one-line, user-friendly markdown description of the edit. This will be shown in the UI and also passed to another model to perform the edit.
45 ///
46 /// Be terse, but also descriptive in what you want to achieve with this edit. Avoid generic instructions.
47 ///
48 /// NEVER mention the file path in this description.
49 ///
50 /// <example>Fix API endpoint URLs</example>
51 /// <example>Update copyright year in `page_footer`</example>
52 ///
53 /// Make sure to include this field before all the others in the input object so that we can display it immediately.
54 pub display_description: String,
55
56 /// The full path of the file to create or modify in the project.
57 ///
58 /// WARNING: When specifying which file path need changing, you MUST start each path with one of the project's root directories.
59 ///
60 /// The following examples assume we have two root directories in the project:
61 /// - /a/b/backend
62 /// - /c/d/frontend
63 ///
64 /// <example>
65 /// `backend/src/main.rs`
66 ///
67 /// Notice how the file path starts with `backend`. Without that, the path would be ambiguous and the call would fail!
68 /// </example>
69 ///
70 /// <example>
71 /// `frontend/db.js`
72 /// </example>
73 pub path: PathBuf,
74 /// The mode of operation on the file. Possible values:
75 /// - 'edit': Make granular edits to an existing file.
76 /// - 'create': Create a new file if it doesn't exist.
77 /// - 'overwrite': Replace the entire contents of an existing file.
78 ///
79 /// When a file already exists or you just created it, prefer editing it as opposed to recreating it from scratch.
80 pub mode: EditFileMode,
81}
82
83#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
84struct EditFileToolPartialInput {
85 #[serde(default)]
86 path: String,
87 #[serde(default)]
88 display_description: String,
89}
90
91#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema)]
92#[serde(rename_all = "lowercase")]
93#[schemars(inline)]
94pub enum EditFileMode {
95 Edit,
96 Create,
97 Overwrite,
98}
99
100#[derive(Debug, Serialize, Deserialize)]
101pub struct EditFileToolOutput {
102 #[serde(alias = "original_path")]
103 input_path: PathBuf,
104 new_text: String,
105 old_text: Arc<String>,
106 #[serde(default)]
107 diff: String,
108 #[serde(alias = "raw_output")]
109 edit_agent_output: EditAgentOutput,
110}
111
112impl From<EditFileToolOutput> for LanguageModelToolResultContent {
113 fn from(output: EditFileToolOutput) -> Self {
114 if output.diff.is_empty() {
115 "No edits were made.".into()
116 } else {
117 format!(
118 "Edited {}:\n\n```diff\n{}\n```",
119 output.input_path.display(),
120 output.diff
121 )
122 .into()
123 }
124 }
125}
126
127pub struct EditFileTool {
128 thread: WeakEntity<Thread>,
129 language_registry: Arc<LanguageRegistry>,
130 project: Entity<Project>,
131 templates: Arc<Templates>,
132}
133
134impl EditFileTool {
135 pub fn new(
136 project: Entity<Project>,
137 thread: WeakEntity<Thread>,
138 language_registry: Arc<LanguageRegistry>,
139 templates: Arc<Templates>,
140 ) -> Self {
141 Self {
142 project,
143 thread,
144 language_registry,
145 templates,
146 }
147 }
148
149 pub fn with_thread(&self, new_thread: WeakEntity<Thread>) -> Self {
150 Self {
151 project: self.project.clone(),
152 thread: new_thread,
153 language_registry: self.language_registry.clone(),
154 templates: self.templates.clone(),
155 }
156 }
157
158 fn authorize(
159 &self,
160 input: &EditFileToolInput,
161 event_stream: &ToolCallEventStream,
162 cx: &mut App,
163 ) -> Task<Result<()>> {
164 authorize_file_edit(
165 Self::NAME,
166 &input.path,
167 &input.display_description,
168 &self.thread,
169 event_stream,
170 cx,
171 )
172 }
173}
174
175pub enum SensitiveSettingsKind {
176 Local,
177 Global,
178}
179
180/// Canonicalize a path, stripping the Windows extended-length path prefix (`\\?\`)
181/// that `std::fs::canonicalize` adds on Windows. This ensures that canonicalized
182/// paths can be compared with non-canonicalized paths via `starts_with`.
183fn safe_canonicalize(path: &Path) -> std::io::Result<PathBuf> {
184 let canonical = std::fs::canonicalize(path)?;
185 #[cfg(target_os = "windows")]
186 {
187 let s = canonical.to_string_lossy();
188 if let Some(stripped) = s.strip_prefix("\\\\?\\") {
189 return Ok(PathBuf::from(stripped));
190 }
191 }
192 Ok(canonical)
193}
194
195/// Returns the kind of sensitive settings location this path targets, if any:
196/// either inside a `.zed/` local-settings directory or inside the global config dir.
197pub fn sensitive_settings_kind(path: &Path) -> Option<SensitiveSettingsKind> {
198 let local_settings_folder = paths::local_settings_folder_name();
199 if path.components().any(|component| {
200 component.as_os_str() == <_ as AsRef<OsStr>>::as_ref(&local_settings_folder)
201 }) {
202 return Some(SensitiveSettingsKind::Local);
203 }
204
205 // Walk up the path hierarchy until we find an ancestor that exists and can
206 // be canonicalized, then reconstruct the path from there. This handles
207 // cases where multiple levels of subdirectories don't exist yet (e.g.
208 // ~/.config/zed/new_subdir/evil.json).
209 let canonical_path = {
210 let mut current: Option<&Path> = Some(path);
211 let mut suffix_components = Vec::new();
212 loop {
213 match current {
214 Some(ancestor) => match safe_canonicalize(ancestor) {
215 Ok(canonical) => {
216 let mut result = canonical;
217 for component in suffix_components.into_iter().rev() {
218 result.push(component);
219 }
220 break Some(result);
221 }
222 Err(_) => {
223 if let Some(file_name) = ancestor.file_name() {
224 suffix_components.push(file_name.to_os_string());
225 }
226 current = ancestor.parent();
227 }
228 },
229 None => break None,
230 }
231 }
232 };
233 if let Some(canonical_path) = canonical_path {
234 let config_dir = safe_canonicalize(paths::config_dir())
235 .unwrap_or_else(|_| paths::config_dir().to_path_buf());
236 if canonical_path.starts_with(&config_dir) {
237 return Some(SensitiveSettingsKind::Global);
238 }
239 }
240
241 None
242}
243
244pub fn is_sensitive_settings_path(path: &Path) -> bool {
245 sensitive_settings_kind(path).is_some()
246}
247
248pub fn authorize_file_edit(
249 tool_name: &str,
250 path: &Path,
251 display_description: &str,
252 thread: &WeakEntity<Thread>,
253 event_stream: &ToolCallEventStream,
254 cx: &mut App,
255) -> Task<Result<()>> {
256 let path_str = path.to_string_lossy();
257 let settings = agent_settings::AgentSettings::get_global(cx);
258 let decision = decide_permission_for_path(tool_name, &path_str, settings);
259
260 if let ToolPermissionDecision::Deny(reason) = decision {
261 return Task::ready(Err(anyhow!("{}", reason)));
262 }
263
264 let explicitly_allowed = matches!(decision, ToolPermissionDecision::Allow);
265
266 if explicitly_allowed
267 && (settings.always_allow_tool_actions || !is_sensitive_settings_path(path))
268 {
269 return Task::ready(Ok(()));
270 }
271
272 match sensitive_settings_kind(path) {
273 Some(SensitiveSettingsKind::Local) => {
274 let context = crate::ToolPermissionContext {
275 tool_name: tool_name.to_string(),
276 input_value: path_str.to_string(),
277 };
278 return event_stream.authorize(
279 format!("{} (local settings)", display_description),
280 context,
281 cx,
282 );
283 }
284 Some(SensitiveSettingsKind::Global) => {
285 let context = crate::ToolPermissionContext {
286 tool_name: tool_name.to_string(),
287 input_value: path_str.to_string(),
288 };
289 return event_stream.authorize(
290 format!("{} (settings)", display_description),
291 context,
292 cx,
293 );
294 }
295 None => {}
296 }
297
298 let Ok(project_path) = thread.read_with(cx, |thread, cx| {
299 thread.project().read(cx).find_project_path(path, cx)
300 }) else {
301 return Task::ready(Err(anyhow!("thread was dropped")));
302 };
303
304 if project_path.is_some() {
305 Task::ready(Ok(()))
306 } else {
307 let context = crate::ToolPermissionContext {
308 tool_name: tool_name.to_string(),
309 input_value: path_str.to_string(),
310 };
311 event_stream.authorize(display_description, context, cx)
312 }
313}
314
315impl AgentTool for EditFileTool {
316 type Input = EditFileToolInput;
317 type Output = EditFileToolOutput;
318
319 const NAME: &'static str = "edit_file";
320
321 fn kind() -> acp::ToolKind {
322 acp::ToolKind::Edit
323 }
324
325 fn initial_title(
326 &self,
327 input: Result<Self::Input, serde_json::Value>,
328 cx: &mut App,
329 ) -> SharedString {
330 match input {
331 Ok(input) => self
332 .project
333 .read(cx)
334 .find_project_path(&input.path, cx)
335 .and_then(|project_path| {
336 self.project
337 .read(cx)
338 .short_full_path_for_project_path(&project_path, cx)
339 })
340 .unwrap_or(input.path.to_string_lossy().into_owned())
341 .into(),
342 Err(raw_input) => {
343 if let Some(input) =
344 serde_json::from_value::<EditFileToolPartialInput>(raw_input).ok()
345 {
346 let path = input.path.trim();
347 if !path.is_empty() {
348 return self
349 .project
350 .read(cx)
351 .find_project_path(&input.path, cx)
352 .and_then(|project_path| {
353 self.project
354 .read(cx)
355 .short_full_path_for_project_path(&project_path, cx)
356 })
357 .unwrap_or(input.path)
358 .into();
359 }
360
361 let description = input.display_description.trim();
362 if !description.is_empty() {
363 return description.to_string().into();
364 }
365 }
366
367 DEFAULT_UI_TEXT.into()
368 }
369 }
370 }
371
372 fn run(
373 self: Arc<Self>,
374 input: Self::Input,
375 event_stream: ToolCallEventStream,
376 cx: &mut App,
377 ) -> Task<Result<Self::Output>> {
378 let Ok(project) = self
379 .thread
380 .read_with(cx, |thread, _cx| thread.project().clone())
381 else {
382 return Task::ready(Err(anyhow!("thread was dropped")));
383 };
384 let project_path = match resolve_path(&input, project.clone(), cx) {
385 Ok(path) => path,
386 Err(err) => return Task::ready(Err(anyhow!(err))),
387 };
388 let abs_path = project.read(cx).absolute_path(&project_path, cx);
389 if let Some(abs_path) = abs_path.clone() {
390 event_stream.update_fields(
391 ToolCallUpdateFields::new().locations(vec![acp::ToolCallLocation::new(abs_path)]),
392 );
393 }
394
395 let authorize = self.authorize(&input, &event_stream, cx);
396 cx.spawn(async move |cx: &mut AsyncApp| {
397 authorize.await?;
398
399 let (request, model, action_log) = self.thread.update(cx, |thread, cx| {
400 let request = thread.build_completion_request(CompletionIntent::ToolResults, cx);
401 (request, thread.model().cloned(), thread.action_log().clone())
402 })?;
403 let request = request?;
404 let model = model.context("No language model configured")?;
405
406 let edit_format = EditFormat::from_model(model.clone())?;
407 let edit_agent = EditAgent::new(
408 model,
409 project.clone(),
410 action_log.clone(),
411 self.templates.clone(),
412 edit_format,
413 );
414
415 let buffer = project
416 .update(cx, |project, cx| {
417 project.open_buffer(project_path.clone(), cx)
418 })
419 .await?;
420
421 // Check if the file has been modified since the agent last read it
422 if let Some(abs_path) = abs_path.as_ref() {
423 let (last_read_mtime, current_mtime, is_dirty, has_save_tool, has_restore_tool) = self.thread.update(cx, |thread, cx| {
424 let last_read = thread.file_read_times.get(abs_path).copied();
425 let current = buffer.read(cx).file().and_then(|file| file.disk_state().mtime());
426 let dirty = buffer.read(cx).is_dirty();
427 let has_save = thread.has_tool(SaveFileTool::NAME);
428 let has_restore = thread.has_tool(RestoreFileFromDiskTool::NAME);
429 (last_read, current, dirty, has_save, has_restore)
430 })?;
431
432 // Check for unsaved changes first - these indicate modifications we don't know about
433 if is_dirty {
434 let message = match (has_save_tool, has_restore_tool) {
435 (true, true) => {
436 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
437 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
438 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
439 }
440 (true, false) => {
441 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
442 If they want to keep them, ask for confirmation then use the save_file tool to save the file, then retry this edit. \
443 If they want to discard them, ask the user to manually revert the file, then inform you when it's ok to proceed."
444 }
445 (false, true) => {
446 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes. \
447 If they want to keep them, ask the user to manually save the file, then inform you when it's ok to proceed. \
448 If they want to discard them, ask for confirmation then use the restore_file_from_disk tool to restore the on-disk contents, then retry this edit."
449 }
450 (false, false) => {
451 "This file has unsaved changes. Ask the user whether they want to keep or discard those changes, \
452 then ask them to save or revert the file manually and inform you when it's ok to proceed."
453 }
454 };
455 anyhow::bail!("{}", message);
456 }
457
458 // Check if the file was modified on disk since we last read it
459 if let (Some(last_read), Some(current)) = (last_read_mtime, current_mtime) {
460 // MTime can be unreliable for comparisons, so our newtype intentionally
461 // doesn't support comparing them. If the mtime at all different
462 // (which could be because of a modification or because e.g. system clock changed),
463 // we pessimistically assume it was modified.
464 if current != last_read {
465 anyhow::bail!(
466 "The file {} has been modified since you last read it. \
467 Please read the file again to get the current state before editing it.",
468 input.path.display()
469 );
470 }
471 }
472 }
473
474 let diff = cx.new(|cx| Diff::new(buffer.clone(), cx));
475 event_stream.update_diff(diff.clone());
476 let _finalize_diff = util::defer({
477 let diff = diff.downgrade();
478 let mut cx = cx.clone();
479 move || {
480 diff.update(&mut cx, |diff, cx| diff.finalize(cx)).ok();
481 }
482 });
483
484 let old_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
485 let old_text = cx
486 .background_spawn({
487 let old_snapshot = old_snapshot.clone();
488 async move { Arc::new(old_snapshot.text()) }
489 })
490 .await;
491
492 let (output, mut events) = if matches!(input.mode, EditFileMode::Edit) {
493 edit_agent.edit(
494 buffer.clone(),
495 input.display_description.clone(),
496 &request,
497 cx,
498 )
499 } else {
500 edit_agent.overwrite(
501 buffer.clone(),
502 input.display_description.clone(),
503 &request,
504 cx,
505 )
506 };
507
508 let mut hallucinated_old_text = false;
509 let mut ambiguous_ranges = Vec::new();
510 let mut emitted_location = false;
511 loop {
512 let event = futures::select! {
513 event = events.next().fuse() => match event {
514 Some(event) => event,
515 None => break,
516 },
517 _ = event_stream.cancelled_by_user().fuse() => {
518 anyhow::bail!("Edit cancelled by user");
519 }
520 };
521 match event {
522 EditAgentOutputEvent::Edited(range) => {
523 if !emitted_location {
524 let line = Some(buffer.update(cx, |buffer, _cx| {
525 range.start.to_point(&buffer.snapshot()).row
526 }));
527 if let Some(abs_path) = abs_path.clone() {
528 event_stream.update_fields(ToolCallUpdateFields::new().locations(vec![ToolCallLocation::new(abs_path).line(line)]));
529 }
530 emitted_location = true;
531 }
532 },
533 EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
534 EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
535 EditAgentOutputEvent::ResolvingEditRange(range) => {
536 diff.update(cx, |card, cx| card.reveal_range(range.clone(), cx));
537 // if !emitted_location {
538 // let line = buffer.update(cx, |buffer, _cx| {
539 // range.start.to_point(&buffer.snapshot()).row
540 // }).ok();
541 // if let Some(abs_path) = abs_path.clone() {
542 // event_stream.update_fields(ToolCallUpdateFields {
543 // locations: Some(vec![ToolCallLocation { path: abs_path, line }]),
544 // ..Default::default()
545 // });
546 // }
547 // }
548 }
549 }
550 }
551
552 // If format_on_save is enabled, format the buffer
553 let format_on_save_enabled = buffer.read_with(cx, |buffer, cx| {
554 let settings = language_settings::language_settings(
555 buffer.language().map(|l| l.name()),
556 buffer.file(),
557 cx,
558 );
559 settings.format_on_save != FormatOnSave::Off
560 });
561
562 let edit_agent_output = output.await?;
563
564 if format_on_save_enabled {
565 action_log.update(cx, |log, cx| {
566 log.buffer_edited(buffer.clone(), cx);
567 });
568
569 let format_task = project.update(cx, |project, cx| {
570 project.format(
571 HashSet::from_iter([buffer.clone()]),
572 LspFormatTarget::Buffers,
573 false, // Don't push to history since the tool did it.
574 FormatTrigger::Save,
575 cx,
576 )
577 });
578 format_task.await.log_err();
579 }
580
581 project
582 .update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
583 .await?;
584
585 action_log.update(cx, |log, cx| {
586 log.buffer_edited(buffer.clone(), cx);
587 });
588
589 // Update the recorded read time after a successful edit so consecutive edits work
590 if let Some(abs_path) = abs_path.as_ref() {
591 if let Some(new_mtime) = buffer.read_with(cx, |buffer, _| {
592 buffer.file().and_then(|file| file.disk_state().mtime())
593 }) {
594 self.thread.update(cx, |thread, _| {
595 thread.file_read_times.insert(abs_path.to_path_buf(), new_mtime);
596 })?;
597 }
598 }
599
600 let new_snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
601 let (new_text, unified_diff) = cx
602 .background_spawn({
603 let new_snapshot = new_snapshot.clone();
604 let old_text = old_text.clone();
605 async move {
606 let new_text = new_snapshot.text();
607 let diff = language::unified_diff(&old_text, &new_text);
608 (new_text, diff)
609 }
610 })
611 .await;
612
613 let input_path = input.path.display();
614 if unified_diff.is_empty() {
615 anyhow::ensure!(
616 !hallucinated_old_text,
617 formatdoc! {"
618 Some edits were produced but none of them could be applied.
619 Read the relevant sections of {input_path} again so that
620 I can perform the requested edits.
621 "}
622 );
623 anyhow::ensure!(
624 ambiguous_ranges.is_empty(),
625 {
626 let line_numbers = ambiguous_ranges
627 .iter()
628 .map(|range| range.start.to_string())
629 .collect::<Vec<_>>()
630 .join(", ");
631 formatdoc! {"
632 <old_text> matches more than one position in the file (lines: {line_numbers}). Read the
633 relevant sections of {input_path} again and extend <old_text> so
634 that I can perform the requested edits.
635 "}
636 }
637 );
638 }
639
640 Ok(EditFileToolOutput {
641 input_path: input.path,
642 new_text,
643 old_text,
644 diff: unified_diff,
645 edit_agent_output,
646 })
647 })
648 }
649
650 fn replay(
651 &self,
652 _input: Self::Input,
653 output: Self::Output,
654 event_stream: ToolCallEventStream,
655 cx: &mut App,
656 ) -> Result<()> {
657 event_stream.update_diff(cx.new(|cx| {
658 Diff::finalized(
659 output.input_path.to_string_lossy().into_owned(),
660 Some(output.old_text.to_string()),
661 output.new_text,
662 self.language_registry.clone(),
663 cx,
664 )
665 }));
666 Ok(())
667 }
668
669 fn rebind_thread(
670 &self,
671 new_thread: gpui::WeakEntity<crate::Thread>,
672 ) -> Option<std::sync::Arc<dyn crate::AnyAgentTool>> {
673 Some(self.with_thread(new_thread).erase())
674 }
675}
676
677/// Validate that the file path is valid, meaning:
678///
679/// - For `edit` and `overwrite`, the path must point to an existing file.
680/// - For `create`, the file must not already exist, but it's parent dir must exist.
681fn resolve_path(
682 input: &EditFileToolInput,
683 project: Entity<Project>,
684 cx: &mut App,
685) -> Result<ProjectPath> {
686 let project = project.read(cx);
687
688 match input.mode {
689 EditFileMode::Edit | EditFileMode::Overwrite => {
690 let path = project
691 .find_project_path(&input.path, cx)
692 .context("Can't edit file: path not found")?;
693
694 let entry = project
695 .entry_for_path(&path, cx)
696 .context("Can't edit file: path not found")?;
697
698 anyhow::ensure!(entry.is_file(), "Can't edit file: path is a directory");
699 Ok(path)
700 }
701
702 EditFileMode::Create => {
703 if let Some(path) = project.find_project_path(&input.path, cx) {
704 anyhow::ensure!(
705 project.entry_for_path(&path, cx).is_none(),
706 "Can't create file: file already exists"
707 );
708 }
709
710 let parent_path = input
711 .path
712 .parent()
713 .context("Can't create file: incorrect path")?;
714
715 let parent_project_path = project.find_project_path(&parent_path, cx);
716
717 let parent_entry = parent_project_path
718 .as_ref()
719 .and_then(|path| project.entry_for_path(path, cx))
720 .context("Can't create file: parent directory doesn't exist")?;
721
722 anyhow::ensure!(
723 parent_entry.is_dir(),
724 "Can't create file: parent is not a directory"
725 );
726
727 let file_name = input
728 .path
729 .file_name()
730 .and_then(|file_name| file_name.to_str())
731 .and_then(|file_name| RelPath::unix(file_name).ok())
732 .context("Can't create file: invalid filename")?;
733
734 let new_file_path = parent_project_path.map(|parent| ProjectPath {
735 path: parent.path.join(file_name),
736 ..parent
737 });
738
739 new_file_path.context("Can't create file")
740 }
741 }
742}
743
744#[cfg(test)]
745mod tests {
746 use super::*;
747 use crate::{ContextServerRegistry, Templates};
748 use fs::Fs;
749 use gpui::{TestAppContext, UpdateGlobal};
750 use language_model::fake_provider::FakeLanguageModel;
751 use prompt_store::ProjectContext;
752 use serde_json::json;
753 use settings::SettingsStore;
754 use util::{path, rel_path::rel_path};
755
756 #[gpui::test]
757 async fn test_edit_nonexistent_file(cx: &mut TestAppContext) {
758 init_test(cx);
759
760 let fs = project::FakeFs::new(cx.executor());
761 fs.insert_tree("/root", json!({})).await;
762 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
763 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
764 let context_server_registry =
765 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
766 let model = Arc::new(FakeLanguageModel::default());
767 let thread = cx.new(|cx| {
768 Thread::new(
769 project.clone(),
770 cx.new(|_cx| ProjectContext::default()),
771 context_server_registry,
772 Templates::new(),
773 Some(model),
774 cx,
775 )
776 });
777 let result = cx
778 .update(|cx| {
779 let input = EditFileToolInput {
780 display_description: "Some edit".into(),
781 path: "root/nonexistent_file.txt".into(),
782 mode: EditFileMode::Edit,
783 };
784 Arc::new(EditFileTool::new(
785 project,
786 thread.downgrade(),
787 language_registry,
788 Templates::new(),
789 ))
790 .run(input, ToolCallEventStream::test().0, cx)
791 })
792 .await;
793 assert_eq!(
794 result.unwrap_err().to_string(),
795 "Can't edit file: path not found"
796 );
797 }
798
799 #[gpui::test]
800 async fn test_resolve_path_for_creating_file(cx: &mut TestAppContext) {
801 let mode = &EditFileMode::Create;
802
803 let result = test_resolve_path(mode, "root/new.txt", cx);
804 assert_resolved_path_eq(result.await, rel_path("new.txt"));
805
806 let result = test_resolve_path(mode, "new.txt", cx);
807 assert_resolved_path_eq(result.await, rel_path("new.txt"));
808
809 let result = test_resolve_path(mode, "dir/new.txt", cx);
810 assert_resolved_path_eq(result.await, rel_path("dir/new.txt"));
811
812 let result = test_resolve_path(mode, "root/dir/subdir/existing.txt", cx);
813 assert_eq!(
814 result.await.unwrap_err().to_string(),
815 "Can't create file: file already exists"
816 );
817
818 let result = test_resolve_path(mode, "root/dir/nonexistent_dir/new.txt", cx);
819 assert_eq!(
820 result.await.unwrap_err().to_string(),
821 "Can't create file: parent directory doesn't exist"
822 );
823 }
824
825 #[gpui::test]
826 async fn test_resolve_path_for_editing_file(cx: &mut TestAppContext) {
827 let mode = &EditFileMode::Edit;
828
829 let path_with_root = "root/dir/subdir/existing.txt";
830 let path_without_root = "dir/subdir/existing.txt";
831 let result = test_resolve_path(mode, path_with_root, cx);
832 assert_resolved_path_eq(result.await, rel_path(path_without_root));
833
834 let result = test_resolve_path(mode, path_without_root, cx);
835 assert_resolved_path_eq(result.await, rel_path(path_without_root));
836
837 let result = test_resolve_path(mode, "root/nonexistent.txt", cx);
838 assert_eq!(
839 result.await.unwrap_err().to_string(),
840 "Can't edit file: path not found"
841 );
842
843 let result = test_resolve_path(mode, "root/dir", cx);
844 assert_eq!(
845 result.await.unwrap_err().to_string(),
846 "Can't edit file: path is a directory"
847 );
848 }
849
850 async fn test_resolve_path(
851 mode: &EditFileMode,
852 path: &str,
853 cx: &mut TestAppContext,
854 ) -> anyhow::Result<ProjectPath> {
855 init_test(cx);
856
857 let fs = project::FakeFs::new(cx.executor());
858 fs.insert_tree(
859 "/root",
860 json!({
861 "dir": {
862 "subdir": {
863 "existing.txt": "hello"
864 }
865 }
866 }),
867 )
868 .await;
869 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
870
871 let input = EditFileToolInput {
872 display_description: "Some edit".into(),
873 path: path.into(),
874 mode: mode.clone(),
875 };
876
877 cx.update(|cx| resolve_path(&input, project, cx))
878 }
879
880 #[track_caller]
881 fn assert_resolved_path_eq(path: anyhow::Result<ProjectPath>, expected: &RelPath) {
882 let actual = path.expect("Should return valid path").path;
883 assert_eq!(actual.as_ref(), expected);
884 }
885
886 #[gpui::test]
887 async fn test_format_on_save(cx: &mut TestAppContext) {
888 init_test(cx);
889
890 let fs = project::FakeFs::new(cx.executor());
891 fs.insert_tree("/root", json!({"src": {}})).await;
892
893 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
894
895 // Set up a Rust language with LSP formatting support
896 let rust_language = Arc::new(language::Language::new(
897 language::LanguageConfig {
898 name: "Rust".into(),
899 matcher: language::LanguageMatcher {
900 path_suffixes: vec!["rs".to_string()],
901 ..Default::default()
902 },
903 ..Default::default()
904 },
905 None,
906 ));
907
908 // Register the language and fake LSP
909 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
910 language_registry.add(rust_language);
911
912 let mut fake_language_servers = language_registry.register_fake_lsp(
913 "Rust",
914 language::FakeLspAdapter {
915 capabilities: lsp::ServerCapabilities {
916 document_formatting_provider: Some(lsp::OneOf::Left(true)),
917 ..Default::default()
918 },
919 ..Default::default()
920 },
921 );
922
923 // Create the file
924 fs.save(
925 path!("/root/src/main.rs").as_ref(),
926 &"initial content".into(),
927 language::LineEnding::Unix,
928 )
929 .await
930 .unwrap();
931
932 // Open the buffer to trigger LSP initialization
933 let buffer = project
934 .update(cx, |project, cx| {
935 project.open_local_buffer(path!("/root/src/main.rs"), cx)
936 })
937 .await
938 .unwrap();
939
940 // Register the buffer with language servers
941 let _handle = project.update(cx, |project, cx| {
942 project.register_buffer_with_language_servers(&buffer, cx)
943 });
944
945 const UNFORMATTED_CONTENT: &str = "fn main() {println!(\"Hello!\");}\n";
946 const FORMATTED_CONTENT: &str =
947 "This file was formatted by the fake formatter in the test.\n";
948
949 // Get the fake language server and set up formatting handler
950 let fake_language_server = fake_language_servers.next().await.unwrap();
951 fake_language_server.set_request_handler::<lsp::request::Formatting, _, _>({
952 |_, _| async move {
953 Ok(Some(vec![lsp::TextEdit {
954 range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(1, 0)),
955 new_text: FORMATTED_CONTENT.to_string(),
956 }]))
957 }
958 });
959
960 let context_server_registry =
961 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
962 let model = Arc::new(FakeLanguageModel::default());
963 let thread = cx.new(|cx| {
964 Thread::new(
965 project.clone(),
966 cx.new(|_cx| ProjectContext::default()),
967 context_server_registry,
968 Templates::new(),
969 Some(model.clone()),
970 cx,
971 )
972 });
973
974 // First, test with format_on_save enabled
975 cx.update(|cx| {
976 SettingsStore::update_global(cx, |store, cx| {
977 store.update_user_settings(cx, |settings| {
978 settings.project.all_languages.defaults.format_on_save = Some(FormatOnSave::On);
979 settings.project.all_languages.defaults.formatter =
980 Some(language::language_settings::FormatterList::default());
981 });
982 });
983 });
984
985 // Have the model stream unformatted content
986 let edit_result = {
987 let edit_task = cx.update(|cx| {
988 let input = EditFileToolInput {
989 display_description: "Create main function".into(),
990 path: "root/src/main.rs".into(),
991 mode: EditFileMode::Overwrite,
992 };
993 Arc::new(EditFileTool::new(
994 project.clone(),
995 thread.downgrade(),
996 language_registry.clone(),
997 Templates::new(),
998 ))
999 .run(input, ToolCallEventStream::test().0, cx)
1000 });
1001
1002 // Stream the unformatted content
1003 cx.executor().run_until_parked();
1004 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
1005 model.end_last_completion_stream();
1006
1007 edit_task.await
1008 };
1009 assert!(edit_result.is_ok());
1010
1011 // Wait for any async operations (e.g. formatting) to complete
1012 cx.executor().run_until_parked();
1013
1014 // Read the file to verify it was formatted automatically
1015 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1016 assert_eq!(
1017 // Ignore carriage returns on Windows
1018 new_content.replace("\r\n", "\n"),
1019 FORMATTED_CONTENT,
1020 "Code should be formatted when format_on_save is enabled"
1021 );
1022
1023 let stale_buffer_count = thread
1024 .read_with(cx, |thread, _cx| thread.action_log.clone())
1025 .read_with(cx, |log, cx| log.stale_buffers(cx).count());
1026
1027 assert_eq!(
1028 stale_buffer_count, 0,
1029 "BUG: Buffer is incorrectly marked as stale after format-on-save. Found {} stale buffers. \
1030 This causes the agent to think the file was modified externally when it was just formatted.",
1031 stale_buffer_count
1032 );
1033
1034 // Next, test with format_on_save disabled
1035 cx.update(|cx| {
1036 SettingsStore::update_global(cx, |store, cx| {
1037 store.update_user_settings(cx, |settings| {
1038 settings.project.all_languages.defaults.format_on_save =
1039 Some(FormatOnSave::Off);
1040 });
1041 });
1042 });
1043
1044 // Stream unformatted edits again
1045 let edit_result = {
1046 let edit_task = cx.update(|cx| {
1047 let input = EditFileToolInput {
1048 display_description: "Update main function".into(),
1049 path: "root/src/main.rs".into(),
1050 mode: EditFileMode::Overwrite,
1051 };
1052 Arc::new(EditFileTool::new(
1053 project.clone(),
1054 thread.downgrade(),
1055 language_registry,
1056 Templates::new(),
1057 ))
1058 .run(input, ToolCallEventStream::test().0, cx)
1059 });
1060
1061 // Stream the unformatted content
1062 cx.executor().run_until_parked();
1063 model.send_last_completion_stream_text_chunk(UNFORMATTED_CONTENT.to_string());
1064 model.end_last_completion_stream();
1065
1066 edit_task.await
1067 };
1068 assert!(edit_result.is_ok());
1069
1070 // Wait for any async operations (e.g. formatting) to complete
1071 cx.executor().run_until_parked();
1072
1073 // Verify the file was not formatted
1074 let new_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1075 assert_eq!(
1076 // Ignore carriage returns on Windows
1077 new_content.replace("\r\n", "\n"),
1078 UNFORMATTED_CONTENT,
1079 "Code should not be formatted when format_on_save is disabled"
1080 );
1081 }
1082
1083 #[gpui::test]
1084 async fn test_remove_trailing_whitespace(cx: &mut TestAppContext) {
1085 init_test(cx);
1086
1087 let fs = project::FakeFs::new(cx.executor());
1088 fs.insert_tree("/root", json!({"src": {}})).await;
1089
1090 // Create a simple file with trailing whitespace
1091 fs.save(
1092 path!("/root/src/main.rs").as_ref(),
1093 &"initial content".into(),
1094 language::LineEnding::Unix,
1095 )
1096 .await
1097 .unwrap();
1098
1099 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1100 let context_server_registry =
1101 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1102 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1103 let model = Arc::new(FakeLanguageModel::default());
1104 let thread = cx.new(|cx| {
1105 Thread::new(
1106 project.clone(),
1107 cx.new(|_cx| ProjectContext::default()),
1108 context_server_registry,
1109 Templates::new(),
1110 Some(model.clone()),
1111 cx,
1112 )
1113 });
1114
1115 // First, test with remove_trailing_whitespace_on_save enabled
1116 cx.update(|cx| {
1117 SettingsStore::update_global(cx, |store, cx| {
1118 store.update_user_settings(cx, |settings| {
1119 settings
1120 .project
1121 .all_languages
1122 .defaults
1123 .remove_trailing_whitespace_on_save = Some(true);
1124 });
1125 });
1126 });
1127
1128 const CONTENT_WITH_TRAILING_WHITESPACE: &str =
1129 "fn main() { \n println!(\"Hello!\"); \n}\n";
1130
1131 // Have the model stream content that contains trailing whitespace
1132 let edit_result = {
1133 let edit_task = cx.update(|cx| {
1134 let input = EditFileToolInput {
1135 display_description: "Create main function".into(),
1136 path: "root/src/main.rs".into(),
1137 mode: EditFileMode::Overwrite,
1138 };
1139 Arc::new(EditFileTool::new(
1140 project.clone(),
1141 thread.downgrade(),
1142 language_registry.clone(),
1143 Templates::new(),
1144 ))
1145 .run(input, ToolCallEventStream::test().0, cx)
1146 });
1147
1148 // Stream the content with trailing whitespace
1149 cx.executor().run_until_parked();
1150 model.send_last_completion_stream_text_chunk(
1151 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1152 );
1153 model.end_last_completion_stream();
1154
1155 edit_task.await
1156 };
1157 assert!(edit_result.is_ok());
1158
1159 // Wait for any async operations (e.g. formatting) to complete
1160 cx.executor().run_until_parked();
1161
1162 // Read the file to verify trailing whitespace was removed automatically
1163 assert_eq!(
1164 // Ignore carriage returns on Windows
1165 fs.load(path!("/root/src/main.rs").as_ref())
1166 .await
1167 .unwrap()
1168 .replace("\r\n", "\n"),
1169 "fn main() {\n println!(\"Hello!\");\n}\n",
1170 "Trailing whitespace should be removed when remove_trailing_whitespace_on_save is enabled"
1171 );
1172
1173 // Next, test with remove_trailing_whitespace_on_save disabled
1174 cx.update(|cx| {
1175 SettingsStore::update_global(cx, |store, cx| {
1176 store.update_user_settings(cx, |settings| {
1177 settings
1178 .project
1179 .all_languages
1180 .defaults
1181 .remove_trailing_whitespace_on_save = Some(false);
1182 });
1183 });
1184 });
1185
1186 // Stream edits again with trailing whitespace
1187 let edit_result = {
1188 let edit_task = cx.update(|cx| {
1189 let input = EditFileToolInput {
1190 display_description: "Update main function".into(),
1191 path: "root/src/main.rs".into(),
1192 mode: EditFileMode::Overwrite,
1193 };
1194 Arc::new(EditFileTool::new(
1195 project.clone(),
1196 thread.downgrade(),
1197 language_registry,
1198 Templates::new(),
1199 ))
1200 .run(input, ToolCallEventStream::test().0, cx)
1201 });
1202
1203 // Stream the content with trailing whitespace
1204 cx.executor().run_until_parked();
1205 model.send_last_completion_stream_text_chunk(
1206 CONTENT_WITH_TRAILING_WHITESPACE.to_string(),
1207 );
1208 model.end_last_completion_stream();
1209
1210 edit_task.await
1211 };
1212 assert!(edit_result.is_ok());
1213
1214 // Wait for any async operations (e.g. formatting) to complete
1215 cx.executor().run_until_parked();
1216
1217 // Verify the file still has trailing whitespace
1218 // Read the file again - it should still have trailing whitespace
1219 let final_content = fs.load(path!("/root/src/main.rs").as_ref()).await.unwrap();
1220 assert_eq!(
1221 // Ignore carriage returns on Windows
1222 final_content.replace("\r\n", "\n"),
1223 CONTENT_WITH_TRAILING_WHITESPACE,
1224 "Trailing whitespace should remain when remove_trailing_whitespace_on_save is disabled"
1225 );
1226 }
1227
1228 #[gpui::test]
1229 async fn test_authorize(cx: &mut TestAppContext) {
1230 init_test(cx);
1231 let fs = project::FakeFs::new(cx.executor());
1232 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1233 let context_server_registry =
1234 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1235 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1236 let model = Arc::new(FakeLanguageModel::default());
1237 let thread = cx.new(|cx| {
1238 Thread::new(
1239 project.clone(),
1240 cx.new(|_cx| ProjectContext::default()),
1241 context_server_registry,
1242 Templates::new(),
1243 Some(model.clone()),
1244 cx,
1245 )
1246 });
1247 let tool = Arc::new(EditFileTool::new(
1248 project.clone(),
1249 thread.downgrade(),
1250 language_registry,
1251 Templates::new(),
1252 ));
1253 fs.insert_tree("/root", json!({})).await;
1254
1255 // Test 1: Path with .zed component should require confirmation
1256 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1257 let _auth = cx.update(|cx| {
1258 tool.authorize(
1259 &EditFileToolInput {
1260 display_description: "test 1".into(),
1261 path: ".zed/settings.json".into(),
1262 mode: EditFileMode::Edit,
1263 },
1264 &stream_tx,
1265 cx,
1266 )
1267 });
1268
1269 let event = stream_rx.expect_authorization().await;
1270 assert_eq!(
1271 event.tool_call.fields.title,
1272 Some("test 1 (local settings)".into())
1273 );
1274
1275 // Test 2: Path outside project should require confirmation
1276 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1277 let _auth = cx.update(|cx| {
1278 tool.authorize(
1279 &EditFileToolInput {
1280 display_description: "test 2".into(),
1281 path: "/etc/hosts".into(),
1282 mode: EditFileMode::Edit,
1283 },
1284 &stream_tx,
1285 cx,
1286 )
1287 });
1288
1289 let event = stream_rx.expect_authorization().await;
1290 assert_eq!(event.tool_call.fields.title, Some("test 2".into()));
1291
1292 // Test 3: Relative path without .zed should not require confirmation
1293 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1294 cx.update(|cx| {
1295 tool.authorize(
1296 &EditFileToolInput {
1297 display_description: "test 3".into(),
1298 path: "root/src/main.rs".into(),
1299 mode: EditFileMode::Edit,
1300 },
1301 &stream_tx,
1302 cx,
1303 )
1304 })
1305 .await
1306 .unwrap();
1307 assert!(stream_rx.try_next().is_err());
1308
1309 // Test 4: Path with .zed in the middle should require confirmation
1310 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1311 let _auth = cx.update(|cx| {
1312 tool.authorize(
1313 &EditFileToolInput {
1314 display_description: "test 4".into(),
1315 path: "root/.zed/tasks.json".into(),
1316 mode: EditFileMode::Edit,
1317 },
1318 &stream_tx,
1319 cx,
1320 )
1321 });
1322 let event = stream_rx.expect_authorization().await;
1323 assert_eq!(
1324 event.tool_call.fields.title,
1325 Some("test 4 (local settings)".into())
1326 );
1327
1328 // Test 5: When always_allow_tool_actions is enabled, no confirmation needed
1329 cx.update(|cx| {
1330 let mut settings = agent_settings::AgentSettings::get_global(cx).clone();
1331 settings.always_allow_tool_actions = true;
1332 agent_settings::AgentSettings::override_global(settings, cx);
1333 });
1334
1335 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1336 cx.update(|cx| {
1337 tool.authorize(
1338 &EditFileToolInput {
1339 display_description: "test 5.1".into(),
1340 path: ".zed/settings.json".into(),
1341 mode: EditFileMode::Edit,
1342 },
1343 &stream_tx,
1344 cx,
1345 )
1346 })
1347 .await
1348 .unwrap();
1349 assert!(stream_rx.try_next().is_err());
1350
1351 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1352 cx.update(|cx| {
1353 tool.authorize(
1354 &EditFileToolInput {
1355 display_description: "test 5.2".into(),
1356 path: "/etc/hosts".into(),
1357 mode: EditFileMode::Edit,
1358 },
1359 &stream_tx,
1360 cx,
1361 )
1362 })
1363 .await
1364 .unwrap();
1365 assert!(stream_rx.try_next().is_err());
1366 }
1367
1368 #[gpui::test]
1369 async fn test_authorize_global_config(cx: &mut TestAppContext) {
1370 init_test(cx);
1371 let fs = project::FakeFs::new(cx.executor());
1372 fs.insert_tree("/project", json!({})).await;
1373 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1374 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1375 let context_server_registry =
1376 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1377 let model = Arc::new(FakeLanguageModel::default());
1378 let thread = cx.new(|cx| {
1379 Thread::new(
1380 project.clone(),
1381 cx.new(|_cx| ProjectContext::default()),
1382 context_server_registry,
1383 Templates::new(),
1384 Some(model.clone()),
1385 cx,
1386 )
1387 });
1388 let tool = Arc::new(EditFileTool::new(
1389 project.clone(),
1390 thread.downgrade(),
1391 language_registry,
1392 Templates::new(),
1393 ));
1394
1395 // Test global config paths - these should require confirmation if they exist and are outside the project
1396 let test_cases = vec![
1397 (
1398 "/etc/hosts",
1399 true,
1400 "System file should require confirmation",
1401 ),
1402 (
1403 "/usr/local/bin/script",
1404 true,
1405 "System bin file should require confirmation",
1406 ),
1407 (
1408 "project/normal_file.rs",
1409 false,
1410 "Normal project file should not require confirmation",
1411 ),
1412 ];
1413
1414 for (path, should_confirm, description) in test_cases {
1415 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1416 let auth = cx.update(|cx| {
1417 tool.authorize(
1418 &EditFileToolInput {
1419 display_description: "Edit file".into(),
1420 path: path.into(),
1421 mode: EditFileMode::Edit,
1422 },
1423 &stream_tx,
1424 cx,
1425 )
1426 });
1427
1428 if should_confirm {
1429 stream_rx.expect_authorization().await;
1430 } else {
1431 auth.await.unwrap();
1432 assert!(
1433 stream_rx.try_next().is_err(),
1434 "Failed for case: {} - path: {} - expected no confirmation but got one",
1435 description,
1436 path
1437 );
1438 }
1439 }
1440 }
1441
1442 #[gpui::test]
1443 async fn test_needs_confirmation_with_multiple_worktrees(cx: &mut TestAppContext) {
1444 init_test(cx);
1445 let fs = project::FakeFs::new(cx.executor());
1446
1447 // Create multiple worktree directories
1448 fs.insert_tree(
1449 "/workspace/frontend",
1450 json!({
1451 "src": {
1452 "main.js": "console.log('frontend');"
1453 }
1454 }),
1455 )
1456 .await;
1457 fs.insert_tree(
1458 "/workspace/backend",
1459 json!({
1460 "src": {
1461 "main.rs": "fn main() {}"
1462 }
1463 }),
1464 )
1465 .await;
1466 fs.insert_tree(
1467 "/workspace/shared",
1468 json!({
1469 ".zed": {
1470 "settings.json": "{}"
1471 }
1472 }),
1473 )
1474 .await;
1475
1476 // Create project with multiple worktrees
1477 let project = Project::test(
1478 fs.clone(),
1479 [
1480 path!("/workspace/frontend").as_ref(),
1481 path!("/workspace/backend").as_ref(),
1482 path!("/workspace/shared").as_ref(),
1483 ],
1484 cx,
1485 )
1486 .await;
1487 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1488 let context_server_registry =
1489 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1490 let model = Arc::new(FakeLanguageModel::default());
1491 let thread = cx.new(|cx| {
1492 Thread::new(
1493 project.clone(),
1494 cx.new(|_cx| ProjectContext::default()),
1495 context_server_registry.clone(),
1496 Templates::new(),
1497 Some(model.clone()),
1498 cx,
1499 )
1500 });
1501 let tool = Arc::new(EditFileTool::new(
1502 project.clone(),
1503 thread.downgrade(),
1504 language_registry,
1505 Templates::new(),
1506 ));
1507
1508 // Test files in different worktrees
1509 let test_cases = vec![
1510 ("frontend/src/main.js", false, "File in first worktree"),
1511 ("backend/src/main.rs", false, "File in second worktree"),
1512 (
1513 "shared/.zed/settings.json",
1514 true,
1515 ".zed file in third worktree",
1516 ),
1517 ("/etc/hosts", true, "Absolute path outside all worktrees"),
1518 (
1519 "../outside/file.txt",
1520 true,
1521 "Relative path outside worktrees",
1522 ),
1523 ];
1524
1525 for (path, should_confirm, description) in test_cases {
1526 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1527 let auth = cx.update(|cx| {
1528 tool.authorize(
1529 &EditFileToolInput {
1530 display_description: "Edit file".into(),
1531 path: path.into(),
1532 mode: EditFileMode::Edit,
1533 },
1534 &stream_tx,
1535 cx,
1536 )
1537 });
1538
1539 if should_confirm {
1540 stream_rx.expect_authorization().await;
1541 } else {
1542 auth.await.unwrap();
1543 assert!(
1544 stream_rx.try_next().is_err(),
1545 "Failed for case: {} - path: {} - expected no confirmation but got one",
1546 description,
1547 path
1548 );
1549 }
1550 }
1551 }
1552
1553 #[gpui::test]
1554 async fn test_needs_confirmation_edge_cases(cx: &mut TestAppContext) {
1555 init_test(cx);
1556 let fs = project::FakeFs::new(cx.executor());
1557 fs.insert_tree(
1558 "/project",
1559 json!({
1560 ".zed": {
1561 "settings.json": "{}"
1562 },
1563 "src": {
1564 ".zed": {
1565 "local.json": "{}"
1566 }
1567 }
1568 }),
1569 )
1570 .await;
1571 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1572 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1573 let context_server_registry =
1574 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1575 let model = Arc::new(FakeLanguageModel::default());
1576 let thread = cx.new(|cx| {
1577 Thread::new(
1578 project.clone(),
1579 cx.new(|_cx| ProjectContext::default()),
1580 context_server_registry.clone(),
1581 Templates::new(),
1582 Some(model.clone()),
1583 cx,
1584 )
1585 });
1586 let tool = Arc::new(EditFileTool::new(
1587 project.clone(),
1588 thread.downgrade(),
1589 language_registry,
1590 Templates::new(),
1591 ));
1592
1593 // Test edge cases
1594 let test_cases = vec![
1595 // Empty path - find_project_path returns Some for empty paths
1596 ("", false, "Empty path is treated as project root"),
1597 // Root directory
1598 ("/", true, "Root directory should be outside project"),
1599 // Parent directory references - find_project_path resolves these
1600 (
1601 "project/../other",
1602 true,
1603 "Path with .. that goes outside of root directory",
1604 ),
1605 (
1606 "project/./src/file.rs",
1607 false,
1608 "Path with . should work normally",
1609 ),
1610 // Windows-style paths (if on Windows)
1611 #[cfg(target_os = "windows")]
1612 ("C:\\Windows\\System32\\hosts", true, "Windows system path"),
1613 #[cfg(target_os = "windows")]
1614 ("project\\src\\main.rs", false, "Windows-style project path"),
1615 ];
1616
1617 for (path, should_confirm, description) in test_cases {
1618 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1619 let auth = cx.update(|cx| {
1620 tool.authorize(
1621 &EditFileToolInput {
1622 display_description: "Edit file".into(),
1623 path: path.into(),
1624 mode: EditFileMode::Edit,
1625 },
1626 &stream_tx,
1627 cx,
1628 )
1629 });
1630
1631 cx.run_until_parked();
1632
1633 if should_confirm {
1634 stream_rx.expect_authorization().await;
1635 } else {
1636 assert!(
1637 stream_rx.try_next().is_err(),
1638 "Failed for case: {} - path: {} - expected no confirmation but got one",
1639 description,
1640 path
1641 );
1642 auth.await.unwrap();
1643 }
1644 }
1645 }
1646
1647 #[gpui::test]
1648 async fn test_needs_confirmation_with_different_modes(cx: &mut TestAppContext) {
1649 init_test(cx);
1650 let fs = project::FakeFs::new(cx.executor());
1651 fs.insert_tree(
1652 "/project",
1653 json!({
1654 "existing.txt": "content",
1655 ".zed": {
1656 "settings.json": "{}"
1657 }
1658 }),
1659 )
1660 .await;
1661 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1662 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1663 let context_server_registry =
1664 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1665 let model = Arc::new(FakeLanguageModel::default());
1666 let thread = cx.new(|cx| {
1667 Thread::new(
1668 project.clone(),
1669 cx.new(|_cx| ProjectContext::default()),
1670 context_server_registry.clone(),
1671 Templates::new(),
1672 Some(model.clone()),
1673 cx,
1674 )
1675 });
1676 let tool = Arc::new(EditFileTool::new(
1677 project.clone(),
1678 thread.downgrade(),
1679 language_registry,
1680 Templates::new(),
1681 ));
1682
1683 // Test different EditFileMode values
1684 let modes = vec![
1685 EditFileMode::Edit,
1686 EditFileMode::Create,
1687 EditFileMode::Overwrite,
1688 ];
1689
1690 for mode in modes {
1691 // Test .zed path with different modes
1692 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1693 let _auth = cx.update(|cx| {
1694 tool.authorize(
1695 &EditFileToolInput {
1696 display_description: "Edit settings".into(),
1697 path: "project/.zed/settings.json".into(),
1698 mode: mode.clone(),
1699 },
1700 &stream_tx,
1701 cx,
1702 )
1703 });
1704
1705 stream_rx.expect_authorization().await;
1706
1707 // Test outside path with different modes
1708 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1709 let _auth = cx.update(|cx| {
1710 tool.authorize(
1711 &EditFileToolInput {
1712 display_description: "Edit file".into(),
1713 path: "/outside/file.txt".into(),
1714 mode: mode.clone(),
1715 },
1716 &stream_tx,
1717 cx,
1718 )
1719 });
1720
1721 stream_rx.expect_authorization().await;
1722
1723 // Test normal path with different modes
1724 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1725 cx.update(|cx| {
1726 tool.authorize(
1727 &EditFileToolInput {
1728 display_description: "Edit file".into(),
1729 path: "project/normal.txt".into(),
1730 mode: mode.clone(),
1731 },
1732 &stream_tx,
1733 cx,
1734 )
1735 })
1736 .await
1737 .unwrap();
1738 assert!(stream_rx.try_next().is_err());
1739 }
1740 }
1741
1742 #[gpui::test]
1743 async fn test_initial_title_with_partial_input(cx: &mut TestAppContext) {
1744 init_test(cx);
1745 let fs = project::FakeFs::new(cx.executor());
1746 let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
1747 let language_registry = project.read_with(cx, |project, _cx| project.languages().clone());
1748 let context_server_registry =
1749 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1750 let model = Arc::new(FakeLanguageModel::default());
1751 let thread = cx.new(|cx| {
1752 Thread::new(
1753 project.clone(),
1754 cx.new(|_cx| ProjectContext::default()),
1755 context_server_registry,
1756 Templates::new(),
1757 Some(model.clone()),
1758 cx,
1759 )
1760 });
1761 let tool = Arc::new(EditFileTool::new(
1762 project,
1763 thread.downgrade(),
1764 language_registry,
1765 Templates::new(),
1766 ));
1767
1768 cx.update(|cx| {
1769 // ...
1770 assert_eq!(
1771 tool.initial_title(
1772 Err(json!({
1773 "path": "src/main.rs",
1774 "display_description": "",
1775 "old_string": "old code",
1776 "new_string": "new code"
1777 })),
1778 cx
1779 ),
1780 "src/main.rs"
1781 );
1782 assert_eq!(
1783 tool.initial_title(
1784 Err(json!({
1785 "path": "",
1786 "display_description": "Fix error handling",
1787 "old_string": "old code",
1788 "new_string": "new code"
1789 })),
1790 cx
1791 ),
1792 "Fix error handling"
1793 );
1794 assert_eq!(
1795 tool.initial_title(
1796 Err(json!({
1797 "path": "src/main.rs",
1798 "display_description": "Fix error handling",
1799 "old_string": "old code",
1800 "new_string": "new code"
1801 })),
1802 cx
1803 ),
1804 "src/main.rs"
1805 );
1806 assert_eq!(
1807 tool.initial_title(
1808 Err(json!({
1809 "path": "",
1810 "display_description": "",
1811 "old_string": "old code",
1812 "new_string": "new code"
1813 })),
1814 cx
1815 ),
1816 DEFAULT_UI_TEXT
1817 );
1818 assert_eq!(
1819 tool.initial_title(Err(serde_json::Value::Null), cx),
1820 DEFAULT_UI_TEXT
1821 );
1822 });
1823 }
1824
1825 #[gpui::test]
1826 async fn test_diff_finalization(cx: &mut TestAppContext) {
1827 init_test(cx);
1828 let fs = project::FakeFs::new(cx.executor());
1829 fs.insert_tree("/", json!({"main.rs": ""})).await;
1830
1831 let project = Project::test(fs.clone(), [path!("/").as_ref()], cx).await;
1832 let languages = project.read_with(cx, |project, _cx| project.languages().clone());
1833 let context_server_registry =
1834 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1835 let model = Arc::new(FakeLanguageModel::default());
1836 let thread = cx.new(|cx| {
1837 Thread::new(
1838 project.clone(),
1839 cx.new(|_cx| ProjectContext::default()),
1840 context_server_registry.clone(),
1841 Templates::new(),
1842 Some(model.clone()),
1843 cx,
1844 )
1845 });
1846
1847 // Ensure the diff is finalized after the edit completes.
1848 {
1849 let tool = Arc::new(EditFileTool::new(
1850 project.clone(),
1851 thread.downgrade(),
1852 languages.clone(),
1853 Templates::new(),
1854 ));
1855 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1856 let edit = cx.update(|cx| {
1857 tool.run(
1858 EditFileToolInput {
1859 display_description: "Edit file".into(),
1860 path: path!("/main.rs").into(),
1861 mode: EditFileMode::Edit,
1862 },
1863 stream_tx,
1864 cx,
1865 )
1866 });
1867 stream_rx.expect_update_fields().await;
1868 let diff = stream_rx.expect_diff().await;
1869 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1870 cx.run_until_parked();
1871 model.end_last_completion_stream();
1872 edit.await.unwrap();
1873 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1874 }
1875
1876 // Ensure the diff is finalized if an error occurs while editing.
1877 {
1878 model.forbid_requests();
1879 let tool = Arc::new(EditFileTool::new(
1880 project.clone(),
1881 thread.downgrade(),
1882 languages.clone(),
1883 Templates::new(),
1884 ));
1885 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1886 let edit = cx.update(|cx| {
1887 tool.run(
1888 EditFileToolInput {
1889 display_description: "Edit file".into(),
1890 path: path!("/main.rs").into(),
1891 mode: EditFileMode::Edit,
1892 },
1893 stream_tx,
1894 cx,
1895 )
1896 });
1897 stream_rx.expect_update_fields().await;
1898 let diff = stream_rx.expect_diff().await;
1899 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1900 edit.await.unwrap_err();
1901 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1902 model.allow_requests();
1903 }
1904
1905 // Ensure the diff is finalized if the tool call gets dropped.
1906 {
1907 let tool = Arc::new(EditFileTool::new(
1908 project.clone(),
1909 thread.downgrade(),
1910 languages.clone(),
1911 Templates::new(),
1912 ));
1913 let (stream_tx, mut stream_rx) = ToolCallEventStream::test();
1914 let edit = cx.update(|cx| {
1915 tool.run(
1916 EditFileToolInput {
1917 display_description: "Edit file".into(),
1918 path: path!("/main.rs").into(),
1919 mode: EditFileMode::Edit,
1920 },
1921 stream_tx,
1922 cx,
1923 )
1924 });
1925 stream_rx.expect_update_fields().await;
1926 let diff = stream_rx.expect_diff().await;
1927 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Pending(_))));
1928 drop(edit);
1929 cx.run_until_parked();
1930 diff.read_with(cx, |diff, _| assert!(matches!(diff, Diff::Finalized(_))));
1931 }
1932 }
1933
1934 #[gpui::test]
1935 async fn test_file_read_times_tracking(cx: &mut TestAppContext) {
1936 init_test(cx);
1937
1938 let fs = project::FakeFs::new(cx.executor());
1939 fs.insert_tree(
1940 "/root",
1941 json!({
1942 "test.txt": "original content"
1943 }),
1944 )
1945 .await;
1946 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
1947 let context_server_registry =
1948 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
1949 let model = Arc::new(FakeLanguageModel::default());
1950 let thread = cx.new(|cx| {
1951 Thread::new(
1952 project.clone(),
1953 cx.new(|_cx| ProjectContext::default()),
1954 context_server_registry,
1955 Templates::new(),
1956 Some(model.clone()),
1957 cx,
1958 )
1959 });
1960 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
1961
1962 // Initially, file_read_times should be empty
1963 let is_empty = thread.read_with(cx, |thread, _| thread.file_read_times.is_empty());
1964 assert!(is_empty, "file_read_times should start empty");
1965
1966 // Create read tool
1967 let read_tool = Arc::new(crate::ReadFileTool::new(
1968 thread.downgrade(),
1969 project.clone(),
1970 action_log,
1971 ));
1972
1973 // Read the file to record the read time
1974 cx.update(|cx| {
1975 read_tool.clone().run(
1976 crate::ReadFileToolInput {
1977 path: "root/test.txt".to_string(),
1978 start_line: None,
1979 end_line: None,
1980 },
1981 ToolCallEventStream::test().0,
1982 cx,
1983 )
1984 })
1985 .await
1986 .unwrap();
1987
1988 // Verify that file_read_times now contains an entry for the file
1989 let has_entry = thread.read_with(cx, |thread, _| {
1990 thread.file_read_times.len() == 1
1991 && thread
1992 .file_read_times
1993 .keys()
1994 .any(|path| path.ends_with("test.txt"))
1995 });
1996 assert!(
1997 has_entry,
1998 "file_read_times should contain an entry after reading the file"
1999 );
2000
2001 // Read the file again - should update the entry
2002 cx.update(|cx| {
2003 read_tool.clone().run(
2004 crate::ReadFileToolInput {
2005 path: "root/test.txt".to_string(),
2006 start_line: None,
2007 end_line: None,
2008 },
2009 ToolCallEventStream::test().0,
2010 cx,
2011 )
2012 })
2013 .await
2014 .unwrap();
2015
2016 // Should still have exactly one entry
2017 let has_one_entry = thread.read_with(cx, |thread, _| thread.file_read_times.len() == 1);
2018 assert!(
2019 has_one_entry,
2020 "file_read_times should still have one entry after re-reading"
2021 );
2022 }
2023
2024 fn init_test(cx: &mut TestAppContext) {
2025 cx.update(|cx| {
2026 let settings_store = SettingsStore::test(cx);
2027 cx.set_global(settings_store);
2028 });
2029 }
2030
2031 #[gpui::test]
2032 async fn test_consecutive_edits_work(cx: &mut TestAppContext) {
2033 init_test(cx);
2034
2035 let fs = project::FakeFs::new(cx.executor());
2036 fs.insert_tree(
2037 "/root",
2038 json!({
2039 "test.txt": "original content"
2040 }),
2041 )
2042 .await;
2043 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2044 let context_server_registry =
2045 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2046 let model = Arc::new(FakeLanguageModel::default());
2047 let thread = cx.new(|cx| {
2048 Thread::new(
2049 project.clone(),
2050 cx.new(|_cx| ProjectContext::default()),
2051 context_server_registry,
2052 Templates::new(),
2053 Some(model.clone()),
2054 cx,
2055 )
2056 });
2057 let languages = project.read_with(cx, |project, _| project.languages().clone());
2058 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2059
2060 let read_tool = Arc::new(crate::ReadFileTool::new(
2061 thread.downgrade(),
2062 project.clone(),
2063 action_log,
2064 ));
2065 let edit_tool = Arc::new(EditFileTool::new(
2066 project.clone(),
2067 thread.downgrade(),
2068 languages,
2069 Templates::new(),
2070 ));
2071
2072 // Read the file first
2073 cx.update(|cx| {
2074 read_tool.clone().run(
2075 crate::ReadFileToolInput {
2076 path: "root/test.txt".to_string(),
2077 start_line: None,
2078 end_line: None,
2079 },
2080 ToolCallEventStream::test().0,
2081 cx,
2082 )
2083 })
2084 .await
2085 .unwrap();
2086
2087 // First edit should work
2088 let edit_result = {
2089 let edit_task = cx.update(|cx| {
2090 edit_tool.clone().run(
2091 EditFileToolInput {
2092 display_description: "First edit".into(),
2093 path: "root/test.txt".into(),
2094 mode: EditFileMode::Edit,
2095 },
2096 ToolCallEventStream::test().0,
2097 cx,
2098 )
2099 });
2100
2101 cx.executor().run_until_parked();
2102 model.send_last_completion_stream_text_chunk(
2103 "<old_text>original content</old_text><new_text>modified content</new_text>"
2104 .to_string(),
2105 );
2106 model.end_last_completion_stream();
2107
2108 edit_task.await
2109 };
2110 assert!(
2111 edit_result.is_ok(),
2112 "First edit should succeed, got error: {:?}",
2113 edit_result.as_ref().err()
2114 );
2115
2116 // Second edit should also work because the edit updated the recorded read time
2117 let edit_result = {
2118 let edit_task = cx.update(|cx| {
2119 edit_tool.clone().run(
2120 EditFileToolInput {
2121 display_description: "Second edit".into(),
2122 path: "root/test.txt".into(),
2123 mode: EditFileMode::Edit,
2124 },
2125 ToolCallEventStream::test().0,
2126 cx,
2127 )
2128 });
2129
2130 cx.executor().run_until_parked();
2131 model.send_last_completion_stream_text_chunk(
2132 "<old_text>modified content</old_text><new_text>further modified content</new_text>".to_string(),
2133 );
2134 model.end_last_completion_stream();
2135
2136 edit_task.await
2137 };
2138 assert!(
2139 edit_result.is_ok(),
2140 "Second consecutive edit should succeed, got error: {:?}",
2141 edit_result.as_ref().err()
2142 );
2143 }
2144
2145 #[gpui::test]
2146 async fn test_external_modification_detected(cx: &mut TestAppContext) {
2147 init_test(cx);
2148
2149 let fs = project::FakeFs::new(cx.executor());
2150 fs.insert_tree(
2151 "/root",
2152 json!({
2153 "test.txt": "original content"
2154 }),
2155 )
2156 .await;
2157 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2158 let context_server_registry =
2159 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2160 let model = Arc::new(FakeLanguageModel::default());
2161 let thread = cx.new(|cx| {
2162 Thread::new(
2163 project.clone(),
2164 cx.new(|_cx| ProjectContext::default()),
2165 context_server_registry,
2166 Templates::new(),
2167 Some(model.clone()),
2168 cx,
2169 )
2170 });
2171 let languages = project.read_with(cx, |project, _| project.languages().clone());
2172 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2173
2174 let read_tool = Arc::new(crate::ReadFileTool::new(
2175 thread.downgrade(),
2176 project.clone(),
2177 action_log,
2178 ));
2179 let edit_tool = Arc::new(EditFileTool::new(
2180 project.clone(),
2181 thread.downgrade(),
2182 languages,
2183 Templates::new(),
2184 ));
2185
2186 // Read the file first
2187 cx.update(|cx| {
2188 read_tool.clone().run(
2189 crate::ReadFileToolInput {
2190 path: "root/test.txt".to_string(),
2191 start_line: None,
2192 end_line: None,
2193 },
2194 ToolCallEventStream::test().0,
2195 cx,
2196 )
2197 })
2198 .await
2199 .unwrap();
2200
2201 // Simulate external modification - advance time and save file
2202 cx.background_executor
2203 .advance_clock(std::time::Duration::from_secs(2));
2204 fs.save(
2205 path!("/root/test.txt").as_ref(),
2206 &"externally modified content".into(),
2207 language::LineEnding::Unix,
2208 )
2209 .await
2210 .unwrap();
2211
2212 // Reload the buffer to pick up the new mtime
2213 let project_path = project
2214 .read_with(cx, |project, cx| {
2215 project.find_project_path("root/test.txt", cx)
2216 })
2217 .expect("Should find project path");
2218 let buffer = project
2219 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2220 .await
2221 .unwrap();
2222 buffer
2223 .update(cx, |buffer, cx| buffer.reload(cx))
2224 .await
2225 .unwrap();
2226
2227 cx.executor().run_until_parked();
2228
2229 // Try to edit - should fail because file was modified externally
2230 let result = cx
2231 .update(|cx| {
2232 edit_tool.clone().run(
2233 EditFileToolInput {
2234 display_description: "Edit after external change".into(),
2235 path: "root/test.txt".into(),
2236 mode: EditFileMode::Edit,
2237 },
2238 ToolCallEventStream::test().0,
2239 cx,
2240 )
2241 })
2242 .await;
2243
2244 assert!(
2245 result.is_err(),
2246 "Edit should fail after external modification"
2247 );
2248 let error_msg = result.unwrap_err().to_string();
2249 assert!(
2250 error_msg.contains("has been modified since you last read it"),
2251 "Error should mention file modification, got: {}",
2252 error_msg
2253 );
2254 }
2255
2256 #[gpui::test]
2257 async fn test_dirty_buffer_detected(cx: &mut TestAppContext) {
2258 init_test(cx);
2259
2260 let fs = project::FakeFs::new(cx.executor());
2261 fs.insert_tree(
2262 "/root",
2263 json!({
2264 "test.txt": "original content"
2265 }),
2266 )
2267 .await;
2268 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
2269 let context_server_registry =
2270 cx.new(|cx| ContextServerRegistry::new(project.read(cx).context_server_store(), cx));
2271 let model = Arc::new(FakeLanguageModel::default());
2272 let thread = cx.new(|cx| {
2273 Thread::new(
2274 project.clone(),
2275 cx.new(|_cx| ProjectContext::default()),
2276 context_server_registry,
2277 Templates::new(),
2278 Some(model.clone()),
2279 cx,
2280 )
2281 });
2282 let languages = project.read_with(cx, |project, _| project.languages().clone());
2283 let action_log = thread.read_with(cx, |thread, _| thread.action_log().clone());
2284
2285 let read_tool = Arc::new(crate::ReadFileTool::new(
2286 thread.downgrade(),
2287 project.clone(),
2288 action_log,
2289 ));
2290 let edit_tool = Arc::new(EditFileTool::new(
2291 project.clone(),
2292 thread.downgrade(),
2293 languages,
2294 Templates::new(),
2295 ));
2296
2297 // Read the file first
2298 cx.update(|cx| {
2299 read_tool.clone().run(
2300 crate::ReadFileToolInput {
2301 path: "root/test.txt".to_string(),
2302 start_line: None,
2303 end_line: None,
2304 },
2305 ToolCallEventStream::test().0,
2306 cx,
2307 )
2308 })
2309 .await
2310 .unwrap();
2311
2312 // Open the buffer and make it dirty by editing without saving
2313 let project_path = project
2314 .read_with(cx, |project, cx| {
2315 project.find_project_path("root/test.txt", cx)
2316 })
2317 .expect("Should find project path");
2318 let buffer = project
2319 .update(cx, |project, cx| project.open_buffer(project_path, cx))
2320 .await
2321 .unwrap();
2322
2323 // Make an in-memory edit to the buffer (making it dirty)
2324 buffer.update(cx, |buffer, cx| {
2325 let end_point = buffer.max_point();
2326 buffer.edit([(end_point..end_point, " added text")], None, cx);
2327 });
2328
2329 // Verify buffer is dirty
2330 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
2331 assert!(is_dirty, "Buffer should be dirty after in-memory edit");
2332
2333 // Try to edit - should fail because buffer has unsaved changes
2334 let result = cx
2335 .update(|cx| {
2336 edit_tool.clone().run(
2337 EditFileToolInput {
2338 display_description: "Edit with dirty buffer".into(),
2339 path: "root/test.txt".into(),
2340 mode: EditFileMode::Edit,
2341 },
2342 ToolCallEventStream::test().0,
2343 cx,
2344 )
2345 })
2346 .await;
2347
2348 assert!(result.is_err(), "Edit should fail when buffer is dirty");
2349 let error_msg = result.unwrap_err().to_string();
2350 assert!(
2351 error_msg.contains("This file has unsaved changes."),
2352 "Error should mention unsaved changes, got: {}",
2353 error_msg
2354 );
2355 assert!(
2356 error_msg.contains("keep or discard"),
2357 "Error should ask whether to keep or discard changes, got: {}",
2358 error_msg
2359 );
2360 // Since save_file and restore_file_from_disk tools aren't added to the thread,
2361 // the error message should ask the user to manually save or revert
2362 assert!(
2363 error_msg.contains("save or revert the file manually"),
2364 "Error should ask user to manually save or revert when tools aren't available, got: {}",
2365 error_msg
2366 );
2367 }
2368
2369 #[test]
2370 fn test_sensitive_settings_kind_detects_nonexistent_subdirectory() {
2371 let config_dir = paths::config_dir();
2372 let path = config_dir.join("nonexistent_subdir_xyz").join("evil.json");
2373 assert!(
2374 matches!(
2375 sensitive_settings_kind(&path),
2376 Some(SensitiveSettingsKind::Global)
2377 ),
2378 "Path in non-existent subdirectory of config dir should be detected as sensitive: {:?}",
2379 path
2380 );
2381 }
2382
2383 #[test]
2384 fn test_sensitive_settings_kind_detects_deeply_nested_nonexistent_subdirectory() {
2385 let config_dir = paths::config_dir();
2386 let path = config_dir.join("a").join("b").join("c").join("evil.json");
2387 assert!(
2388 matches!(
2389 sensitive_settings_kind(&path),
2390 Some(SensitiveSettingsKind::Global)
2391 ),
2392 "Path in deeply nested non-existent subdirectory of config dir should be detected as sensitive: {:?}",
2393 path
2394 );
2395 }
2396
2397 #[test]
2398 fn test_sensitive_settings_kind_returns_none_for_non_config_path() {
2399 let path = PathBuf::from("/tmp/not_a_config_dir/some_file.json");
2400 assert!(
2401 sensitive_settings_kind(&path).is_none(),
2402 "Path outside config dir should not be detected as sensitive: {:?}",
2403 path
2404 );
2405 }
2406}