restore_file_from_disk_tool.rs

  1use agent_client_protocol as acp;
  2use anyhow::Result;
  3use collections::FxHashSet;
  4use futures::FutureExt as _;
  5use gpui::{App, Entity, SharedString, Task};
  6use language::Buffer;
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::path::PathBuf;
 11use std::sync::Arc;
 12
 13use crate::{AgentTool, ToolCallEventStream};
 14
 15/// Discards unsaved changes in open buffers by reloading file contents from disk.
 16///
 17/// Use this tool when:
 18/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 19/// - You want to reset files to the on-disk state before retrying an edit.
 20///
 21/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 23pub struct RestoreFileFromDiskToolInput {
 24    /// The paths of the files to restore from disk.
 25    pub paths: Vec<PathBuf>,
 26}
 27
 28pub struct RestoreFileFromDiskTool {
 29    project: Entity<Project>,
 30}
 31
 32impl RestoreFileFromDiskTool {
 33    pub fn new(project: Entity<Project>) -> Self {
 34        Self { project }
 35    }
 36}
 37
 38impl AgentTool for RestoreFileFromDiskTool {
 39    type Input = RestoreFileFromDiskToolInput;
 40    type Output = String;
 41
 42    const NAME: &'static str = "restore_file_from_disk";
 43
 44    fn kind() -> acp::ToolKind {
 45        acp::ToolKind::Other
 46    }
 47
 48    fn initial_title(
 49        &self,
 50        input: Result<Self::Input, serde_json::Value>,
 51        _cx: &mut App,
 52    ) -> SharedString {
 53        match input {
 54            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 55            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 56            Err(_) => "Restore files from disk".into(),
 57        }
 58    }
 59
 60    fn run(
 61        self: Arc<Self>,
 62        input: Self::Input,
 63        event_stream: ToolCallEventStream,
 64        cx: &mut App,
 65    ) -> Task<Result<String>> {
 66        let project = self.project.clone();
 67        let input_paths = input.paths;
 68
 69        cx.spawn(async move |cx| {
 70            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
 71
 72            let mut restored_paths: Vec<PathBuf> = Vec::new();
 73            let mut clean_paths: Vec<PathBuf> = Vec::new();
 74            let mut not_found_paths: Vec<PathBuf> = Vec::new();
 75            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
 76            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
 77            let mut reload_errors: Vec<String> = Vec::new();
 78
 79            for path in input_paths {
 80                let Some(project_path) =
 81                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
 82                else {
 83                    not_found_paths.push(path);
 84                    continue;
 85                };
 86
 87                let open_buffer_task =
 88                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 89
 90                let buffer = futures::select! {
 91                    result = open_buffer_task.fuse() => {
 92                        match result {
 93                            Ok(buffer) => buffer,
 94                            Err(error) => {
 95                                open_errors.push((path, error.to_string()));
 96                                continue;
 97                            }
 98                        }
 99                    }
100                    _ = event_stream.cancelled_by_user().fuse() => {
101                        anyhow::bail!("Restore cancelled by user");
102                    }
103                };
104
105                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
106
107                if is_dirty {
108                    buffers_to_reload.insert(buffer);
109                    restored_paths.push(path);
110                } else {
111                    clean_paths.push(path);
112                }
113            }
114
115            if !buffers_to_reload.is_empty() {
116                let reload_task = project.update(cx, |project, cx| {
117                    project.reload_buffers(buffers_to_reload, true, cx)
118                });
119
120                let result = futures::select! {
121                    result = reload_task.fuse() => result,
122                    _ = event_stream.cancelled_by_user().fuse() => {
123                        anyhow::bail!("Restore cancelled by user");
124                    }
125                };
126                if let Err(error) = result {
127                    reload_errors.push(error.to_string());
128                }
129            }
130
131            let mut lines: Vec<String> = Vec::new();
132
133            if !restored_paths.is_empty() {
134                lines.push(format!("Restored {} file(s).", restored_paths.len()));
135            }
136            if !clean_paths.is_empty() {
137                lines.push(format!("{} clean.", clean_paths.len()));
138            }
139
140            if !not_found_paths.is_empty() {
141                lines.push(format!("Not found ({}):", not_found_paths.len()));
142                for path in &not_found_paths {
143                    lines.push(format!("- {}", path.display()));
144                }
145            }
146            if !open_errors.is_empty() {
147                lines.push(format!("Open failed ({}):", open_errors.len()));
148                for (path, error) in &open_errors {
149                    lines.push(format!("- {}: {}", path.display(), error));
150                }
151            }
152            if !dirty_check_errors.is_empty() {
153                lines.push(format!(
154                    "Dirty check failed ({}):",
155                    dirty_check_errors.len()
156                ));
157                for (path, error) in &dirty_check_errors {
158                    lines.push(format!("- {}: {}", path.display(), error));
159                }
160            }
161            if !reload_errors.is_empty() {
162                lines.push(format!("Reload failed ({}):", reload_errors.len()));
163                for error in &reload_errors {
164                    lines.push(format!("- {}", error));
165                }
166            }
167
168            if lines.is_empty() {
169                Ok("No paths provided.".to_string())
170            } else {
171                Ok(lines.join("\n"))
172            }
173        })
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use fs::Fs;
181    use gpui::TestAppContext;
182    use language::LineEnding;
183    use project::FakeFs;
184    use serde_json::json;
185    use settings::SettingsStore;
186    use util::path;
187
188    fn init_test(cx: &mut TestAppContext) {
189        cx.update(|cx| {
190            let settings_store = SettingsStore::test(cx);
191            cx.set_global(settings_store);
192        });
193    }
194
195    #[gpui::test]
196    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
197        init_test(cx);
198
199        let fs = FakeFs::new(cx.executor());
200        fs.insert_tree(
201            "/root",
202            json!({
203                "dirty.txt": "on disk: dirty\n",
204                "clean.txt": "on disk: clean\n",
205            }),
206        )
207        .await;
208
209        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
210        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
211
212        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
213        let dirty_project_path = project.read_with(cx, |project, cx| {
214            project
215                .find_project_path("root/dirty.txt", cx)
216                .expect("dirty.txt should exist in project")
217        });
218
219        let dirty_buffer = project
220            .update(cx, |project, cx| {
221                project.open_buffer(dirty_project_path, cx)
222            })
223            .await
224            .unwrap();
225        dirty_buffer.update(cx, |buffer, cx| {
226            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
227        });
228        assert!(
229            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
230            "dirty.txt buffer should be dirty before restore"
231        );
232
233        // Ensure clean.txt is opened but remains clean.
234        let clean_project_path = project.read_with(cx, |project, cx| {
235            project
236                .find_project_path("root/clean.txt", cx)
237                .expect("clean.txt should exist in project")
238        });
239
240        let clean_buffer = project
241            .update(cx, |project, cx| {
242                project.open_buffer(clean_project_path, cx)
243            })
244            .await
245            .unwrap();
246        assert!(
247            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
248            "clean.txt buffer should start clean"
249        );
250
251        let output = cx
252            .update(|cx| {
253                tool.clone().run(
254                    RestoreFileFromDiskToolInput {
255                        paths: vec![
256                            PathBuf::from("root/dirty.txt"),
257                            PathBuf::from("root/clean.txt"),
258                        ],
259                    },
260                    ToolCallEventStream::test().0,
261                    cx,
262                )
263            })
264            .await
265            .unwrap();
266
267        // Output should mention restored + clean.
268        assert!(
269            output.contains("Restored 1 file(s)."),
270            "expected restored count line, got:\n{output}"
271        );
272        assert!(
273            output.contains("1 clean."),
274            "expected clean count line, got:\n{output}"
275        );
276
277        // Effect: dirty buffer should be restored back to disk content and become clean.
278        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
279        assert_eq!(
280            dirty_text, "on disk: dirty\n",
281            "dirty.txt buffer should be restored to disk contents"
282        );
283        assert!(
284            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
285            "dirty.txt buffer should not be dirty after restore"
286        );
287
288        // Disk contents should be unchanged (restore-from-disk should not write).
289        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
290        assert_eq!(disk_dirty, "on disk: dirty\n");
291
292        // Sanity: clean buffer should remain clean and unchanged.
293        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
294        assert_eq!(clean_text, "on disk: clean\n");
295        assert!(
296            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
297            "clean.txt buffer should remain clean"
298        );
299
300        // Test empty paths case.
301        let output = cx
302            .update(|cx| {
303                tool.clone().run(
304                    RestoreFileFromDiskToolInput { paths: vec![] },
305                    ToolCallEventStream::test().0,
306                    cx,
307                )
308            })
309            .await
310            .unwrap();
311        assert_eq!(output, "No paths provided.");
312
313        // Test not-found path case (path outside the project root).
314        let output = cx
315            .update(|cx| {
316                tool.clone().run(
317                    RestoreFileFromDiskToolInput {
318                        paths: vec![PathBuf::from("nonexistent/path.txt")],
319                    },
320                    ToolCallEventStream::test().0,
321                    cx,
322                )
323            })
324            .await
325            .unwrap();
326        assert!(
327            output.contains("Not found (1):"),
328            "expected not-found header line, got:\n{output}"
329        );
330        assert!(
331            output.contains("- nonexistent/path.txt"),
332            "expected not-found path bullet, got:\n{output}"
333        );
334
335        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
336    }
337}