restore_file_from_disk_tool.rs

  1use super::tool_permissions::{
  2    ResolvedProjectPath, SensitiveSettingsKind, authorize_symlink_access,
  3    canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path,
  4    sensitive_settings_kind,
  5};
  6use agent_client_protocol as acp;
  7use agent_settings::AgentSettings;
  8use anyhow::Result;
  9use collections::FxHashSet;
 10use futures::FutureExt as _;
 11use gpui::{App, Entity, SharedString, Task};
 12use language::Buffer;
 13use project::Project;
 14use schemars::JsonSchema;
 15use serde::{Deserialize, Serialize};
 16use settings::Settings;
 17use std::path::{Path, PathBuf};
 18use std::sync::Arc;
 19use util::markdown::MarkdownInlineCode;
 20
 21use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path};
 22
 23/// Discards unsaved changes in open buffers by reloading file contents from disk.
 24///
 25/// Use this tool when:
 26/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 27/// - You want to reset files to the on-disk state before retrying an edit.
 28///
 29/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 30#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 31pub struct RestoreFileFromDiskToolInput {
 32    /// The paths of the files to restore from disk.
 33    pub paths: Vec<PathBuf>,
 34}
 35
 36pub struct RestoreFileFromDiskTool {
 37    project: Entity<Project>,
 38}
 39
 40impl RestoreFileFromDiskTool {
 41    pub fn new(project: Entity<Project>) -> Self {
 42        Self { project }
 43    }
 44}
 45
 46impl AgentTool for RestoreFileFromDiskTool {
 47    type Input = RestoreFileFromDiskToolInput;
 48    type Output = String;
 49
 50    const NAME: &'static str = "restore_file_from_disk";
 51
 52    fn kind() -> acp::ToolKind {
 53        acp::ToolKind::Other
 54    }
 55
 56    fn initial_title(
 57        &self,
 58        input: Result<Self::Input, serde_json::Value>,
 59        _cx: &mut App,
 60    ) -> SharedString {
 61        match input {
 62            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 63            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 64            Err(_) => "Restore files from disk".into(),
 65        }
 66    }
 67
 68    fn run(
 69        self: Arc<Self>,
 70        input: Self::Input,
 71        event_stream: ToolCallEventStream,
 72        cx: &mut App,
 73    ) -> Task<Result<String>> {
 74        let settings = AgentSettings::get_global(cx).clone();
 75
 76        // Check for any immediate deny before spawning async work.
 77        for path in &input.paths {
 78            let path_str = path.to_string_lossy();
 79            let decision = decide_permission_for_path(Self::NAME, &path_str, &settings);
 80            if let ToolPermissionDecision::Deny(reason) = decision {
 81                return Task::ready(Err(anyhow::anyhow!("{}", reason)));
 82            }
 83        }
 84
 85        let project = self.project.clone();
 86        let input_paths = input.paths;
 87
 88        cx.spawn(async move |cx| {
 89            let fs = project.read_with(cx, |project, _cx| project.fs().clone());
 90            let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
 91
 92            let mut confirmation_paths: Vec<String> = Vec::new();
 93
 94            for path in &input_paths {
 95                let path_str = path.to_string_lossy();
 96                let decision = decide_permission_for_path(Self::NAME, &path_str, &settings);
 97                let symlink_escape = project.read_with(cx, |project, cx| {
 98                    path_has_symlink_escape(project, path, &canonical_roots, cx)
 99                });
100
101                match decision {
102                    ToolPermissionDecision::Allow => {
103                        if !symlink_escape {
104                            let is_sensitive = super::tool_permissions::is_sensitive_settings_path(
105                                Path::new(&*path_str),
106                                fs.as_ref(),
107                            )
108                            .await;
109                            if is_sensitive {
110                                confirmation_paths.push(path_str.to_string());
111                            }
112                        }
113                    }
114                    ToolPermissionDecision::Deny(reason) => {
115                        return Err(anyhow::anyhow!("{}", reason));
116                    }
117                    ToolPermissionDecision::Confirm => {
118                        if !symlink_escape {
119                            confirmation_paths.push(path_str.to_string());
120                        }
121                    }
122                }
123            }
124
125            if !confirmation_paths.is_empty() {
126                let title = if confirmation_paths.len() == 1 {
127                    format!(
128                        "Restore {} from disk",
129                        MarkdownInlineCode(&confirmation_paths[0])
130                    )
131                } else {
132                    let paths: Vec<_> = confirmation_paths
133                        .iter()
134                        .take(3)
135                        .map(|p| p.as_str())
136                        .collect();
137                    if confirmation_paths.len() > 3 {
138                        format!(
139                            "Restore {}, and {} more from disk",
140                            paths.join(", "),
141                            confirmation_paths.len() - 3
142                        )
143                    } else {
144                        format!("Restore {} from disk", paths.join(", "))
145                    }
146                };
147
148                let mut settings_kind = None;
149                for p in &confirmation_paths {
150                    if let Some(kind) = sensitive_settings_kind(Path::new(p), fs.as_ref()).await {
151                        settings_kind = Some(kind);
152                        break;
153                    }
154                }
155                let title = match settings_kind {
156                    Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
157                    Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
158                    None => title,
159                };
160                let context = crate::ToolPermissionContext::new(Self::NAME, confirmation_paths);
161                let authorize = cx.update(|cx| event_stream.authorize(title, context, cx));
162                authorize.await?;
163            }
164            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
165
166            let mut restored_paths: Vec<PathBuf> = Vec::new();
167            let mut clean_paths: Vec<PathBuf> = Vec::new();
168            let mut not_found_paths: Vec<PathBuf> = Vec::new();
169            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
170            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
171            let mut reload_errors: Vec<String> = Vec::new();
172
173            for path in input_paths {
174                let project_path = match project.read_with(cx, |project, cx| {
175                    resolve_project_path(project, &path, &canonical_roots, cx)
176                }) {
177                    Ok(resolved) => {
178                        let (project_path, symlink_canonical_target) = match resolved {
179                            ResolvedProjectPath::Safe(path) => (path, None),
180                            ResolvedProjectPath::SymlinkEscape {
181                                project_path,
182                                canonical_target,
183                            } => (project_path, Some(canonical_target)),
184                        };
185                        if let Some(canonical_target) = &symlink_canonical_target {
186                            let path_str = path.to_string_lossy();
187                            let authorize_task = cx.update(|cx| {
188                                authorize_symlink_access(
189                                    Self::NAME,
190                                    &path_str,
191                                    canonical_target,
192                                    &event_stream,
193                                    cx,
194                                )
195                            });
196                            let result = authorize_task.await;
197                            if let Err(err) = result {
198                                reload_errors.push(format!("{}: {}", path.to_string_lossy(), err));
199                                continue;
200                            }
201                        }
202                        project_path
203                    }
204                    Err(_) => {
205                        not_found_paths.push(path);
206                        continue;
207                    }
208                };
209
210                let open_buffer_task =
211                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
212
213                let buffer = futures::select! {
214                    result = open_buffer_task.fuse() => {
215                        match result {
216                            Ok(buffer) => buffer,
217                            Err(error) => {
218                                open_errors.push((path, error.to_string()));
219                                continue;
220                            }
221                        }
222                    }
223                    _ = event_stream.cancelled_by_user().fuse() => {
224                        anyhow::bail!("Restore cancelled by user");
225                    }
226                };
227
228                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
229
230                if is_dirty {
231                    buffers_to_reload.insert(buffer);
232                    restored_paths.push(path);
233                } else {
234                    clean_paths.push(path);
235                }
236            }
237
238            if !buffers_to_reload.is_empty() {
239                let reload_task = project.update(cx, |project, cx| {
240                    project.reload_buffers(buffers_to_reload, true, cx)
241                });
242
243                let result = futures::select! {
244                    result = reload_task.fuse() => result,
245                    _ = event_stream.cancelled_by_user().fuse() => {
246                        anyhow::bail!("Restore cancelled by user");
247                    }
248                };
249                if let Err(error) = result {
250                    reload_errors.push(error.to_string());
251                }
252            }
253
254            let mut lines: Vec<String> = Vec::new();
255
256            if !restored_paths.is_empty() {
257                lines.push(format!("Restored {} file(s).", restored_paths.len()));
258            }
259            if !clean_paths.is_empty() {
260                lines.push(format!("{} clean.", clean_paths.len()));
261            }
262
263            if !not_found_paths.is_empty() {
264                lines.push(format!("Not found ({}):", not_found_paths.len()));
265                for path in &not_found_paths {
266                    lines.push(format!("- {}", path.display()));
267                }
268            }
269            if !open_errors.is_empty() {
270                lines.push(format!("Open failed ({}):", open_errors.len()));
271                for (path, error) in &open_errors {
272                    lines.push(format!("- {}: {}", path.display(), error));
273                }
274            }
275            if !dirty_check_errors.is_empty() {
276                lines.push(format!(
277                    "Dirty check failed ({}):",
278                    dirty_check_errors.len()
279                ));
280                for (path, error) in &dirty_check_errors {
281                    lines.push(format!("- {}: {}", path.display(), error));
282                }
283            }
284            if !reload_errors.is_empty() {
285                lines.push(format!("Reload failed ({}):", reload_errors.len()));
286                for error in &reload_errors {
287                    lines.push(format!("- {}", error));
288                }
289            }
290
291            if lines.is_empty() {
292                Ok("No paths provided.".to_string())
293            } else {
294                Ok(lines.join("\n"))
295            }
296        })
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use fs::Fs as _;
304    use gpui::TestAppContext;
305    use language::LineEnding;
306    use project::FakeFs;
307    use serde_json::json;
308    use settings::SettingsStore;
309    use util::path;
310
311    fn init_test(cx: &mut TestAppContext) {
312        cx.update(|cx| {
313            let settings_store = SettingsStore::test(cx);
314            cx.set_global(settings_store);
315        });
316        cx.update(|cx| {
317            let mut settings = AgentSettings::get_global(cx).clone();
318            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
319            AgentSettings::override_global(settings, cx);
320        });
321    }
322
323    #[gpui::test]
324    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
325        init_test(cx);
326
327        let fs = FakeFs::new(cx.executor());
328        fs.insert_tree(
329            "/root",
330            json!({
331                "dirty.txt": "on disk: dirty\n",
332                "clean.txt": "on disk: clean\n",
333            }),
334        )
335        .await;
336
337        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
338        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
339
340        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
341        let dirty_project_path = project.read_with(cx, |project, cx| {
342            project
343                .find_project_path("root/dirty.txt", cx)
344                .expect("dirty.txt should exist in project")
345        });
346
347        let dirty_buffer = project
348            .update(cx, |project, cx| {
349                project.open_buffer(dirty_project_path, cx)
350            })
351            .await
352            .unwrap();
353        dirty_buffer.update(cx, |buffer, cx| {
354            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
355        });
356        assert!(
357            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
358            "dirty.txt buffer should be dirty before restore"
359        );
360
361        // Ensure clean.txt is opened but remains clean.
362        let clean_project_path = project.read_with(cx, |project, cx| {
363            project
364                .find_project_path("root/clean.txt", cx)
365                .expect("clean.txt should exist in project")
366        });
367
368        let clean_buffer = project
369            .update(cx, |project, cx| {
370                project.open_buffer(clean_project_path, cx)
371            })
372            .await
373            .unwrap();
374        assert!(
375            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
376            "clean.txt buffer should start clean"
377        );
378
379        let output = cx
380            .update(|cx| {
381                tool.clone().run(
382                    RestoreFileFromDiskToolInput {
383                        paths: vec![
384                            PathBuf::from("root/dirty.txt"),
385                            PathBuf::from("root/clean.txt"),
386                        ],
387                    },
388                    ToolCallEventStream::test().0,
389                    cx,
390                )
391            })
392            .await
393            .unwrap();
394
395        // Output should mention restored + clean.
396        assert!(
397            output.contains("Restored 1 file(s)."),
398            "expected restored count line, got:\n{output}"
399        );
400        assert!(
401            output.contains("1 clean."),
402            "expected clean count line, got:\n{output}"
403        );
404
405        // Effect: dirty buffer should be restored back to disk content and become clean.
406        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
407        assert_eq!(
408            dirty_text, "on disk: dirty\n",
409            "dirty.txt buffer should be restored to disk contents"
410        );
411        assert!(
412            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
413            "dirty.txt buffer should not be dirty after restore"
414        );
415
416        // Disk contents should be unchanged (restore-from-disk should not write).
417        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
418        assert_eq!(disk_dirty, "on disk: dirty\n");
419
420        // Sanity: clean buffer should remain clean and unchanged.
421        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
422        assert_eq!(clean_text, "on disk: clean\n");
423        assert!(
424            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
425            "clean.txt buffer should remain clean"
426        );
427
428        // Test empty paths case.
429        let output = cx
430            .update(|cx| {
431                tool.clone().run(
432                    RestoreFileFromDiskToolInput { paths: vec![] },
433                    ToolCallEventStream::test().0,
434                    cx,
435                )
436            })
437            .await
438            .unwrap();
439        assert_eq!(output, "No paths provided.");
440
441        // Test not-found path case (path outside the project root).
442        let output = cx
443            .update(|cx| {
444                tool.clone().run(
445                    RestoreFileFromDiskToolInput {
446                        paths: vec![PathBuf::from("nonexistent/path.txt")],
447                    },
448                    ToolCallEventStream::test().0,
449                    cx,
450                )
451            })
452            .await
453            .unwrap();
454        assert!(
455            output.contains("Not found (1):"),
456            "expected not-found header line, got:\n{output}"
457        );
458        assert!(
459            output.contains("- nonexistent/path.txt"),
460            "expected not-found path bullet, got:\n{output}"
461        );
462
463        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
464    }
465
466    #[gpui::test]
467    async fn test_restore_file_symlink_escape_requests_authorization(cx: &mut TestAppContext) {
468        init_test(cx);
469
470        let fs = FakeFs::new(cx.executor());
471        fs.insert_tree(
472            path!("/root"),
473            json!({
474                "project": {
475                    "src": {}
476                },
477                "external": {
478                    "secret.txt": "secret content"
479                }
480            }),
481        )
482        .await;
483
484        fs.create_symlink(
485            path!("/root/project/link.txt").as_ref(),
486            PathBuf::from("../external/secret.txt"),
487        )
488        .await
489        .unwrap();
490
491        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
492        cx.executor().run_until_parked();
493
494        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
495
496        let (event_stream, mut event_rx) = ToolCallEventStream::test();
497        let task = cx.update(|cx| {
498            tool.clone().run(
499                RestoreFileFromDiskToolInput {
500                    paths: vec![PathBuf::from("project/link.txt")],
501                },
502                event_stream,
503                cx,
504            )
505        });
506
507        cx.run_until_parked();
508
509        let auth = event_rx.expect_authorization().await;
510        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
511        assert!(
512            title.contains("points outside the project"),
513            "Expected symlink escape authorization, got: {title}",
514        );
515
516        auth.response
517            .send(acp::PermissionOptionId::new("allow"))
518            .unwrap();
519
520        let _result = task.await;
521    }
522
523    #[gpui::test]
524    async fn test_restore_file_symlink_escape_honors_deny_policy(cx: &mut TestAppContext) {
525        init_test(cx);
526        cx.update(|cx| {
527            let mut settings = AgentSettings::get_global(cx).clone();
528            settings.tool_permissions.tools.insert(
529                "restore_file_from_disk".into(),
530                agent_settings::ToolRules {
531                    default: Some(settings::ToolPermissionMode::Deny),
532                    ..Default::default()
533                },
534            );
535            AgentSettings::override_global(settings, cx);
536        });
537
538        let fs = FakeFs::new(cx.executor());
539        fs.insert_tree(
540            path!("/root"),
541            json!({
542                "project": {
543                    "src": {}
544                },
545                "external": {
546                    "secret.txt": "secret content"
547                }
548            }),
549        )
550        .await;
551
552        fs.create_symlink(
553            path!("/root/project/link.txt").as_ref(),
554            PathBuf::from("../external/secret.txt"),
555        )
556        .await
557        .unwrap();
558
559        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
560        cx.executor().run_until_parked();
561
562        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
563
564        let (event_stream, mut event_rx) = ToolCallEventStream::test();
565        let result = cx
566            .update(|cx| {
567                tool.clone().run(
568                    RestoreFileFromDiskToolInput {
569                        paths: vec![PathBuf::from("project/link.txt")],
570                    },
571                    event_stream,
572                    cx,
573                )
574            })
575            .await;
576
577        assert!(result.is_err(), "Tool should fail when policy denies");
578        assert!(
579            !matches!(
580                event_rx.try_next(),
581                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
582            ),
583            "Deny policy should not emit symlink authorization prompt",
584        );
585    }
586
587    #[gpui::test]
588    async fn test_restore_file_symlink_escape_confirm_requires_single_approval(
589        cx: &mut TestAppContext,
590    ) {
591        init_test(cx);
592        cx.update(|cx| {
593            let mut settings = AgentSettings::get_global(cx).clone();
594            settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
595            AgentSettings::override_global(settings, cx);
596        });
597
598        let fs = FakeFs::new(cx.executor());
599        fs.insert_tree(
600            path!("/root"),
601            json!({
602                "project": {
603                    "src": {}
604                },
605                "external": {
606                    "secret.txt": "secret content"
607                }
608            }),
609        )
610        .await;
611
612        fs.create_symlink(
613            path!("/root/project/link.txt").as_ref(),
614            PathBuf::from("../external/secret.txt"),
615        )
616        .await
617        .unwrap();
618
619        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
620        cx.executor().run_until_parked();
621
622        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
623
624        let (event_stream, mut event_rx) = ToolCallEventStream::test();
625        let task = cx.update(|cx| {
626            tool.clone().run(
627                RestoreFileFromDiskToolInput {
628                    paths: vec![PathBuf::from("project/link.txt")],
629                },
630                event_stream,
631                cx,
632            )
633        });
634
635        cx.run_until_parked();
636
637        let auth = event_rx.expect_authorization().await;
638        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
639        assert!(
640            title.contains("points outside the project"),
641            "Expected symlink escape authorization, got: {title}",
642        );
643
644        auth.response
645            .send(acp::PermissionOptionId::new("allow"))
646            .unwrap();
647
648        assert!(
649            !matches!(
650                event_rx.try_next(),
651                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
652            ),
653            "Expected a single authorization prompt",
654        );
655
656        let _result = task.await;
657    }
658}