1use agent_client_protocol as acp;
2use anyhow::Result;
3use collections::FxHashSet;
4use gpui::{App, Entity, SharedString, Task};
5use language::Buffer;
6use project::Project;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use crate::{AgentTool, ToolCallEventStream};
13
14/// Discards unsaved changes in open buffers by reloading file contents from disk.
15///
16/// Use this tool when:
17/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
18/// - You want to reset files to the on-disk state before retrying an edit.
19///
20/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
22pub struct RestoreFileFromDiskToolInput {
23 /// The paths of the files to restore from disk.
24 pub paths: Vec<PathBuf>,
25}
26
27pub struct RestoreFileFromDiskTool {
28 project: Entity<Project>,
29}
30
31impl RestoreFileFromDiskTool {
32 pub fn new(project: Entity<Project>) -> Self {
33 Self { project }
34 }
35}
36
37impl AgentTool for RestoreFileFromDiskTool {
38 type Input = RestoreFileFromDiskToolInput;
39 type Output = String;
40
41 fn name() -> &'static str {
42 "restore_file_from_disk"
43 }
44
45 fn kind() -> acp::ToolKind {
46 acp::ToolKind::Other
47 }
48
49 fn initial_title(
50 &self,
51 input: Result<Self::Input, serde_json::Value>,
52 _cx: &mut App,
53 ) -> SharedString {
54 match input {
55 Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
56 Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
57 Err(_) => "Restore files from disk".into(),
58 }
59 }
60
61 fn run(
62 self: Arc<Self>,
63 input: Self::Input,
64 _event_stream: ToolCallEventStream,
65 cx: &mut App,
66 ) -> Task<Result<String>> {
67 let project = self.project.clone();
68 let input_paths = input.paths;
69
70 cx.spawn(async move |cx| {
71 let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
72
73 let mut restored_paths: Vec<PathBuf> = Vec::new();
74 let mut clean_paths: Vec<PathBuf> = Vec::new();
75 let mut not_found_paths: Vec<PathBuf> = Vec::new();
76 let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
77 let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
78 let mut reload_errors: Vec<String> = Vec::new();
79
80 for path in input_paths {
81 let Some(project_path) =
82 project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
83 else {
84 not_found_paths.push(path);
85 continue;
86 };
87
88 let open_buffer_task =
89 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
90
91 let buffer = match open_buffer_task.await {
92 Ok(buffer) => buffer,
93 Err(error) => {
94 open_errors.push((path, error.to_string()));
95 continue;
96 }
97 };
98
99 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
100
101 if is_dirty {
102 buffers_to_reload.insert(buffer);
103 restored_paths.push(path);
104 } else {
105 clean_paths.push(path);
106 }
107 }
108
109 if !buffers_to_reload.is_empty() {
110 let reload_task = project.update(cx, |project, cx| {
111 project.reload_buffers(buffers_to_reload, true, cx)
112 });
113
114 if let Err(error) = reload_task.await {
115 reload_errors.push(error.to_string());
116 }
117 }
118
119 let mut lines: Vec<String> = Vec::new();
120
121 if !restored_paths.is_empty() {
122 lines.push(format!("Restored {} file(s).", restored_paths.len()));
123 }
124 if !clean_paths.is_empty() {
125 lines.push(format!("{} clean.", clean_paths.len()));
126 }
127
128 if !not_found_paths.is_empty() {
129 lines.push(format!("Not found ({}):", not_found_paths.len()));
130 for path in ¬_found_paths {
131 lines.push(format!("- {}", path.display()));
132 }
133 }
134 if !open_errors.is_empty() {
135 lines.push(format!("Open failed ({}):", open_errors.len()));
136 for (path, error) in &open_errors {
137 lines.push(format!("- {}: {}", path.display(), error));
138 }
139 }
140 if !dirty_check_errors.is_empty() {
141 lines.push(format!(
142 "Dirty check failed ({}):",
143 dirty_check_errors.len()
144 ));
145 for (path, error) in &dirty_check_errors {
146 lines.push(format!("- {}: {}", path.display(), error));
147 }
148 }
149 if !reload_errors.is_empty() {
150 lines.push(format!("Reload failed ({}):", reload_errors.len()));
151 for error in &reload_errors {
152 lines.push(format!("- {}", error));
153 }
154 }
155
156 if lines.is_empty() {
157 Ok("No paths provided.".to_string())
158 } else {
159 Ok(lines.join("\n"))
160 }
161 })
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use fs::Fs;
169 use gpui::TestAppContext;
170 use language::LineEnding;
171 use project::FakeFs;
172 use serde_json::json;
173 use settings::SettingsStore;
174 use util::path;
175
176 fn init_test(cx: &mut TestAppContext) {
177 cx.update(|cx| {
178 let settings_store = SettingsStore::test(cx);
179 cx.set_global(settings_store);
180 });
181 }
182
183 #[gpui::test]
184 async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
185 init_test(cx);
186
187 let fs = FakeFs::new(cx.executor());
188 fs.insert_tree(
189 "/root",
190 json!({
191 "dirty.txt": "on disk: dirty\n",
192 "clean.txt": "on disk: clean\n",
193 }),
194 )
195 .await;
196
197 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
198 let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
199
200 // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
201 let dirty_project_path = project.read_with(cx, |project, cx| {
202 project
203 .find_project_path("root/dirty.txt", cx)
204 .expect("dirty.txt should exist in project")
205 });
206
207 let dirty_buffer = project
208 .update(cx, |project, cx| {
209 project.open_buffer(dirty_project_path, cx)
210 })
211 .await
212 .unwrap();
213 dirty_buffer.update(cx, |buffer, cx| {
214 buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
215 });
216 assert!(
217 dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
218 "dirty.txt buffer should be dirty before restore"
219 );
220
221 // Ensure clean.txt is opened but remains clean.
222 let clean_project_path = project.read_with(cx, |project, cx| {
223 project
224 .find_project_path("root/clean.txt", cx)
225 .expect("clean.txt should exist in project")
226 });
227
228 let clean_buffer = project
229 .update(cx, |project, cx| {
230 project.open_buffer(clean_project_path, cx)
231 })
232 .await
233 .unwrap();
234 assert!(
235 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
236 "clean.txt buffer should start clean"
237 );
238
239 let output = cx
240 .update(|cx| {
241 tool.clone().run(
242 RestoreFileFromDiskToolInput {
243 paths: vec![
244 PathBuf::from("root/dirty.txt"),
245 PathBuf::from("root/clean.txt"),
246 ],
247 },
248 ToolCallEventStream::test().0,
249 cx,
250 )
251 })
252 .await
253 .unwrap();
254
255 // Output should mention restored + clean.
256 assert!(
257 output.contains("Restored 1 file(s)."),
258 "expected restored count line, got:\n{output}"
259 );
260 assert!(
261 output.contains("1 clean."),
262 "expected clean count line, got:\n{output}"
263 );
264
265 // Effect: dirty buffer should be restored back to disk content and become clean.
266 let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
267 assert_eq!(
268 dirty_text, "on disk: dirty\n",
269 "dirty.txt buffer should be restored to disk contents"
270 );
271 assert!(
272 !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
273 "dirty.txt buffer should not be dirty after restore"
274 );
275
276 // Disk contents should be unchanged (restore-from-disk should not write).
277 let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
278 assert_eq!(disk_dirty, "on disk: dirty\n");
279
280 // Sanity: clean buffer should remain clean and unchanged.
281 let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
282 assert_eq!(clean_text, "on disk: clean\n");
283 assert!(
284 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
285 "clean.txt buffer should remain clean"
286 );
287
288 // Test empty paths case.
289 let output = cx
290 .update(|cx| {
291 tool.clone().run(
292 RestoreFileFromDiskToolInput { paths: vec![] },
293 ToolCallEventStream::test().0,
294 cx,
295 )
296 })
297 .await
298 .unwrap();
299 assert_eq!(output, "No paths provided.");
300
301 // Test not-found path case (path outside the project root).
302 let output = cx
303 .update(|cx| {
304 tool.clone().run(
305 RestoreFileFromDiskToolInput {
306 paths: vec![PathBuf::from("nonexistent/path.txt")],
307 },
308 ToolCallEventStream::test().0,
309 cx,
310 )
311 })
312 .await
313 .unwrap();
314 assert!(
315 output.contains("Not found (1):"),
316 "expected not-found header line, got:\n{output}"
317 );
318 assert!(
319 output.contains("- nonexistent/path.txt"),
320 "expected not-found path bullet, got:\n{output}"
321 );
322
323 let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
324 }
325}