restore_file_from_disk_tool.rs

  1use agent_client_protocol as acp;
  2use anyhow::Result;
  3use collections::FxHashSet;
  4use gpui::{App, Entity, SharedString, Task};
  5use language::Buffer;
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::path::PathBuf;
 10use std::sync::Arc;
 11
 12use crate::{AgentTool, ToolCallEventStream};
 13
 14/// Discards unsaved changes in open buffers by reloading file contents from disk.
 15///
 16/// Use this tool when:
 17/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 18/// - You want to reset files to the on-disk state before retrying an edit.
 19///
 20/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 22pub struct RestoreFileFromDiskToolInput {
 23    /// The paths of the files to restore from disk.
 24    pub paths: Vec<PathBuf>,
 25}
 26
 27pub struct RestoreFileFromDiskTool {
 28    project: Entity<Project>,
 29}
 30
 31impl RestoreFileFromDiskTool {
 32    pub fn new(project: Entity<Project>) -> Self {
 33        Self { project }
 34    }
 35}
 36
 37impl AgentTool for RestoreFileFromDiskTool {
 38    type Input = RestoreFileFromDiskToolInput;
 39    type Output = String;
 40
 41    fn name() -> &'static str {
 42        "restore_file_from_disk"
 43    }
 44
 45    fn kind() -> acp::ToolKind {
 46        acp::ToolKind::Other
 47    }
 48
 49    fn initial_title(
 50        &self,
 51        input: Result<Self::Input, serde_json::Value>,
 52        _cx: &mut App,
 53    ) -> SharedString {
 54        match input {
 55            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 56            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 57            Err(_) => "Restore files from disk".into(),
 58        }
 59    }
 60
 61    fn run(
 62        self: Arc<Self>,
 63        input: Self::Input,
 64        _event_stream: ToolCallEventStream,
 65        cx: &mut App,
 66    ) -> Task<Result<String>> {
 67        let project = self.project.clone();
 68        let input_paths = input.paths;
 69
 70        cx.spawn(async move |cx| {
 71            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
 72
 73            let mut restored_paths: Vec<PathBuf> = Vec::new();
 74            let mut clean_paths: Vec<PathBuf> = Vec::new();
 75            let mut not_found_paths: Vec<PathBuf> = Vec::new();
 76            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
 77            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
 78            let mut reload_errors: Vec<String> = Vec::new();
 79
 80            for path in input_paths {
 81                let Some(project_path) =
 82                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
 83                else {
 84                    not_found_paths.push(path);
 85                    continue;
 86                };
 87
 88                let open_buffer_task =
 89                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 90
 91                let buffer = match open_buffer_task.await {
 92                    Ok(buffer) => buffer,
 93                    Err(error) => {
 94                        open_errors.push((path, error.to_string()));
 95                        continue;
 96                    }
 97                };
 98
 99                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
100
101                if is_dirty {
102                    buffers_to_reload.insert(buffer);
103                    restored_paths.push(path);
104                } else {
105                    clean_paths.push(path);
106                }
107            }
108
109            if !buffers_to_reload.is_empty() {
110                let reload_task = project.update(cx, |project, cx| {
111                    project.reload_buffers(buffers_to_reload, true, cx)
112                });
113
114                if let Err(error) = reload_task.await {
115                    reload_errors.push(error.to_string());
116                }
117            }
118
119            let mut lines: Vec<String> = Vec::new();
120
121            if !restored_paths.is_empty() {
122                lines.push(format!("Restored {} file(s).", restored_paths.len()));
123            }
124            if !clean_paths.is_empty() {
125                lines.push(format!("{} clean.", clean_paths.len()));
126            }
127
128            if !not_found_paths.is_empty() {
129                lines.push(format!("Not found ({}):", not_found_paths.len()));
130                for path in &not_found_paths {
131                    lines.push(format!("- {}", path.display()));
132                }
133            }
134            if !open_errors.is_empty() {
135                lines.push(format!("Open failed ({}):", open_errors.len()));
136                for (path, error) in &open_errors {
137                    lines.push(format!("- {}: {}", path.display(), error));
138                }
139            }
140            if !dirty_check_errors.is_empty() {
141                lines.push(format!(
142                    "Dirty check failed ({}):",
143                    dirty_check_errors.len()
144                ));
145                for (path, error) in &dirty_check_errors {
146                    lines.push(format!("- {}: {}", path.display(), error));
147                }
148            }
149            if !reload_errors.is_empty() {
150                lines.push(format!("Reload failed ({}):", reload_errors.len()));
151                for error in &reload_errors {
152                    lines.push(format!("- {}", error));
153                }
154            }
155
156            if lines.is_empty() {
157                Ok("No paths provided.".to_string())
158            } else {
159                Ok(lines.join("\n"))
160            }
161        })
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use fs::Fs;
169    use gpui::TestAppContext;
170    use language::LineEnding;
171    use project::FakeFs;
172    use serde_json::json;
173    use settings::SettingsStore;
174    use util::path;
175
176    fn init_test(cx: &mut TestAppContext) {
177        cx.update(|cx| {
178            let settings_store = SettingsStore::test(cx);
179            cx.set_global(settings_store);
180        });
181    }
182
183    #[gpui::test]
184    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
185        init_test(cx);
186
187        let fs = FakeFs::new(cx.executor());
188        fs.insert_tree(
189            "/root",
190            json!({
191                "dirty.txt": "on disk: dirty\n",
192                "clean.txt": "on disk: clean\n",
193            }),
194        )
195        .await;
196
197        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
198        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
199
200        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
201        let dirty_project_path = project.read_with(cx, |project, cx| {
202            project
203                .find_project_path("root/dirty.txt", cx)
204                .expect("dirty.txt should exist in project")
205        });
206
207        let dirty_buffer = project
208            .update(cx, |project, cx| {
209                project.open_buffer(dirty_project_path, cx)
210            })
211            .await
212            .unwrap();
213        dirty_buffer.update(cx, |buffer, cx| {
214            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
215        });
216        assert!(
217            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
218            "dirty.txt buffer should be dirty before restore"
219        );
220
221        // Ensure clean.txt is opened but remains clean.
222        let clean_project_path = project.read_with(cx, |project, cx| {
223            project
224                .find_project_path("root/clean.txt", cx)
225                .expect("clean.txt should exist in project")
226        });
227
228        let clean_buffer = project
229            .update(cx, |project, cx| {
230                project.open_buffer(clean_project_path, cx)
231            })
232            .await
233            .unwrap();
234        assert!(
235            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
236            "clean.txt buffer should start clean"
237        );
238
239        let output = cx
240            .update(|cx| {
241                tool.clone().run(
242                    RestoreFileFromDiskToolInput {
243                        paths: vec![
244                            PathBuf::from("root/dirty.txt"),
245                            PathBuf::from("root/clean.txt"),
246                        ],
247                    },
248                    ToolCallEventStream::test().0,
249                    cx,
250                )
251            })
252            .await
253            .unwrap();
254
255        // Output should mention restored + clean.
256        assert!(
257            output.contains("Restored 1 file(s)."),
258            "expected restored count line, got:\n{output}"
259        );
260        assert!(
261            output.contains("1 clean."),
262            "expected clean count line, got:\n{output}"
263        );
264
265        // Effect: dirty buffer should be restored back to disk content and become clean.
266        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
267        assert_eq!(
268            dirty_text, "on disk: dirty\n",
269            "dirty.txt buffer should be restored to disk contents"
270        );
271        assert!(
272            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
273            "dirty.txt buffer should not be dirty after restore"
274        );
275
276        // Disk contents should be unchanged (restore-from-disk should not write).
277        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
278        assert_eq!(disk_dirty, "on disk: dirty\n");
279
280        // Sanity: clean buffer should remain clean and unchanged.
281        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
282        assert_eq!(clean_text, "on disk: clean\n");
283        assert!(
284            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
285            "clean.txt buffer should remain clean"
286        );
287
288        // Test empty paths case.
289        let output = cx
290            .update(|cx| {
291                tool.clone().run(
292                    RestoreFileFromDiskToolInput { paths: vec![] },
293                    ToolCallEventStream::test().0,
294                    cx,
295                )
296            })
297            .await
298            .unwrap();
299        assert_eq!(output, "No paths provided.");
300
301        // Test not-found path case (path outside the project root).
302        let output = cx
303            .update(|cx| {
304                tool.clone().run(
305                    RestoreFileFromDiskToolInput {
306                        paths: vec![PathBuf::from("nonexistent/path.txt")],
307                    },
308                    ToolCallEventStream::test().0,
309                    cx,
310                )
311            })
312            .await
313            .unwrap();
314        assert!(
315            output.contains("Not found (1):"),
316            "expected not-found header line, got:\n{output}"
317        );
318        assert!(
319            output.contains("- nonexistent/path.txt"),
320            "expected not-found path bullet, got:\n{output}"
321        );
322
323        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
324    }
325}