restore_file_from_disk_tool.rs

  1use super::tool_permissions::{
  2    ResolvedProjectPath, SensitiveSettingsKind, authorize_symlink_access,
  3    canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path,
  4    sensitive_settings_kind,
  5};
  6use agent_client_protocol as acp;
  7use agent_settings::AgentSettings;
  8use collections::FxHashSet;
  9use futures::FutureExt as _;
 10use gpui::{App, Entity, SharedString, Task};
 11use language::Buffer;
 12use project::Project;
 13use schemars::JsonSchema;
 14use serde::{Deserialize, Serialize};
 15use settings::Settings;
 16use std::path::{Path, PathBuf};
 17use std::sync::Arc;
 18use util::markdown::MarkdownInlineCode;
 19
 20use crate::{
 21    AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path,
 22};
 23
 24/// Discards unsaved changes in open buffers by reloading file contents from disk.
 25///
 26/// Use this tool when:
 27/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 28/// - You want to reset files to the on-disk state before retrying an edit.
 29///
 30/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 31#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 32pub struct RestoreFileFromDiskToolInput {
 33    /// The paths of the files to restore from disk.
 34    pub paths: Vec<PathBuf>,
 35}
 36
 37pub struct RestoreFileFromDiskTool {
 38    project: Entity<Project>,
 39}
 40
 41impl RestoreFileFromDiskTool {
 42    pub fn new(project: Entity<Project>) -> Self {
 43        Self { project }
 44    }
 45}
 46
 47impl AgentTool for RestoreFileFromDiskTool {
 48    type Input = RestoreFileFromDiskToolInput;
 49    type Output = String;
 50
 51    const NAME: &'static str = "restore_file_from_disk";
 52
 53    fn kind() -> acp::ToolKind {
 54        acp::ToolKind::Other
 55    }
 56
 57    fn initial_title(
 58        &self,
 59        input: Result<Self::Input, serde_json::Value>,
 60        _cx: &mut App,
 61    ) -> SharedString {
 62        match input {
 63            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 64            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 65            Err(_) => "Restore files from disk".into(),
 66        }
 67    }
 68
 69    fn run(
 70        self: Arc<Self>,
 71        input: ToolInput<Self::Input>,
 72        event_stream: ToolCallEventStream,
 73        cx: &mut App,
 74    ) -> Task<Result<String, String>> {
 75        let project = self.project.clone();
 76
 77        cx.spawn(async move |cx| {
 78            let input = input
 79                .recv()
 80                .await
 81                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 82
 83            // Check for any immediate deny before doing async work.
 84            for path in &input.paths {
 85                let path_str = path.to_string_lossy();
 86                let decision = cx.update(|cx| {
 87                    decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx))
 88                });
 89                if let ToolPermissionDecision::Deny(reason) = decision {
 90                    return Err(reason);
 91                }
 92            }
 93
 94            let input_paths = input.paths;
 95
 96            let fs = project.read_with(cx, |project, _cx| project.fs().clone());
 97            let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
 98
 99            let mut confirmation_paths: Vec<String> = Vec::new();
100
101            for path in &input_paths {
102                let path_str = path.to_string_lossy();
103                let decision = cx.update(|cx| {
104                    decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx))
105                });
106                let symlink_escape = project.read_with(cx, |project, cx| {
107                    path_has_symlink_escape(project, path, &canonical_roots, cx)
108                });
109
110                match decision {
111                    ToolPermissionDecision::Allow => {
112                        if !symlink_escape {
113                            let is_sensitive = super::tool_permissions::is_sensitive_settings_path(
114                                Path::new(&*path_str),
115                                fs.as_ref(),
116                            )
117                            .await;
118                            if is_sensitive {
119                                confirmation_paths.push(path_str.to_string());
120                            }
121                        }
122                    }
123                    ToolPermissionDecision::Deny(reason) => {
124                        return Err(reason);
125                    }
126                    ToolPermissionDecision::Confirm => {
127                        if !symlink_escape {
128                            confirmation_paths.push(path_str.to_string());
129                        }
130                    }
131                }
132            }
133
134            if !confirmation_paths.is_empty() {
135                let title = if confirmation_paths.len() == 1 {
136                    format!(
137                        "Restore {} from disk",
138                        MarkdownInlineCode(&confirmation_paths[0])
139                    )
140                } else {
141                    let paths: Vec<_> = confirmation_paths
142                        .iter()
143                        .take(3)
144                        .map(|p| p.as_str())
145                        .collect();
146                    if confirmation_paths.len() > 3 {
147                        format!(
148                            "Restore {}, and {} more from disk",
149                            paths.join(", "),
150                            confirmation_paths.len() - 3
151                        )
152                    } else {
153                        format!("Restore {} from disk", paths.join(", "))
154                    }
155                };
156
157                let mut settings_kind = None;
158                for p in &confirmation_paths {
159                    if let Some(kind) = sensitive_settings_kind(Path::new(p), fs.as_ref()).await {
160                        settings_kind = Some(kind);
161                        break;
162                    }
163                }
164                let title = match settings_kind {
165                    Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
166                    Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
167                    None => title,
168                };
169                let context = crate::ToolPermissionContext::new(Self::NAME, confirmation_paths);
170                let authorize = cx.update(|cx| event_stream.authorize(title, context, cx));
171                authorize.await.map_err(|e| e.to_string())?;
172            }
173            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
174
175            let mut restored_paths: Vec<PathBuf> = Vec::new();
176            let mut clean_paths: Vec<PathBuf> = Vec::new();
177            let mut not_found_paths: Vec<PathBuf> = Vec::new();
178            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
179            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
180            let mut reload_errors: Vec<String> = Vec::new();
181
182            for path in input_paths {
183                let project_path = match project.read_with(cx, |project, cx| {
184                    resolve_project_path(project, &path, &canonical_roots, cx)
185                }) {
186                    Ok(resolved) => {
187                        let (project_path, symlink_canonical_target) = match resolved {
188                            ResolvedProjectPath::Safe(path) => (path, None),
189                            ResolvedProjectPath::SymlinkEscape {
190                                project_path,
191                                canonical_target,
192                            } => (project_path, Some(canonical_target)),
193                        };
194                        if let Some(canonical_target) = &symlink_canonical_target {
195                            let path_str = path.to_string_lossy();
196                            let authorize_task = cx.update(|cx| {
197                                authorize_symlink_access(
198                                    Self::NAME,
199                                    &path_str,
200                                    canonical_target,
201                                    &event_stream,
202                                    cx,
203                                )
204                            });
205                            let result = authorize_task.await;
206                            if let Err(err) = result {
207                                reload_errors.push(format!("{}: {}", path.to_string_lossy(), err));
208                                continue;
209                            }
210                        }
211                        project_path
212                    }
213                    Err(_) => {
214                        not_found_paths.push(path);
215                        continue;
216                    }
217                };
218
219                let open_buffer_task =
220                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
221
222                let buffer = futures::select! {
223                    result = open_buffer_task.fuse() => {
224                        match result {
225                            Ok(buffer) => buffer,
226                            Err(error) => {
227                                open_errors.push((path, error.to_string()));
228                                continue;
229                            }
230                        }
231                    }
232                    _ = event_stream.cancelled_by_user().fuse() => {
233                        return Err("Restore cancelled by user".to_string());
234                    }
235                };
236
237                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
238
239                if is_dirty {
240                    buffers_to_reload.insert(buffer);
241                    restored_paths.push(path);
242                } else {
243                    clean_paths.push(path);
244                }
245            }
246
247            if !buffers_to_reload.is_empty() {
248                let reload_task = project.update(cx, |project, cx| {
249                    project.reload_buffers(buffers_to_reload, true, cx)
250                });
251
252                let result = futures::select! {
253                    result = reload_task.fuse() => result,
254                    _ = event_stream.cancelled_by_user().fuse() => {
255                        return Err("Restore cancelled by user".to_string());
256                    }
257                };
258                if let Err(error) = result {
259                    reload_errors.push(error.to_string());
260                }
261            }
262
263            let mut lines: Vec<String> = Vec::new();
264
265            if !restored_paths.is_empty() {
266                lines.push(format!("Restored {} file(s).", restored_paths.len()));
267            }
268            if !clean_paths.is_empty() {
269                lines.push(format!("{} clean.", clean_paths.len()));
270            }
271
272            if !not_found_paths.is_empty() {
273                lines.push(format!("Not found ({}):", not_found_paths.len()));
274                for path in &not_found_paths {
275                    lines.push(format!("- {}", path.display()));
276                }
277            }
278            if !open_errors.is_empty() {
279                lines.push(format!("Open failed ({}):", open_errors.len()));
280                for (path, error) in &open_errors {
281                    lines.push(format!("- {}: {}", path.display(), error));
282                }
283            }
284            if !dirty_check_errors.is_empty() {
285                lines.push(format!(
286                    "Dirty check failed ({}):",
287                    dirty_check_errors.len()
288                ));
289                for (path, error) in &dirty_check_errors {
290                    lines.push(format!("- {}: {}", path.display(), error));
291                }
292            }
293            if !reload_errors.is_empty() {
294                lines.push(format!("Reload failed ({}):", reload_errors.len()));
295                for error in &reload_errors {
296                    lines.push(format!("- {}", error));
297                }
298            }
299
300            if lines.is_empty() {
301                Ok("No paths provided.".to_string())
302            } else {
303                Ok(lines.join("\n"))
304            }
305        })
306    }
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use fs::Fs as _;
313    use gpui::TestAppContext;
314    use language::LineEnding;
315    use project::FakeFs;
316    use serde_json::json;
317    use settings::SettingsStore;
318    use util::path;
319
320    fn init_test(cx: &mut TestAppContext) {
321        cx.update(|cx| {
322            let settings_store = SettingsStore::test(cx);
323            cx.set_global(settings_store);
324        });
325        cx.update(|cx| {
326            let mut settings = AgentSettings::get_global(cx).clone();
327            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
328            AgentSettings::override_global(settings, cx);
329        });
330    }
331
332    #[gpui::test]
333    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
334        init_test(cx);
335
336        let fs = FakeFs::new(cx.executor());
337        fs.insert_tree(
338            "/root",
339            json!({
340                "dirty.txt": "on disk: dirty\n",
341                "clean.txt": "on disk: clean\n",
342            }),
343        )
344        .await;
345
346        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
347        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
348
349        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
350        let dirty_project_path = project.read_with(cx, |project, cx| {
351            project
352                .find_project_path("root/dirty.txt", cx)
353                .expect("dirty.txt should exist in project")
354        });
355
356        let dirty_buffer = project
357            .update(cx, |project, cx| {
358                project.open_buffer(dirty_project_path, cx)
359            })
360            .await
361            .unwrap();
362        dirty_buffer.update(cx, |buffer, cx| {
363            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
364        });
365        assert!(
366            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
367            "dirty.txt buffer should be dirty before restore"
368        );
369
370        // Ensure clean.txt is opened but remains clean.
371        let clean_project_path = project.read_with(cx, |project, cx| {
372            project
373                .find_project_path("root/clean.txt", cx)
374                .expect("clean.txt should exist in project")
375        });
376
377        let clean_buffer = project
378            .update(cx, |project, cx| {
379                project.open_buffer(clean_project_path, cx)
380            })
381            .await
382            .unwrap();
383        assert!(
384            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
385            "clean.txt buffer should start clean"
386        );
387
388        let output = cx
389            .update(|cx| {
390                tool.clone().run(
391                    ToolInput::resolved(RestoreFileFromDiskToolInput {
392                        paths: vec![
393                            PathBuf::from("root/dirty.txt"),
394                            PathBuf::from("root/clean.txt"),
395                        ],
396                    }),
397                    ToolCallEventStream::test().0,
398                    cx,
399                )
400            })
401            .await
402            .unwrap();
403
404        // Output should mention restored + clean.
405        assert!(
406            output.contains("Restored 1 file(s)."),
407            "expected restored count line, got:\n{output}"
408        );
409        assert!(
410            output.contains("1 clean."),
411            "expected clean count line, got:\n{output}"
412        );
413
414        // Effect: dirty buffer should be restored back to disk content and become clean.
415        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
416        assert_eq!(
417            dirty_text, "on disk: dirty\n",
418            "dirty.txt buffer should be restored to disk contents"
419        );
420        assert!(
421            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
422            "dirty.txt buffer should not be dirty after restore"
423        );
424
425        // Disk contents should be unchanged (restore-from-disk should not write).
426        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
427        assert_eq!(disk_dirty, "on disk: dirty\n");
428
429        // Sanity: clean buffer should remain clean and unchanged.
430        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
431        assert_eq!(clean_text, "on disk: clean\n");
432        assert!(
433            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
434            "clean.txt buffer should remain clean"
435        );
436
437        // Test empty paths case.
438        let output = cx
439            .update(|cx| {
440                tool.clone().run(
441                    ToolInput::resolved(RestoreFileFromDiskToolInput { paths: vec![] }),
442                    ToolCallEventStream::test().0,
443                    cx,
444                )
445            })
446            .await
447            .unwrap();
448        assert_eq!(output, "No paths provided.");
449
450        // Test not-found path case (path outside the project root).
451        let output = cx
452            .update(|cx| {
453                tool.clone().run(
454                    ToolInput::resolved(RestoreFileFromDiskToolInput {
455                        paths: vec![PathBuf::from("nonexistent/path.txt")],
456                    }),
457                    ToolCallEventStream::test().0,
458                    cx,
459                )
460            })
461            .await
462            .unwrap();
463        assert!(
464            output.contains("Not found (1):"),
465            "expected not-found header line, got:\n{output}"
466        );
467        assert!(
468            output.contains("- nonexistent/path.txt"),
469            "expected not-found path bullet, got:\n{output}"
470        );
471
472        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
473    }
474
475    #[gpui::test]
476    async fn test_restore_file_symlink_escape_requests_authorization(cx: &mut TestAppContext) {
477        init_test(cx);
478
479        let fs = FakeFs::new(cx.executor());
480        fs.insert_tree(
481            path!("/root"),
482            json!({
483                "project": {
484                    "src": {}
485                },
486                "external": {
487                    "secret.txt": "secret content"
488                }
489            }),
490        )
491        .await;
492
493        fs.create_symlink(
494            path!("/root/project/link.txt").as_ref(),
495            PathBuf::from("../external/secret.txt"),
496        )
497        .await
498        .unwrap();
499
500        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
501        cx.executor().run_until_parked();
502
503        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
504
505        let (event_stream, mut event_rx) = ToolCallEventStream::test();
506        let task = cx.update(|cx| {
507            tool.clone().run(
508                ToolInput::resolved(RestoreFileFromDiskToolInput {
509                    paths: vec![PathBuf::from("project/link.txt")],
510                }),
511                event_stream,
512                cx,
513            )
514        });
515
516        cx.run_until_parked();
517
518        let auth = event_rx.expect_authorization().await;
519        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
520        assert!(
521            title.contains("points outside the project"),
522            "Expected symlink escape authorization, got: {title}",
523        );
524
525        auth.response
526            .send(acp::PermissionOptionId::new("allow"))
527            .unwrap();
528
529        let _result = task.await;
530    }
531
532    #[gpui::test]
533    async fn test_restore_file_symlink_escape_honors_deny_policy(cx: &mut TestAppContext) {
534        init_test(cx);
535        cx.update(|cx| {
536            let mut settings = AgentSettings::get_global(cx).clone();
537            settings.tool_permissions.tools.insert(
538                "restore_file_from_disk".into(),
539                agent_settings::ToolRules {
540                    default: Some(settings::ToolPermissionMode::Deny),
541                    ..Default::default()
542                },
543            );
544            AgentSettings::override_global(settings, cx);
545        });
546
547        let fs = FakeFs::new(cx.executor());
548        fs.insert_tree(
549            path!("/root"),
550            json!({
551                "project": {
552                    "src": {}
553                },
554                "external": {
555                    "secret.txt": "secret content"
556                }
557            }),
558        )
559        .await;
560
561        fs.create_symlink(
562            path!("/root/project/link.txt").as_ref(),
563            PathBuf::from("../external/secret.txt"),
564        )
565        .await
566        .unwrap();
567
568        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
569        cx.executor().run_until_parked();
570
571        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
572
573        let (event_stream, mut event_rx) = ToolCallEventStream::test();
574        let result = cx
575            .update(|cx| {
576                tool.clone().run(
577                    ToolInput::resolved(RestoreFileFromDiskToolInput {
578                        paths: vec![PathBuf::from("project/link.txt")],
579                    }),
580                    event_stream,
581                    cx,
582                )
583            })
584            .await;
585
586        assert!(result.is_err(), "Tool should fail when policy denies");
587        assert!(
588            !matches!(
589                event_rx.try_next(),
590                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
591            ),
592            "Deny policy should not emit symlink authorization prompt",
593        );
594    }
595
596    #[gpui::test]
597    async fn test_restore_file_symlink_escape_confirm_requires_single_approval(
598        cx: &mut TestAppContext,
599    ) {
600        init_test(cx);
601        cx.update(|cx| {
602            let mut settings = AgentSettings::get_global(cx).clone();
603            settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
604            AgentSettings::override_global(settings, cx);
605        });
606
607        let fs = FakeFs::new(cx.executor());
608        fs.insert_tree(
609            path!("/root"),
610            json!({
611                "project": {
612                    "src": {}
613                },
614                "external": {
615                    "secret.txt": "secret content"
616                }
617            }),
618        )
619        .await;
620
621        fs.create_symlink(
622            path!("/root/project/link.txt").as_ref(),
623            PathBuf::from("../external/secret.txt"),
624        )
625        .await
626        .unwrap();
627
628        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
629        cx.executor().run_until_parked();
630
631        let tool = Arc::new(RestoreFileFromDiskTool::new(project));
632
633        let (event_stream, mut event_rx) = ToolCallEventStream::test();
634        let task = cx.update(|cx| {
635            tool.clone().run(
636                ToolInput::resolved(RestoreFileFromDiskToolInput {
637                    paths: vec![PathBuf::from("project/link.txt")],
638                }),
639                event_stream,
640                cx,
641            )
642        });
643
644        cx.run_until_parked();
645
646        let auth = event_rx.expect_authorization().await;
647        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
648        assert!(
649            title.contains("points outside the project"),
650            "Expected symlink escape authorization, got: {title}",
651        );
652
653        auth.response
654            .send(acp::PermissionOptionId::new("allow"))
655            .unwrap();
656
657        assert!(
658            !matches!(
659                event_rx.try_next(),
660                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
661            ),
662            "Expected a single authorization prompt",
663        );
664
665        let _result = task.await;
666    }
667}