restore_file_from_disk_tool.rs

  1use agent_client_protocol as acp;
  2use anyhow::Result;
  3use collections::FxHashSet;
  4use gpui::{App, Entity, SharedString, Task};
  5use language::Buffer;
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::path::PathBuf;
 10use std::sync::Arc;
 11
 12use crate::{AgentTool, ToolCallEventStream};
 13
 14/// Discards unsaved changes in open buffers by reloading file contents from disk.
 15///
 16/// Use this tool when:
 17/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
 18/// - You want to reset files to the on-disk state before retrying an edit.
 19///
 20/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
 21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 22pub struct RestoreFileFromDiskToolInput {
 23    /// The paths of the files to restore from disk.
 24    pub paths: Vec<PathBuf>,
 25}
 26
 27pub struct RestoreFileFromDiskTool {
 28    project: Entity<Project>,
 29}
 30
 31impl RestoreFileFromDiskTool {
 32    pub fn new(project: Entity<Project>) -> Self {
 33        Self { project }
 34    }
 35}
 36
 37impl AgentTool for RestoreFileFromDiskTool {
 38    type Input = RestoreFileFromDiskToolInput;
 39    type Output = String;
 40
 41    fn name() -> &'static str {
 42        "restore_file_from_disk"
 43    }
 44
 45    fn kind() -> acp::ToolKind {
 46        acp::ToolKind::Other
 47    }
 48
 49    fn initial_title(
 50        &self,
 51        input: Result<Self::Input, serde_json::Value>,
 52        _cx: &mut App,
 53    ) -> SharedString {
 54        match input {
 55            Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
 56            Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
 57            Err(_) => "Restore files from disk".into(),
 58        }
 59    }
 60
 61    fn run(
 62        self: Arc<Self>,
 63        input: Self::Input,
 64        _event_stream: ToolCallEventStream,
 65        cx: &mut App,
 66    ) -> Task<Result<String>> {
 67        let project = self.project.clone();
 68        let input_paths = input.paths;
 69
 70        cx.spawn(async move |cx| {
 71            let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
 72
 73            let mut restored_paths: Vec<PathBuf> = Vec::new();
 74            let mut clean_paths: Vec<PathBuf> = Vec::new();
 75            let mut not_found_paths: Vec<PathBuf> = Vec::new();
 76            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
 77            let mut dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
 78            let mut reload_errors: Vec<String> = Vec::new();
 79
 80            for path in input_paths {
 81                let project_path =
 82                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx));
 83
 84                let project_path = match project_path {
 85                    Ok(Some(project_path)) => project_path,
 86                    Ok(None) => {
 87                        not_found_paths.push(path);
 88                        continue;
 89                    }
 90                    Err(error) => {
 91                        open_errors.push((path, error.to_string()));
 92                        continue;
 93                    }
 94                };
 95
 96                let open_buffer_task =
 97                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 98
 99                let buffer = match open_buffer_task {
100                    Ok(task) => match task.await {
101                        Ok(buffer) => buffer,
102                        Err(error) => {
103                            open_errors.push((path, error.to_string()));
104                            continue;
105                        }
106                    },
107                    Err(error) => {
108                        open_errors.push((path, error.to_string()));
109                        continue;
110                    }
111                };
112
113                let is_dirty = match buffer.read_with(cx, |buffer, _| buffer.is_dirty()) {
114                    Ok(is_dirty) => is_dirty,
115                    Err(error) => {
116                        dirty_check_errors.push((path, error.to_string()));
117                        continue;
118                    }
119                };
120
121                if is_dirty {
122                    buffers_to_reload.insert(buffer);
123                    restored_paths.push(path);
124                } else {
125                    clean_paths.push(path);
126                }
127            }
128
129            if !buffers_to_reload.is_empty() {
130                let reload_task = project.update(cx, |project, cx| {
131                    project.reload_buffers(buffers_to_reload, true, cx)
132                });
133
134                match reload_task {
135                    Ok(task) => {
136                        if let Err(error) = task.await {
137                            reload_errors.push(error.to_string());
138                        }
139                    }
140                    Err(error) => {
141                        reload_errors.push(error.to_string());
142                    }
143                }
144            }
145
146            let mut lines: Vec<String> = Vec::new();
147
148            if !restored_paths.is_empty() {
149                lines.push(format!("Restored {} file(s).", restored_paths.len()));
150            }
151            if !clean_paths.is_empty() {
152                lines.push(format!("{} clean.", clean_paths.len()));
153            }
154
155            if !not_found_paths.is_empty() {
156                lines.push(format!("Not found ({}):", not_found_paths.len()));
157                for path in &not_found_paths {
158                    lines.push(format!("- {}", path.display()));
159                }
160            }
161            if !open_errors.is_empty() {
162                lines.push(format!("Open failed ({}):", open_errors.len()));
163                for (path, error) in &open_errors {
164                    lines.push(format!("- {}: {}", path.display(), error));
165                }
166            }
167            if !dirty_check_errors.is_empty() {
168                lines.push(format!(
169                    "Dirty check failed ({}):",
170                    dirty_check_errors.len()
171                ));
172                for (path, error) in &dirty_check_errors {
173                    lines.push(format!("- {}: {}", path.display(), error));
174                }
175            }
176            if !reload_errors.is_empty() {
177                lines.push(format!("Reload failed ({}):", reload_errors.len()));
178                for error in &reload_errors {
179                    lines.push(format!("- {}", error));
180                }
181            }
182
183            if lines.is_empty() {
184                Ok("No paths provided.".to_string())
185            } else {
186                Ok(lines.join("\n"))
187            }
188        })
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use fs::Fs;
196    use gpui::TestAppContext;
197    use language::LineEnding;
198    use project::FakeFs;
199    use serde_json::json;
200    use settings::SettingsStore;
201    use util::path;
202
203    fn init_test(cx: &mut TestAppContext) {
204        cx.update(|cx| {
205            let settings_store = SettingsStore::test(cx);
206            cx.set_global(settings_store);
207        });
208    }
209
210    #[gpui::test]
211    async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
212        init_test(cx);
213
214        let fs = FakeFs::new(cx.executor());
215        fs.insert_tree(
216            "/root",
217            json!({
218                "dirty.txt": "on disk: dirty\n",
219                "clean.txt": "on disk: clean\n",
220            }),
221        )
222        .await;
223
224        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
225        let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
226
227        // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
228        let dirty_project_path = project.read_with(cx, |project, cx| {
229            project
230                .find_project_path("root/dirty.txt", cx)
231                .expect("dirty.txt should exist in project")
232        });
233
234        let dirty_buffer = project
235            .update(cx, |project, cx| {
236                project.open_buffer(dirty_project_path, cx)
237            })
238            .await
239            .unwrap();
240        dirty_buffer.update(cx, |buffer, cx| {
241            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
242        });
243        assert!(
244            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
245            "dirty.txt buffer should be dirty before restore"
246        );
247
248        // Ensure clean.txt is opened but remains clean.
249        let clean_project_path = project.read_with(cx, |project, cx| {
250            project
251                .find_project_path("root/clean.txt", cx)
252                .expect("clean.txt should exist in project")
253        });
254
255        let clean_buffer = project
256            .update(cx, |project, cx| {
257                project.open_buffer(clean_project_path, cx)
258            })
259            .await
260            .unwrap();
261        assert!(
262            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
263            "clean.txt buffer should start clean"
264        );
265
266        let output = cx
267            .update(|cx| {
268                tool.clone().run(
269                    RestoreFileFromDiskToolInput {
270                        paths: vec![
271                            PathBuf::from("root/dirty.txt"),
272                            PathBuf::from("root/clean.txt"),
273                        ],
274                    },
275                    ToolCallEventStream::test().0,
276                    cx,
277                )
278            })
279            .await
280            .unwrap();
281
282        // Output should mention restored + clean.
283        assert!(
284            output.contains("Restored 1 file(s)."),
285            "expected restored count line, got:\n{output}"
286        );
287        assert!(
288            output.contains("1 clean."),
289            "expected clean count line, got:\n{output}"
290        );
291
292        // Effect: dirty buffer should be restored back to disk content and become clean.
293        let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
294        assert_eq!(
295            dirty_text, "on disk: dirty\n",
296            "dirty.txt buffer should be restored to disk contents"
297        );
298        assert!(
299            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
300            "dirty.txt buffer should not be dirty after restore"
301        );
302
303        // Disk contents should be unchanged (restore-from-disk should not write).
304        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
305        assert_eq!(disk_dirty, "on disk: dirty\n");
306
307        // Sanity: clean buffer should remain clean and unchanged.
308        let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
309        assert_eq!(clean_text, "on disk: clean\n");
310        assert!(
311            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
312            "clean.txt buffer should remain clean"
313        );
314
315        // Test empty paths case.
316        let output = cx
317            .update(|cx| {
318                tool.clone().run(
319                    RestoreFileFromDiskToolInput { paths: vec![] },
320                    ToolCallEventStream::test().0,
321                    cx,
322                )
323            })
324            .await
325            .unwrap();
326        assert_eq!(output, "No paths provided.");
327
328        // Test not-found path case (path outside the project root).
329        let output = cx
330            .update(|cx| {
331                tool.clone().run(
332                    RestoreFileFromDiskToolInput {
333                        paths: vec![PathBuf::from("nonexistent/path.txt")],
334                    },
335                    ToolCallEventStream::test().0,
336                    cx,
337                )
338            })
339            .await
340            .unwrap();
341        assert!(
342            output.contains("Not found (1):"),
343            "expected not-found header line, got:\n{output}"
344        );
345        assert!(
346            output.contains("- nonexistent/path.txt"),
347            "expected not-found path bullet, got:\n{output}"
348        );
349
350        let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
351    }
352}