restore_file_from_disk_tool.rs

  1use agent_client_protocol as acp;
  2use anyhow::Result;
  3use collections::FxHashSet;
  4use futures::FutureExt as _;
  5use gpui::{App, Entity, SharedString, Task};
  6use language::Buffer;
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::path::PathBuf;
 11use std::sync::Arc;
 12
 13use crate::{AgentTool, ToolCallEventStream};
 14
 15/// Discards unsaved changes in open buffers by reloading file contents from disk.
 16///
 17/// Use this tool when:
 18/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 19/// - You want to reset files to the on-disk state before retrying an edit.
 20///
 21/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct RestoreFileFromDiskToolInput {
 24    /// The paths of the files to restore from disk.
 25    pub paths: Vec<PathBuf>,
 26}
 27
 28pub struct RestoreFileFromDiskTool {
 29    project: Entity<Project>,
 30}
 31
 32impl RestoreFileFromDiskTool {
 33    pub fn new(project: Entity<Project>) -> Self {
 34        Self { project }
 35    }
 36}
 37
 38impl AgentTool for RestoreFileFromDiskTool {
 39    type Input = RestoreFileFromDiskToolInput;
 40    type Output = String;
 41
 42    fn name() -> &'static str {
 43        "restore_file_from_disk"
 44    }
 45
 46    fn kind() -> acp::ToolKind {
 47        acp::ToolKind::Other
 48    }
 49
 50    fn initial_title(
 51        &self,
 52        input: Result<Self::Input, serde_json::Value>,
 53        _cx: &mut App,
 54    ) -> SharedString {
 55        match input {
 56            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 57            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 58            Err(_) => "Restore files from disk".into(),
 59        }
 60    }
 61
 62    fn run(
 63        self: Arc<Self>,
 64        input: Self::Input,
 65        event_stream: ToolCallEventStream,
 66        cx: &mut App,
 67    ) -> Task<Result<String>> {
 68        let project = self.project.clone();
 69        let input_paths = input.paths;
 70
 71        cx.spawn(async move |cx| {
 72            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
 73
 74            let mut restored_paths: Vec<PathBuf> = Vec::new();
 75            let mut clean_paths: Vec<PathBuf> = Vec::new();
 76            let mut not_found_paths: Vec<PathBuf> = Vec::new();
 77            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
 78            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
 79            let mut reload_errors: Vec<String> = Vec::new();
 80
 81            for path in input_paths {
 82                let Some(project_path) =
 83                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
 84                else {
 85                    not_found_paths.push(path);
 86                    continue;
 87                };
 88
 89                let open_buffer_task =
 90                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 91
 92                let buffer = futures::select! {
 93                    result = open_buffer_task.fuse() => {
 94                        match result {
 95                            Ok(buffer) => buffer,
 96                            Err(error) => {
 97                                open_errors.push((path, error.to_string()));
 98                                continue;
 99                            }
100                        }
101                    }
102                    _ = event_stream.cancelled_by_user().fuse() => {
103                        anyhow::bail!("Restore cancelled by user");
104                    }
105                };
106
107                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
108
109                if is_dirty {
110                    buffers_to_reload.insert(buffer);
111                    restored_paths.push(path);
112                } else {
113                    clean_paths.push(path);
114                }
115            }
116
117            if !buffers_to_reload.is_empty() {
118                let reload_task = project.update(cx, |project, cx| {
119                    project.reload_buffers(buffers_to_reload, true, cx)
120                });
121
122                let result = futures::select! {
123                    result = reload_task.fuse() => result,
124                    _ = event_stream.cancelled_by_user().fuse() => {
125                        anyhow::bail!("Restore cancelled by user");
126                    }
127                };
128                if let Err(error) = result {
129                    reload_errors.push(error.to_string());
130                }
131            }
132
133            let mut lines: Vec<String> = Vec::new();
134
135            if !restored_paths.is_empty() {
136                lines.push(format!("Restored {} file(s).", restored_paths.len()));
137            }
138            if !clean_paths.is_empty() {
139                lines.push(format!("{} clean.", clean_paths.len()));
140            }
141
142            if !not_found_paths.is_empty() {
143                lines.push(format!("Not found ({}):", not_found_paths.len()));
144                for path in &not_found_paths {
145                    lines.push(format!("- {}", path.display()));
146                }
147            }
148            if !open_errors.is_empty() {
149                lines.push(format!("Open failed ({}):", open_errors.len()));
150                for (path, error) in &open_errors {
151                    lines.push(format!("- {}: {}", path.display(), error));
152                }
153            }
154            if !dirty_check_errors.is_empty() {
155                lines.push(format!(
156                    "Dirty check failed ({}):",
157                    dirty_check_errors.len()
158                ));
159                for (path, error) in &dirty_check_errors {
160                    lines.push(format!("- {}: {}", path.display(), error));
161                }
162            }
163            if !reload_errors.is_empty() {
164                lines.push(format!("Reload failed ({}):", reload_errors.len()));
165                for error in &reload_errors {
166                    lines.push(format!("- {}", error));
167                }
168            }
169
170            if lines.is_empty() {
171                Ok("No paths provided.".to_string())
172            } else {
173                Ok(lines.join("\n"))
174            }
175        })
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use fs::Fs;
183    use gpui::TestAppContext;
184    use language::LineEnding;
185    use project::FakeFs;
186    use serde_json::json;
187    use settings::SettingsStore;
188    use util::path;
189
190    fn init_test(cx: &mut TestAppContext) {
191        cx.update(|cx| {
192            let settings_store = SettingsStore::test(cx);
193            cx.set_global(settings_store);
194        });
195    }
196
197    #[gpui::test]
198    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
199        init_test(cx);
200
201        let fs = FakeFs::new(cx.executor());
202        fs.insert_tree(
203            "/root",
204            json!({
205                "dirty.txt": "on disk: dirty\n",
206                "clean.txt": "on disk: clean\n",
207            }),
208        )
209        .await;
210
211        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
212        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
213
214        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
215        let dirty_project_path = project.read_with(cx, |project, cx| {
216            project
217                .find_project_path("root/dirty.txt", cx)
218                .expect("dirty.txt should exist in project")
219        });
220
221        let dirty_buffer = project
222            .update(cx, |project, cx| {
223                project.open_buffer(dirty_project_path, cx)
224            })
225            .await
226            .unwrap();
227        dirty_buffer.update(cx, |buffer, cx| {
228            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
229        });
230        assert!(
231            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
232            "dirty.txt buffer should be dirty before restore"
233        );
234
235        // Ensure clean.txt is opened but remains clean.
236        let clean_project_path = project.read_with(cx, |project, cx| {
237            project
238                .find_project_path("root/clean.txt", cx)
239                .expect("clean.txt should exist in project")
240        });
241
242        let clean_buffer = project
243            .update(cx, |project, cx| {
244                project.open_buffer(clean_project_path, cx)
245            })
246            .await
247            .unwrap();
248        assert!(
249            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
250            "clean.txt buffer should start clean"
251        );
252
253        let output = cx
254            .update(|cx| {
255                tool.clone().run(
256                    RestoreFileFromDiskToolInput {
257                        paths: vec![
258                            PathBuf::from("root/dirty.txt"),
259                            PathBuf::from("root/clean.txt"),
260                        ],
261                    },
262                    ToolCallEventStream::test().0,
263                    cx,
264                )
265            })
266            .await
267            .unwrap();
268
269        // Output should mention restored + clean.
270        assert!(
271            output.contains("Restored 1 file(s)."),
272            "expected restored count line, got:\n{output}"
273        );
274        assert!(
275            output.contains("1 clean."),
276            "expected clean count line, got:\n{output}"
277        );
278
279        // Effect: dirty buffer should be restored back to disk content and become clean.
280        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
281        assert_eq!(
282            dirty_text, "on disk: dirty\n",
283            "dirty.txt buffer should be restored to disk contents"
284        );
285        assert!(
286            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
287            "dirty.txt buffer should not be dirty after restore"
288        );
289
290        // Disk contents should be unchanged (restore-from-disk should not write).
291        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
292        assert_eq!(disk_dirty, "on disk: dirty\n");
293
294        // Sanity: clean buffer should remain clean and unchanged.
295        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
296        assert_eq!(clean_text, "on disk: clean\n");
297        assert!(
298            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
299            "clean.txt buffer should remain clean"
300        );
301
302        // Test empty paths case.
303        let output = cx
304            .update(|cx| {
305                tool.clone().run(
306                    RestoreFileFromDiskToolInput { paths: vec![] },
307                    ToolCallEventStream::test().0,
308                    cx,
309                )
310            })
311            .await
312            .unwrap();
313        assert_eq!(output, "No paths provided.");
314
315        // Test not-found path case (path outside the project root).
316        let output = cx
317            .update(|cx| {
318                tool.clone().run(
319                    RestoreFileFromDiskToolInput {
320                        paths: vec![PathBuf::from("nonexistent/path.txt")],
321                    },
322                    ToolCallEventStream::test().0,
323                    cx,
324                )
325            })
326            .await
327            .unwrap();
328        assert!(
329            output.contains("Not found (1):"),
330            "expected not-found header line, got:\n{output}"
331        );
332        assert!(
333            output.contains("- nonexistent/path.txt"),
334            "expected not-found path bullet, got:\n{output}"
335        );
336
337        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
338    }
339}