restore_file_from_disk_tool.rs

  1use super::tool_permissions::{
  2    ResolvedProjectPath, SensitiveSettingsKind, authorize_symlink_access,
  3    canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path,
  4    sensitive_settings_kind,
  5};
  6use agent_client_protocol as acp;
  7use agent_settings::AgentSettings;
  8use collections::FxHashSet;
  9use futures::FutureExt as _;
 10use gpui::{App, Entity, SharedString, Task};
 11use language::Buffer;
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use settings::Settings;
 16use std::path::{Path, PathBuf};
 17use std::sync::Arc;
 18use util::markdown::MarkdownInlineCode;
 19
 20use crate::{
 21    AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path,
 22};
 23
 24/// Discards unsaved changes in open buffers by reloading file contents from disk.
 25///
 26/// Use this tool when:
 27/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 28/// - You want to reset files to the on-disk state before retrying an edit.
 29///
 30/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 31#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 32pub struct RestoreFileFromDiskToolInput {
 33    /// The paths of the files to restore from disk.
 34    pub paths: Vec<PathBuf>,
 35}
 36
 37pub struct RestoreFileFromDiskTool {
 38    project: Entity<Project>,
 39}
 40
 41impl RestoreFileFromDiskTool {
 42    pub fn new(project: Entity<Project>) -> Self {
 43        Self { project }
 44    }
 45}
 46
 47impl AgentTool for RestoreFileFromDiskTool {
 48    type Input = RestoreFileFromDiskToolInput;
 49    type Output = String;
 50
 51    const NAME: &'static str = "restore_file_from_disk";
 52
 53    fn kind() -> acp::ToolKind {
 54        acp::ToolKind::Other
 55    }
 56
 57    fn initial_title(
 58        &self,
 59        input: Result<Self::Input, serde_json::Value>,
 60        _cx: &mut App,
 61    ) -> SharedString {
 62        match input {
 63            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 64            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 65            Err(_) => "Restore files from disk".into(),
 66        }
 67    }
 68
 69    fn run(
 70        self: Arc<Self>,
 71        input: ToolInput<Self::Input>,
 72        event_stream: ToolCallEventStream,
 73        cx: &mut App,
 74    ) -> Task<Result<String, String>> {
 75        let project = self.project.clone();
 76
 77        cx.spawn(async move |cx| {
 78            let input = input
 79                .recv()
 80                .await
 81                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 82
 83            // Check for any immediate deny before doing async work.
 84            for path in &input.paths {
 85                let path_str = path.to_string_lossy();
 86                let decision = cx.update(|cx| {
 87                    decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx))
 88                });
 89                if let ToolPermissionDecision::Deny(reason) = decision {
 90                    return Err(reason);
 91                }
 92            }
 93
 94            let input_paths = input.paths;
 95
 96            let fs = project.read_with(cx, |project, _cx| project.fs().clone());
 97            let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
 98
 99            let mut confirmation_paths: Vec<String> = Vec::new();
100
101            for path in &input_paths {
102                let path_str = path.to_string_lossy();
103                let decision = cx.update(|cx| {
104                    decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx))
105                });
106                let symlink_escape = project.read_with(cx, |project, cx| {
107                    path_has_symlink_escape(project, path, &canonical_roots, cx)
108                });
109
110                match decision {
111                    ToolPermissionDecision::Allow => {
112                        if !symlink_escape {
113                            let is_sensitive = super::tool_permissions::is_sensitive_settings_path(
114                                Path::new(&*path_str),
115                                fs.as_ref(),
116                            )
117                            .await;
118                            if is_sensitive {
119                                confirmation_paths.push(path_str.to_string());
120                            }
121                        }
122                    }
123                    ToolPermissionDecision::Deny(reason) => {
124                        return Err(reason);
125                    }
126                    ToolPermissionDecision::Confirm => {
127                        if !symlink_escape {
128                            confirmation_paths.push(path_str.to_string());
129                        }
130                    }
131                }
132            }
133
134            if !confirmation_paths.is_empty() {
135                let title = if confirmation_paths.len() == 1 {
136                    format!(
137                        "Restore {} from disk",
138                        MarkdownInlineCode(&confirmation_paths[0])
139                    )
140                } else {
141                    let paths: Vec<_> = confirmation_paths
142                        .iter()
143                        .take(3)
144                        .map(|p| p.as_str())
145                        .collect();
146                    if confirmation_paths.len() > 3 {
147                        format!(
148                            "Restore {}, and {} more from disk",
149                            paths.join(", "),
150                            confirmation_paths.len() - 3
151                        )
152                    } else {
153                        format!("Restore {} from disk", paths.join(", "))
154                    }
155                };
156
157                let mut settings_kind = None;
158                for p in &confirmation_paths {
159                    if let Some(kind) = sensitive_settings_kind(Path::new(p), fs.as_ref()).await {
160                        settings_kind = Some(kind);
161                        break;
162                    }
163                }
164                let title = match settings_kind {
165                    Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
166                    Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
167                    None => title,
168                };
169                let context = crate::ToolPermissionContext::new(Self::NAME, confirmation_paths);
170                let authorize = cx.update(|cx| event_stream.authorize(title, context, cx));
171                authorize.await.map_err(|e| e.to_string())?;
172            }
173            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
174
175            let mut restored_paths: Vec<PathBuf> = Vec::new();
176            let mut clean_paths: Vec<PathBuf> = Vec::new();
177            let mut not_found_paths: Vec<PathBuf> = Vec::new();
178            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
179            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
180            let mut reload_errors: Vec<String> = Vec::new();
181
182            for path in input_paths {
183                let project_path = match project.read_with(cx, |project, cx| {
184                    resolve_project_path(project, &path, &canonical_roots, cx)
185                }) {
186                    Ok(resolved) => {
187                        let (project_path, symlink_canonical_target) = match resolved {
188                            ResolvedProjectPath::Safe(path) => (path, None),
189                            ResolvedProjectPath::SymlinkEscape {
190                                project_path,
191                                canonical_target,
192                            } => (project_path, Some(canonical_target)),
193                        };
194                        if let Some(canonical_target) = &symlink_canonical_target {
195                            let path_str = path.to_string_lossy();
196                            let authorize_task = cx.update(|cx| {
197                                authorize_symlink_access(
198                                    Self::NAME,
199                                    &path_str,
200                                    canonical_target,
201                                    &event_stream,
202                                    cx,
203                                )
204                            });
205                            let result = authorize_task.await;
206                            if let Err(err) = result {
207                                reload_errors.push(format!("{}: {}", path.to_string_lossy(), err));
208                                continue;
209                            }
210                        }
211                        project_path
212                    }
213                    Err(_) => {
214                        not_found_paths.push(path);
215                        continue;
216                    }
217                };
218
219                let open_buffer_task =
220                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
221
222                let buffer = futures::select! {
223                    result = open_buffer_task.fuse() => {
224                        match result {
225                            Ok(buffer) => buffer,
226                            Err(error) => {
227                                open_errors.push((path, error.to_string()));
228                                continue;
229                            }
230                        }
231                    }
232                    _ = event_stream.cancelled_by_user().fuse() => {
233                        return Err("Restore cancelled by user".to_string());
234                    }
235                };
236
237                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
238
239                if is_dirty {
240                    buffers_to_reload.insert(buffer);
241                    restored_paths.push(path);
242                } else {
243                    clean_paths.push(path);
244                }
245            }
246
247            if !buffers_to_reload.is_empty() {
248                let reload_task = project.update(cx, |project, cx| {
249                    project.reload_buffers(buffers_to_reload, true, cx)
250                });
251
252                let result = futures::select! {
253                    result = reload_task.fuse() => result,
254                    _ = event_stream.cancelled_by_user().fuse() => {
255                        return Err("Restore cancelled by user".to_string());
256                    }
257                };
258                if let Err(error) = result {
259                    reload_errors.push(error.to_string());
260                }
261            }
262
263            let mut lines: Vec<String> = Vec::new();
264
265            if !restored_paths.is_empty() {
266                lines.push(format!("Restored {} file(s).", restored_paths.len()));
267            }
268            if !clean_paths.is_empty() {
269                lines.push(format!("{} clean.", clean_paths.len()));
270            }
271
272            if !not_found_paths.is_empty() {
273                lines.push(format!("Not found ({}):", not_found_paths.len()));
274                for path in &not_found_paths {
275                    lines.push(format!("- {}", path.display()));
276                }
277            }
278            if !open_errors.is_empty() {
279                lines.push(format!("Open failed ({}):", open_errors.len()));
280                for (path, error) in &open_errors {
281                    lines.push(format!("- {}: {}", path.display(), error));
282                }
283            }
284            if !dirty_check_errors.is_empty() {
285                lines.push(format!(
286                    "Dirty check failed ({}):",
287                    dirty_check_errors.len()
288                ));
289                for (path, error) in &dirty_check_errors {
290                    lines.push(format!("- {}: {}", path.display(), error));
291                }
292            }
293            if !reload_errors.is_empty() {
294                lines.push(format!("Reload failed ({}):", reload_errors.len()));
295                for error in &reload_errors {
296                    lines.push(format!("- {}", error));
297                }
298            }
299
300            if lines.is_empty() {
301                Ok("No paths provided.".to_string())
302            } else {
303                Ok(lines.join("\n"))
304            }
305        })
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use fs::Fs as _;
313    use gpui::TestAppContext;
314    use language::LineEnding;
315    use project::FakeFs;
316    use serde_json::json;
317    use settings::SettingsStore;
318    use util::path;
319
320    fn init_test(cx: &mut TestAppContext) {
321        cx.update(|cx| {
322            let settings_store = SettingsStore::test(cx);
323            cx.set_global(settings_store);
324        });
325        cx.update(|cx| {
326            let mut settings = AgentSettings::get_global(cx).clone();
327            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
328            AgentSettings::override_global(settings, cx);
329        });
330    }
331
332    #[gpui::test]
333    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
334        init_test(cx);
335
336        let fs = FakeFs::new(cx.executor());
337        fs.insert_tree(
338            "/root",
339            json!({
340                "dirty.txt": "on disk: dirty\n",
341                "clean.txt": "on disk: clean\n",
342            }),
343        )
344        .await;
345
346        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
347        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
348
349        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
350        let dirty_project_path = project.read_with(cx, |project, cx| {
351            project
352                .find_project_path("root/dirty.txt", cx)
353                .expect("dirty.txt should exist in project")
354        });
355
356        let dirty_buffer = project
357            .update(cx, |project, cx| {
358                project.open_buffer(dirty_project_path, cx)
359            })
360            .await
361            .unwrap();
362        dirty_buffer.update(cx, |buffer, cx| {
363            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
364        });
365        assert!(
366            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
367            "dirty.txt buffer should be dirty before restore"
368        );
369
370        // Ensure clean.txt is opened but remains clean.
371        let clean_project_path = project.read_with(cx, |project, cx| {
372            project
373                .find_project_path("root/clean.txt", cx)
374                .expect("clean.txt should exist in project")
375        });
376
377        let clean_buffer = project
378            .update(cx, |project, cx| {
379                project.open_buffer(clean_project_path, cx)
380            })
381            .await
382            .unwrap();
383        assert!(
384            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
385            "clean.txt buffer should start clean"
386        );
387
388        let output = cx
389            .update(|cx| {
390                tool.clone().run(
391                    ToolInput::resolved(RestoreFileFromDiskToolInput {
392                        paths: vec![
393                            PathBuf::from("root/dirty.txt"),
394                            PathBuf::from("root/clean.txt"),
395                        ],
396                    }),
397                    ToolCallEventStream::test().0,
398                    cx,
399                )
400            })
401            .await
402            .unwrap();
403
404        // Output should mention restored + clean.
405        assert!(
406            output.contains("Restored 1 file(s)."),
407            "expected restored count line, got:\n{output}"
408        );
409        assert!(
410            output.contains("1 clean."),
411            "expected clean count line, got:\n{output}"
412        );
413
414        // Effect: dirty buffer should be restored back to disk content and become clean.
415        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
416        assert_eq!(
417            dirty_text, "on disk: dirty\n",
418            "dirty.txt buffer should be restored to disk contents"
419        );
420        assert!(
421            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
422            "dirty.txt buffer should not be dirty after restore"
423        );
424
425        // Disk contents should be unchanged (restore-from-disk should not write).
426        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
427        assert_eq!(disk_dirty, "on disk: dirty\n");
428
429        // Sanity: clean buffer should remain clean and unchanged.
430        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
431        assert_eq!(clean_text, "on disk: clean\n");
432        assert!(
433            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
434            "clean.txt buffer should remain clean"
435        );
436
437        // Test empty paths case.
438        let output = cx
439            .update(|cx| {
440                tool.clone().run(
441                    ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![] }),
442                    ToolCallEventStream::test().0,
443                    cx,
444                )
445            })
446            .await
447            .unwrap();
448        assert_eq!(output, "No paths provided.");
449
450        // Test not-found path case (path outside the project root).
451        let output = cx
452            .update(|cx| {
453                tool.clone().run(
454                    ToolInput::resolved(RestoreFileFromDiskToolInput {
455                        paths: vec![PathBuf::from("nonexistent/path.txt")],
456                    }),
457                    ToolCallEventStream::test().0,
458                    cx,
459                )
460            })
461            .await
462            .unwrap();
463        assert!(
464            output.contains("Not found (1):"),
465            "expected not-found header line, got:\n{output}"
466        );
467        assert!(
468            output.contains("- nonexistent/path.txt"),
469            "expected not-found path bullet, got:\n{output}"
470        );
471
472        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
473    }
474
475    #[gpui::test]
476    async fn test_restore_file_symlink_escape_requests_authorization(cx: &mut TestAppContext) {
477        init_test(cx);
478
479        let fs = FakeFs::new(cx.executor());
480        fs.insert_tree(
481            path!("/root"),
482            json!({
483                "project": {
484                    "src": {}
485                },
486                "external": {
487                    "secret.txt": "secret content"
488                }
489            }),
490        )
491        .await;
492
493        fs.create_symlink(
494            path!("/root/project/link.txt").as_ref(),
495            PathBuf::from("../external/secret.txt"),
496        )
497        .await
498        .unwrap();
499
500        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
501        cx.executor().run_until_parked();
502
503        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
504
505        let (event_stream, mut event_rx) = ToolCallEventStream::test();
506        let task = cx.update(|cx| {
507            tool.clone().run(
508                ToolInput::resolved(RestoreFileFromDiskToolInput {
509                    paths: vec![PathBuf::from("project/link.txt")],
510                }),
511                event_stream,
512                cx,
513            )
514        });
515
516        cx.run_until_parked();
517
518        let auth = event_rx.expect_authorization().await;
519        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
520        assert!(
521            title.contains("points outside the project"),
522            "Expected symlink escape authorization, got: {title}",
523        );
524
525        auth.response
526            .send(acp_thread::SelectedPermissionOutcome::new(
527                acp::PermissionOptionId::new("allow"),
528                acp::PermissionOptionKind::AllowOnce,
529            ))
530            .unwrap();
531
532        let _result = task.await;
533    }
534
535    #[gpui::test]
536    async fn test_restore_file_symlink_escape_honors_deny_policy(cx: &mut TestAppContext) {
537        init_test(cx);
538        cx.update(|cx| {
539            let mut settings = AgentSettings::get_global(cx).clone();
540            settings.tool_permissions.tools.insert(
541                "restore_file_from_disk".into(),
542                agent_settings::ToolRules {
543                    default: Some(settings::ToolPermissionMode::Deny),
544                    ..Default::default()
545                },
546            );
547            AgentSettings::override_global(settings, cx);
548        });
549
550        let fs = FakeFs::new(cx.executor());
551        fs.insert_tree(
552            path!("/root"),
553            json!({
554                "project": {
555                    "src": {}
556                },
557                "external": {
558                    "secret.txt": "secret content"
559                }
560            }),
561        )
562        .await;
563
564        fs.create_symlink(
565            path!("/root/project/link.txt").as_ref(),
566            PathBuf::from("../external/secret.txt"),
567        )
568        .await
569        .unwrap();
570
571        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
572        cx.executor().run_until_parked();
573
574        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
575
576        let (event_stream, mut event_rx) = ToolCallEventStream::test();
577        let result = cx
578            .update(|cx| {
579                tool.clone().run(
580                    ToolInput::resolved(RestoreFileFromDiskToolInput {
581                        paths: vec![PathBuf::from("project/link.txt")],
582                    }),
583                    event_stream,
584                    cx,
585                )
586            })
587            .await;
588
589        assert!(result.is_err(), "Tool should fail when policy denies");
590        assert!(
591            !matches!(
592                event_rx.try_recv(),
593                Ok(Ok(crate::ThreadEvent::ToolCallAuthorization(_)))
594            ),
595            "Deny policy should not emit symlink authorization prompt",
596        );
597    }
598
599    #[gpui::test]
600    async fn test_restore_file_symlink_escape_confirm_requires_single_approval(
601        cx: &mut TestAppContext,
602    ) {
603        init_test(cx);
604        cx.update(|cx| {
605            let mut settings = AgentSettings::get_global(cx).clone();
606            settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
607            AgentSettings::override_global(settings, cx);
608        });
609
610        let fs = FakeFs::new(cx.executor());
611        fs.insert_tree(
612            path!("/root"),
613            json!({
614                "project": {
615                    "src": {}
616                },
617                "external": {
618                    "secret.txt": "secret content"
619                }
620            }),
621        )
622        .await;
623
624        fs.create_symlink(
625            path!("/root/project/link.txt").as_ref(),
626            PathBuf::from("../external/secret.txt"),
627        )
628        .await
629        .unwrap();
630
631        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
632        cx.executor().run_until_parked();
633
634        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
635
636        let (event_stream, mut event_rx) = ToolCallEventStream::test();
637        let task = cx.update(|cx| {
638            tool.clone().run(
639                ToolInput::resolved(RestoreFileFromDiskToolInput {
640                    paths: vec![PathBuf::from("project/link.txt")],
641                }),
642                event_stream,
643                cx,
644            )
645        });
646
647        cx.run_until_parked();
648
649        let auth = event_rx.expect_authorization().await;
650        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
651        assert!(
652            title.contains("points outside the project"),
653            "Expected symlink escape authorization, got: {title}",
654        );
655
656        auth.response
657            .send(acp_thread::SelectedPermissionOutcome::new(
658                acp::PermissionOptionId::new("allow"),
659                acp::PermissionOptionKind::AllowOnce,
660            ))
661            .unwrap();
662
663        assert!(
664            !matches!(
665                event_rx.try_recv(),
666                Ok(Ok(crate::ThreadEvent::ToolCallAuthorization(_)))
667            ),
668            "Expected a single authorization prompt",
669        );
670
671        let _result = task.await;
672    }
673}