restore_file_from_disk_tool.rs

  1use super::tool_permissions::{
  2    ResolvedProjectPath, SensitiveSettingsKind, authorize_symlink_access,
  3    canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path,
  4    sensitive_settings_kind,
  5};
  6use agent_client_protocol as acp;
  7use agent_settings::AgentSettings;
  8use collections::FxHashSet;
  9use futures::FutureExt as _;
 10use gpui::{App, Entity, SharedString, Task};
 11use language::Buffer;
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use settings::Settings;
 16use std::path::{Path, PathBuf};
 17use std::sync::Arc;
 18use util::markdown::MarkdownInlineCode;
 19
 20use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path};
 21
 22/// Discards unsaved changes in open buffers by reloading file contents from disk.
 23///
 24/// Use this tool when:
 25/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 26/// - You want to reset files to the on-disk state before retrying an edit.
 27///
 28/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 29#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 30pub struct RestoreFileFromDiskToolInput {
 31    /// The paths of the files to restore from disk.
 32    pub paths: Vec<PathBuf>,
 33}
 34
 35pub struct RestoreFileFromDiskTool {
 36    project: Entity<Project>,
 37}
 38
 39impl RestoreFileFromDiskTool {
 40    pub fn new(project: Entity<Project>) -> Self {
 41        Self { project }
 42    }
 43}
 44
 45impl AgentTool for RestoreFileFromDiskTool {
 46    type Input = RestoreFileFromDiskToolInput;
 47    type Output = String;
 48
 49    const NAME: &'static str = "restore_file_from_disk";
 50
 51    fn kind() -> acp::ToolKind {
 52        acp::ToolKind::Other
 53    }
 54
 55    fn initial_title(
 56        &self,
 57        input: Result<Self::Input, serde_json::Value>,
 58        _cx: &mut App,
 59    ) -> SharedString {
 60        match input {
 61            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 62            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 63            Err(_) => "Restore files from disk".into(),
 64        }
 65    }
 66
 67    fn run(
 68        self: Arc<Self>,
 69        input: Self::Input,
 70        event_stream: ToolCallEventStream,
 71        cx: &mut App,
 72    ) -> Task<Result<String, String>> {
 73        let settings = AgentSettings::get_global(cx).clone();
 74
 75        // Check for any immediate deny before spawning async work.
 76        for path in &input.paths {
 77            let path_str = path.to_string_lossy();
 78            let decision = decide_permission_for_path(Self::NAME, &path_str, &settings);
 79            if let ToolPermissionDecision::Deny(reason) = decision {
 80                return Task::ready(Err(reason));
 81            }
 82        }
 83
 84        let project = self.project.clone();
 85        let input_paths = input.paths;
 86
 87        cx.spawn(async move |cx| {
 88            let fs = project.read_with(cx, |project, _cx| project.fs().clone());
 89            let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
 90
 91            let mut confirmation_paths: Vec<String> = Vec::new();
 92
 93            for path in &input_paths {
 94                let path_str = path.to_string_lossy();
 95                let decision = decide_permission_for_path(Self::NAME, &path_str, &settings);
 96                let symlink_escape = project.read_with(cx, |project, cx| {
 97                    path_has_symlink_escape(project, path, &canonical_roots, cx)
 98                });
 99
100                match decision {
101                    ToolPermissionDecision::Allow => {
102                        if !symlink_escape {
103                            let is_sensitive = super::tool_permissions::is_sensitive_settings_path(
104                                Path::new(&*path_str),
105                                fs.as_ref(),
106                            )
107                            .await;
108                            if is_sensitive {
109                                confirmation_paths.push(path_str.to_string());
110                            }
111                        }
112                    }
113                    ToolPermissionDecision::Deny(reason) => {
114                        return Err(reason);
115                    }
116                    ToolPermissionDecision::Confirm => {
117                        if !symlink_escape {
118                            confirmation_paths.push(path_str.to_string());
119                        }
120                    }
121                }
122            }
123
124            if !confirmation_paths.is_empty() {
125                let title = if confirmation_paths.len() == 1 {
126                    format!(
127                        "Restore {} from disk",
128                        MarkdownInlineCode(&confirmation_paths[0])
129                    )
130                } else {
131                    let paths: Vec<_> = confirmation_paths
132                        .iter()
133                        .take(3)
134                        .map(|p| p.as_str())
135                        .collect();
136                    if confirmation_paths.len() > 3 {
137                        format!(
138                            "Restore {}, and {} more from disk",
139                            paths.join(", "),
140                            confirmation_paths.len() - 3
141                        )
142                    } else {
143                        format!("Restore {} from disk", paths.join(", "))
144                    }
145                };
146
147                let mut settings_kind = None;
148                for p in &confirmation_paths {
149                    if let Some(kind) = sensitive_settings_kind(Path::new(p), fs.as_ref()).await {
150                        settings_kind = Some(kind);
151                        break;
152                    }
153                }
154                let title = match settings_kind {
155                    Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
156                    Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
157                    None => title,
158                };
159                let context = crate::ToolPermissionContext::new(Self::NAME, confirmation_paths);
160                let authorize = cx.update(|cx| event_stream.authorize(title, context, cx));
161                authorize.await.map_err(|e| e.to_string())?;
162            }
163            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
164
165            let mut restored_paths: Vec<PathBuf> = Vec::new();
166            let mut clean_paths: Vec<PathBuf> = Vec::new();
167            let mut not_found_paths: Vec<PathBuf> = Vec::new();
168            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
169            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
170            let mut reload_errors: Vec<String> = Vec::new();
171
172            for path in input_paths {
173                let project_path = match project.read_with(cx, |project, cx| {
174                    resolve_project_path(project, &path, &canonical_roots, cx)
175                }) {
176                    Ok(resolved) => {
177                        let (project_path, symlink_canonical_target) = match resolved {
178                            ResolvedProjectPath::Safe(path) => (path, None),
179                            ResolvedProjectPath::SymlinkEscape {
180                                project_path,
181                                canonical_target,
182                            } => (project_path, Some(canonical_target)),
183                        };
184                        if let Some(canonical_target) = &symlink_canonical_target {
185                            let path_str = path.to_string_lossy();
186                            let authorize_task = cx.update(|cx| {
187                                authorize_symlink_access(
188                                    Self::NAME,
189                                    &path_str,
190                                    canonical_target,
191                                    &event_stream,
192                                    cx,
193                                )
194                            });
195                            let result = authorize_task.await;
196                            if let Err(err) = result {
197                                reload_errors.push(format!("{}: {}", path.to_string_lossy(), err));
198                                continue;
199                            }
200                        }
201                        project_path
202                    }
203                    Err(_) => {
204                        not_found_paths.push(path);
205                        continue;
206                    }
207                };
208
209                let open_buffer_task =
210                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
211
212                let buffer = futures::select! {
213                    result = open_buffer_task.fuse() => {
214                        match result {
215                            Ok(buffer) => buffer,
216                            Err(error) => {
217                                open_errors.push((path, error.to_string()));
218                                continue;
219                            }
220                        }
221                    }
222                    _ = event_stream.cancelled_by_user().fuse() => {
223                        return Err("Restore cancelled by user".to_string());
224                    }
225                };
226
227                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
228
229                if is_dirty {
230                    buffers_to_reload.insert(buffer);
231                    restored_paths.push(path);
232                } else {
233                    clean_paths.push(path);
234                }
235            }
236
237            if !buffers_to_reload.is_empty() {
238                let reload_task = project.update(cx, |project, cx| {
239                    project.reload_buffers(buffers_to_reload, true, cx)
240                });
241
242                let result = futures::select! {
243                    result = reload_task.fuse() => result,
244                    _ = event_stream.cancelled_by_user().fuse() => {
245                        return Err("Restore cancelled by user".to_string());
246                    }
247                };
248                if let Err(error) = result {
249                    reload_errors.push(error.to_string());
250                }
251            }
252
253            let mut lines: Vec<String> = Vec::new();
254
255            if !restored_paths.is_empty() {
256                lines.push(format!("Restored {} file(s).", restored_paths.len()));
257            }
258            if !clean_paths.is_empty() {
259                lines.push(format!("{} clean.", clean_paths.len()));
260            }
261
262            if !not_found_paths.is_empty() {
263                lines.push(format!("Not found ({}):", not_found_paths.len()));
264                for path in &not_found_paths {
265                    lines.push(format!("- {}", path.display()));
266                }
267            }
268            if !open_errors.is_empty() {
269                lines.push(format!("Open failed ({}):", open_errors.len()));
270                for (path, error) in &open_errors {
271                    lines.push(format!("- {}: {}", path.display(), error));
272                }
273            }
274            if !dirty_check_errors.is_empty() {
275                lines.push(format!(
276                    "Dirty check failed ({}):",
277                    dirty_check_errors.len()
278                ));
279                for (path, error) in &dirty_check_errors {
280                    lines.push(format!("- {}: {}", path.display(), error));
281                }
282            }
283            if !reload_errors.is_empty() {
284                lines.push(format!("Reload failed ({}):", reload_errors.len()));
285                for error in &reload_errors {
286                    lines.push(format!("- {}", error));
287                }
288            }
289
290            if lines.is_empty() {
291                Ok("No paths provided.".to_string())
292            } else {
293                Ok(lines.join("\n"))
294            }
295        })
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302    use fs::Fs as _;
303    use gpui::TestAppContext;
304    use language::LineEnding;
305    use project::FakeFs;
306    use serde_json::json;
307    use settings::SettingsStore;
308    use util::path;
309
310    fn init_test(cx: &mut TestAppContext) {
311        cx.update(|cx| {
312            let settings_store = SettingsStore::test(cx);
313            cx.set_global(settings_store);
314        });
315        cx.update(|cx| {
316            let mut settings = AgentSettings::get_global(cx).clone();
317            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
318            AgentSettings::override_global(settings, cx);
319        });
320    }
321
322    #[gpui::test]
323    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
324        init_test(cx);
325
326        let fs = FakeFs::new(cx.executor());
327        fs.insert_tree(
328            "/root",
329            json!({
330                "dirty.txt": "on disk: dirty\n",
331                "clean.txt": "on disk: clean\n",
332            }),
333        )
334        .await;
335
336        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
337        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
338
339        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
340        let dirty_project_path = project.read_with(cx, |project, cx| {
341            project
342                .find_project_path("root/dirty.txt", cx)
343                .expect("dirty.txt should exist in project")
344        });
345
346        let dirty_buffer = project
347            .update(cx, |project, cx| {
348                project.open_buffer(dirty_project_path, cx)
349            })
350            .await
351            .unwrap();
352        dirty_buffer.update(cx, |buffer, cx| {
353            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
354        });
355        assert!(
356            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
357            "dirty.txt buffer should be dirty before restore"
358        );
359
360        // Ensure clean.txt is opened but remains clean.
361        let clean_project_path = project.read_with(cx, |project, cx| {
362            project
363                .find_project_path("root/clean.txt", cx)
364                .expect("clean.txt should exist in project")
365        });
366
367        let clean_buffer = project
368            .update(cx, |project, cx| {
369                project.open_buffer(clean_project_path, cx)
370            })
371            .await
372            .unwrap();
373        assert!(
374            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
375            "clean.txt buffer should start clean"
376        );
377
378        let output = cx
379            .update(|cx| {
380                tool.clone().run(
381                    RestoreFileFromDiskToolInput {
382                        paths: vec![
383                            PathBuf::from("root/dirty.txt"),
384                            PathBuf::from("root/clean.txt"),
385                        ],
386                    },
387                    ToolCallEventStream::test().0,
388                    cx,
389                )
390            })
391            .await
392            .unwrap();
393
394        // Output should mention restored + clean.
395        assert!(
396            output.contains("Restored 1 file(s)."),
397            "expected restored count line, got:\n{output}"
398        );
399        assert!(
400            output.contains("1 clean."),
401            "expected clean count line, got:\n{output}"
402        );
403
404        // Effect: dirty buffer should be restored back to disk content and become clean.
405        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
406        assert_eq!(
407            dirty_text, "on disk: dirty\n",
408            "dirty.txt buffer should be restored to disk contents"
409        );
410        assert!(
411            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
412            "dirty.txt buffer should not be dirty after restore"
413        );
414
415        // Disk contents should be unchanged (restore-from-disk should not write).
416        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
417        assert_eq!(disk_dirty, "on disk: dirty\n");
418
419        // Sanity: clean buffer should remain clean and unchanged.
420        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
421        assert_eq!(clean_text, "on disk: clean\n");
422        assert!(
423            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
424            "clean.txt buffer should remain clean"
425        );
426
427        // Test empty paths case.
428        let output = cx
429            .update(|cx| {
430                tool.clone().run(
431                    RestoreFileFromDiskToolInput { paths: vec![] },
432                    ToolCallEventStream::test().0,
433                    cx,
434                )
435            })
436            .await
437            .unwrap();
438        assert_eq!(output, "No paths provided.");
439
440        // Test not-found path case (path outside the project root).
441        let output = cx
442            .update(|cx| {
443                tool.clone().run(
444                    RestoreFileFromDiskToolInput {
445                        paths: vec![PathBuf::from("nonexistent/path.txt")],
446                    },
447                    ToolCallEventStream::test().0,
448                    cx,
449                )
450            })
451            .await
452            .unwrap();
453        assert!(
454            output.contains("Not found (1):"),
455            "expected not-found header line, got:\n{output}"
456        );
457        assert!(
458            output.contains("- nonexistent/path.txt"),
459            "expected not-found path bullet, got:\n{output}"
460        );
461
462        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
463    }
464
465    #[gpui::test]
466    async fn test_restore_file_symlink_escape_requests_authorization(cx: &mut TestAppContext) {
467        init_test(cx);
468
469        let fs = FakeFs::new(cx.executor());
470        fs.insert_tree(
471            path!("/root"),
472            json!({
473                "project": {
474                    "src": {}
475                },
476                "external": {
477                    "secret.txt": "secret content"
478                }
479            }),
480        )
481        .await;
482
483        fs.create_symlink(
484            path!("/root/project/link.txt").as_ref(),
485            PathBuf::from("../external/secret.txt"),
486        )
487        .await
488        .unwrap();
489
490        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
491        cx.executor().run_until_parked();
492
493        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
494
495        let (event_stream, mut event_rx) = ToolCallEventStream::test();
496        let task = cx.update(|cx| {
497            tool.clone().run(
498                RestoreFileFromDiskToolInput {
499                    paths: vec![PathBuf::from("project/link.txt")],
500                },
501                event_stream,
502                cx,
503            )
504        });
505
506        cx.run_until_parked();
507
508        let auth = event_rx.expect_authorization().await;
509        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
510        assert!(
511            title.contains("points outside the project"),
512            "Expected symlink escape authorization, got: {title}",
513        );
514
515        auth.response
516            .send(acp::PermissionOptionId::new("allow"))
517            .unwrap();
518
519        let _result = task.await;
520    }
521
522    #[gpui::test]
523    async fn test_restore_file_symlink_escape_honors_deny_policy(cx: &mut TestAppContext) {
524        init_test(cx);
525        cx.update(|cx| {
526            let mut settings = AgentSettings::get_global(cx).clone();
527            settings.tool_permissions.tools.insert(
528                "restore_file_from_disk".into(),
529                agent_settings::ToolRules {
530                    default: Some(settings::ToolPermissionMode::Deny),
531                    ..Default::default()
532                },
533            );
534            AgentSettings::override_global(settings, cx);
535        });
536
537        let fs = FakeFs::new(cx.executor());
538        fs.insert_tree(
539            path!("/root"),
540            json!({
541                "project": {
542                    "src": {}
543                },
544                "external": {
545                    "secret.txt": "secret content"
546                }
547            }),
548        )
549        .await;
550
551        fs.create_symlink(
552            path!("/root/project/link.txt").as_ref(),
553            PathBuf::from("../external/secret.txt"),
554        )
555        .await
556        .unwrap();
557
558        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
559        cx.executor().run_until_parked();
560
561        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
562
563        let (event_stream, mut event_rx) = ToolCallEventStream::test();
564        let result = cx
565            .update(|cx| {
566                tool.clone().run(
567                    RestoreFileFromDiskToolInput {
568                        paths: vec![PathBuf::from("project/link.txt")],
569                    },
570                    event_stream,
571                    cx,
572                )
573            })
574            .await;
575
576        assert!(result.is_err(), "Tool should fail when policy denies");
577        assert!(
578            !matches!(
579                event_rx.try_next(),
580                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
581            ),
582            "Deny policy should not emit symlink authorization prompt",
583        );
584    }
585
586    #[gpui::test]
587    async fn test_restore_file_symlink_escape_confirm_requires_single_approval(
588        cx: &mut TestAppContext,
589    ) {
590        init_test(cx);
591        cx.update(|cx| {
592            let mut settings = AgentSettings::get_global(cx).clone();
593            settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
594            AgentSettings::override_global(settings, cx);
595        });
596
597        let fs = FakeFs::new(cx.executor());
598        fs.insert_tree(
599            path!("/root"),
600            json!({
601                "project": {
602                    "src": {}
603                },
604                "external": {
605                    "secret.txt": "secret content"
606                }
607            }),
608        )
609        .await;
610
611        fs.create_symlink(
612            path!("/root/project/link.txt").as_ref(),
613            PathBuf::from("../external/secret.txt"),
614        )
615        .await
616        .unwrap();
617
618        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
619        cx.executor().run_until_parked();
620
621        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
622
623        let (event_stream, mut event_rx) = ToolCallEventStream::test();
624        let task = cx.update(|cx| {
625            tool.clone().run(
626                RestoreFileFromDiskToolInput {
627                    paths: vec![PathBuf::from("project/link.txt")],
628                },
629                event_stream,
630                cx,
631            )
632        });
633
634        cx.run_until_parked();
635
636        let auth = event_rx.expect_authorization().await;
637        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
638        assert!(
639            title.contains("points outside the project"),
640            "Expected symlink escape authorization, got: {title}",
641        );
642
643        auth.response
644            .send(acp::PermissionOptionId::new("allow"))
645            .unwrap();
646
647        assert!(
648            !matches!(
649                event_rx.try_next(),
650                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
651            ),
652            "Expected a single authorization prompt",
653        );
654
655        let _result = task.await;
656    }
657}