restore_file_from_disk_tool.rs

  1use super::edit_file_tool::{
  2    SensitiveSettingsKind, is_sensitive_settings_path, sensitive_settings_kind,
  3};
  4use agent_client_protocol as acp;
  5use agent_settings::AgentSettings;
  6use anyhow::Result;
  7use collections::FxHashSet;
  8use futures::FutureExt as _;
  9use gpui::{App, Entity, SharedString, Task};
 10use language::Buffer;
 11use project::Project;
 12use schemars::JsonSchema;
 13use serde::{Deserialize, Serialize};
 14use settings::Settings;
 15use std::path::{Path, PathBuf};
 16use std::sync::Arc;
 17use util::markdown::MarkdownInlineCode;
 18
 19use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path};
 20
 21/// Discards unsaved changes in open buffers by reloading file contents from disk.
 22///
 23/// Use this tool when:
 24/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 25/// - You want to reset files to the on-disk state before retrying an edit.
 26///
 27/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 28#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 29pub struct RestoreFileFromDiskToolInput {
 30    /// The paths of the files to restore from disk.
 31    pub paths: Vec<PathBuf>,
 32}
 33
 34pub struct RestoreFileFromDiskTool {
 35    project: Entity<Project>,
 36}
 37
 38impl RestoreFileFromDiskTool {
 39    pub fn new(project: Entity<Project>) -> Self {
 40        Self { project }
 41    }
 42}
 43
 44impl AgentTool for RestoreFileFromDiskTool {
 45    type Input = RestoreFileFromDiskToolInput;
 46    type Output = String;
 47
 48    const NAME: &'static str = "restore_file_from_disk";
 49
 50    fn kind() -> acp::ToolKind {
 51        acp::ToolKind::Other
 52    }
 53
 54    fn initial_title(
 55        &self,
 56        input: Result<Self::Input, serde_json::Value>,
 57        _cx: &mut App,
 58    ) -> SharedString {
 59        match input {
 60            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 61            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 62            Err(_) => "Restore files from disk".into(),
 63        }
 64    }
 65
 66    fn run(
 67        self: Arc<Self>,
 68        input: Self::Input,
 69        event_stream: ToolCallEventStream,
 70        cx: &mut App,
 71    ) -> Task<Result<String>> {
 72        let settings = AgentSettings::get_global(cx);
 73        let mut confirmation_paths: Vec<String> = Vec::new();
 74
 75        for path in &input.paths {
 76            let path_str = path.to_string_lossy();
 77            let decision = decide_permission_for_path(Self::NAME, &path_str, settings);
 78            match decision {
 79                ToolPermissionDecision::Allow => {
 80                    if is_sensitive_settings_path(Path::new(&*path_str)) {
 81                        confirmation_paths.push(path_str.to_string());
 82                    }
 83                }
 84                ToolPermissionDecision::Deny(reason) => {
 85                    return Task::ready(Err(anyhow::anyhow!("{}", reason)));
 86                }
 87                ToolPermissionDecision::Confirm => {
 88                    confirmation_paths.push(path_str.to_string());
 89                }
 90            }
 91        }
 92
 93        let authorize = if !confirmation_paths.is_empty() {
 94            let title = if confirmation_paths.len() == 1 {
 95                format!(
 96                    "Restore {} from disk",
 97                    MarkdownInlineCode(&confirmation_paths[0])
 98                )
 99            } else {
100                let paths: Vec<_> = confirmation_paths
101                    .iter()
102                    .take(3)
103                    .map(|p| p.as_str())
104                    .collect();
105                if confirmation_paths.len() > 3 {
106                    format!(
107                        "Restore {}, and {} more from disk",
108                        paths.join(", "),
109                        confirmation_paths.len() - 3
110                    )
111                } else {
112                    format!("Restore {} from disk", paths.join(", "))
113                }
114            };
115            let sensitive_kind = confirmation_paths
116                .iter()
117                .find_map(|p| sensitive_settings_kind(Path::new(p)));
118            let title = match sensitive_kind {
119                Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
120                Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
121                None => title,
122            };
123            let context = crate::ToolPermissionContext {
124                tool_name: Self::NAME.to_string(),
125                input_values: confirmation_paths,
126            };
127            Some(event_stream.authorize(title, context, cx))
128        } else {
129            None
130        };
131
132        let project = self.project.clone();
133        let input_paths = input.paths;
134
135        cx.spawn(async move |cx| {
136            if let Some(authorize) = authorize {
137                authorize.await?;
138            }
139            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
140
141            let mut restored_paths: Vec<PathBuf> = Vec::new();
142            let mut clean_paths: Vec<PathBuf> = Vec::new();
143            let mut not_found_paths: Vec<PathBuf> = Vec::new();
144            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
145            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
146            let mut reload_errors: Vec<String> = Vec::new();
147
148            for path in input_paths {
149                let Some(project_path) =
150                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
151                else {
152                    not_found_paths.push(path);
153                    continue;
154                };
155
156                let open_buffer_task =
157                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
158
159                let buffer = futures::select! {
160                    result = open_buffer_task.fuse() => {
161                        match result {
162                            Ok(buffer) => buffer,
163                            Err(error) => {
164                                open_errors.push((path, error.to_string()));
165                                continue;
166                            }
167                        }
168                    }
169                    _ = event_stream.cancelled_by_user().fuse() => {
170                        anyhow::bail!("Restore cancelled by user");
171                    }
172                };
173
174                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
175
176                if is_dirty {
177                    buffers_to_reload.insert(buffer);
178                    restored_paths.push(path);
179                } else {
180                    clean_paths.push(path);
181                }
182            }
183
184            if !buffers_to_reload.is_empty() {
185                let reload_task = project.update(cx, |project, cx| {
186                    project.reload_buffers(buffers_to_reload, true, cx)
187                });
188
189                let result = futures::select! {
190                    result = reload_task.fuse() => result,
191                    _ = event_stream.cancelled_by_user().fuse() => {
192                        anyhow::bail!("Restore cancelled by user");
193                    }
194                };
195                if let Err(error) = result {
196                    reload_errors.push(error.to_string());
197                }
198            }
199
200            let mut lines: Vec<String> = Vec::new();
201
202            if !restored_paths.is_empty() {
203                lines.push(format!("Restored {} file(s).", restored_paths.len()));
204            }
205            if !clean_paths.is_empty() {
206                lines.push(format!("{} clean.", clean_paths.len()));
207            }
208
209            if !not_found_paths.is_empty() {
210                lines.push(format!("Not found ({}):", not_found_paths.len()));
211                for path in &not_found_paths {
212                    lines.push(format!("- {}", path.display()));
213                }
214            }
215            if !open_errors.is_empty() {
216                lines.push(format!("Open failed ({}):", open_errors.len()));
217                for (path, error) in &open_errors {
218                    lines.push(format!("- {}: {}", path.display(), error));
219                }
220            }
221            if !dirty_check_errors.is_empty() {
222                lines.push(format!(
223                    "Dirty check failed ({}):",
224                    dirty_check_errors.len()
225                ));
226                for (path, error) in &dirty_check_errors {
227                    lines.push(format!("- {}: {}", path.display(), error));
228                }
229            }
230            if !reload_errors.is_empty() {
231                lines.push(format!("Reload failed ({}):", reload_errors.len()));
232                for error in &reload_errors {
233                    lines.push(format!("- {}", error));
234                }
235            }
236
237            if lines.is_empty() {
238                Ok("No paths provided.".to_string())
239            } else {
240                Ok(lines.join("\n"))
241            }
242        })
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use fs::Fs;
250    use gpui::TestAppContext;
251    use language::LineEnding;
252    use project::FakeFs;
253    use serde_json::json;
254    use settings::SettingsStore;
255    use util::path;
256
257    fn init_test(cx: &mut TestAppContext) {
258        cx.update(|cx| {
259            let settings_store = SettingsStore::test(cx);
260            cx.set_global(settings_store);
261        });
262        cx.update(|cx| {
263            let mut settings = AgentSettings::get_global(cx).clone();
264            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
265            AgentSettings::override_global(settings, cx);
266        });
267    }
268
269    #[gpui::test]
270    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
271        init_test(cx);
272
273        let fs = FakeFs::new(cx.executor());
274        fs.insert_tree(
275            "/root",
276            json!({
277                "dirty.txt": "on disk: dirty\n",
278                "clean.txt": "on disk: clean\n",
279            }),
280        )
281        .await;
282
283        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
284        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
285
286        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
287        let dirty_project_path = project.read_with(cx, |project, cx| {
288            project
289                .find_project_path("root/dirty.txt", cx)
290                .expect("dirty.txt should exist in project")
291        });
292
293        let dirty_buffer = project
294            .update(cx, |project, cx| {
295                project.open_buffer(dirty_project_path, cx)
296            })
297            .await
298            .unwrap();
299        dirty_buffer.update(cx, |buffer, cx| {
300            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
301        });
302        assert!(
303            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
304            "dirty.txt buffer should be dirty before restore"
305        );
306
307        // Ensure clean.txt is opened but remains clean.
308        let clean_project_path = project.read_with(cx, |project, cx| {
309            project
310                .find_project_path("root/clean.txt", cx)
311                .expect("clean.txt should exist in project")
312        });
313
314        let clean_buffer = project
315            .update(cx, |project, cx| {
316                project.open_buffer(clean_project_path, cx)
317            })
318            .await
319            .unwrap();
320        assert!(
321            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
322            "clean.txt buffer should start clean"
323        );
324
325        let output = cx
326            .update(|cx| {
327                tool.clone().run(
328                    RestoreFileFromDiskToolInput {
329                        paths: vec![
330                            PathBuf::from("root/dirty.txt"),
331                            PathBuf::from("root/clean.txt"),
332                        ],
333                    },
334                    ToolCallEventStream::test().0,
335                    cx,
336                )
337            })
338            .await
339            .unwrap();
340
341        // Output should mention restored + clean.
342        assert!(
343            output.contains("Restored 1 file(s)."),
344            "expected restored count line, got:\n{output}"
345        );
346        assert!(
347            output.contains("1 clean."),
348            "expected clean count line, got:\n{output}"
349        );
350
351        // Effect: dirty buffer should be restored back to disk content and become clean.
352        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
353        assert_eq!(
354            dirty_text, "on disk: dirty\n",
355            "dirty.txt buffer should be restored to disk contents"
356        );
357        assert!(
358            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
359            "dirty.txt buffer should not be dirty after restore"
360        );
361
362        // Disk contents should be unchanged (restore-from-disk should not write).
363        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
364        assert_eq!(disk_dirty, "on disk: dirty\n");
365
366        // Sanity: clean buffer should remain clean and unchanged.
367        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
368        assert_eq!(clean_text, "on disk: clean\n");
369        assert!(
370            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
371            "clean.txt buffer should remain clean"
372        );
373
374        // Test empty paths case.
375        let output = cx
376            .update(|cx| {
377                tool.clone().run(
378                    RestoreFileFromDiskToolInput { paths: vec![] },
379                    ToolCallEventStream::test().0,
380                    cx,
381                )
382            })
383            .await
384            .unwrap();
385        assert_eq!(output, "No paths provided.");
386
387        // Test not-found path case (path outside the project root).
388        let output = cx
389            .update(|cx| {
390                tool.clone().run(
391                    RestoreFileFromDiskToolInput {
392                        paths: vec![PathBuf::from("nonexistent/path.txt")],
393                    },
394                    ToolCallEventStream::test().0,
395                    cx,
396                )
397            })
398            .await
399            .unwrap();
400        assert!(
401            output.contains("Not found (1):"),
402            "expected not-found header line, got:\n{output}"
403        );
404        assert!(
405            output.contains("- nonexistent/path.txt"),
406            "expected not-found path bullet, got:\n{output}"
407        );
408
409        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
410    }
411}