1use agent_client_protocol as acp;
2use anyhow::Result;
3use collections::FxHashSet;
4use futures::FutureExt as _;
5use gpui::{App, Entity, SharedString, Task};
6use language::Buffer;
7use project::Project;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use std::sync::Arc;
12
13use crate::{AgentTool, ToolCallEventStream};
14
15/// Discards unsaved changes in open buffers by reloading file contents from disk.
16///
17/// Use this tool when:
18/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
19/// - You want to reset files to the on-disk state before retrying an edit.
20///
21/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
23pub struct RestoreFileFromDiskToolInput {
24 /// The paths of the files to restore from disk.
25 pub paths: Vec<PathBuf>,
26}
27
28pub struct RestoreFileFromDiskTool {
29 project: Entity<Project>,
30}
31
32impl RestoreFileFromDiskTool {
33 pub fn new(project: Entity<Project>) -> Self {
34 Self { project }
35 }
36}
37
38impl AgentTool for RestoreFileFromDiskTool {
39 type Input = RestoreFileFromDiskToolInput;
40 type Output = String;
41
42 fn name() -> &'static str {
43 "restore_file_from_disk"
44 }
45
46 fn kind() -> acp::ToolKind {
47 acp::ToolKind::Other
48 }
49
50 fn initial_title(
51 &self,
52 input: Result<Self::Input, serde_json::Value>,
53 _cx: &mut App,
54 ) -> SharedString {
55 match input {
56 Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
57 Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
58 Err(_) => "Restore files from disk".into(),
59 }
60 }
61
62 fn run(
63 self: Arc<Self>,
64 input: Self::Input,
65 event_stream: ToolCallEventStream,
66 cx: &mut App,
67 ) -> Task<Result<String>> {
68 let project = self.project.clone();
69 let input_paths = input.paths;
70
71 cx.spawn(async move |cx| {
72 let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
73
74 let mut restored_paths: Vec<PathBuf> = Vec::new();
75 let mut clean_paths: Vec<PathBuf> = Vec::new();
76 let mut not_found_paths: Vec<PathBuf> = Vec::new();
77 let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
78 let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
79 let mut reload_errors: Vec<String> = Vec::new();
80
81 for path in input_paths {
82 let Some(project_path) =
83 project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
84 else {
85 not_found_paths.push(path);
86 continue;
87 };
88
89 let open_buffer_task =
90 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
91
92 let buffer = futures::select! {
93 result = open_buffer_task.fuse() => {
94 match result {
95 Ok(buffer) => buffer,
96 Err(error) => {
97 open_errors.push((path, error.to_string()));
98 continue;
99 }
100 }
101 }
102 _ = event_stream.cancelled_by_user().fuse() => {
103 anyhow::bail!("Restore cancelled by user");
104 }
105 };
106
107 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
108
109 if is_dirty {
110 buffers_to_reload.insert(buffer);
111 restored_paths.push(path);
112 } else {
113 clean_paths.push(path);
114 }
115 }
116
117 if !buffers_to_reload.is_empty() {
118 let reload_task = project.update(cx, |project, cx| {
119 project.reload_buffers(buffers_to_reload, true, cx)
120 });
121
122 let result = futures::select! {
123 result = reload_task.fuse() => result,
124 _ = event_stream.cancelled_by_user().fuse() => {
125 anyhow::bail!("Restore cancelled by user");
126 }
127 };
128 if let Err(error) = result {
129 reload_errors.push(error.to_string());
130 }
131 }
132
133 let mut lines: Vec<String> = Vec::new();
134
135 if !restored_paths.is_empty() {
136 lines.push(format!("Restored {} file(s).", restored_paths.len()));
137 }
138 if !clean_paths.is_empty() {
139 lines.push(format!("{} clean.", clean_paths.len()));
140 }
141
142 if !not_found_paths.is_empty() {
143 lines.push(format!("Not found ({}):", not_found_paths.len()));
144 for path in ¬_found_paths {
145 lines.push(format!("- {}", path.display()));
146 }
147 }
148 if !open_errors.is_empty() {
149 lines.push(format!("Open failed ({}):", open_errors.len()));
150 for (path, error) in &open_errors {
151 lines.push(format!("- {}: {}", path.display(), error));
152 }
153 }
154 if !dirty_check_errors.is_empty() {
155 lines.push(format!(
156 "Dirty check failed ({}):",
157 dirty_check_errors.len()
158 ));
159 for (path, error) in &dirty_check_errors {
160 lines.push(format!("- {}: {}", path.display(), error));
161 }
162 }
163 if !reload_errors.is_empty() {
164 lines.push(format!("Reload failed ({}):", reload_errors.len()));
165 for error in &reload_errors {
166 lines.push(format!("- {}", error));
167 }
168 }
169
170 if lines.is_empty() {
171 Ok("No paths provided.".to_string())
172 } else {
173 Ok(lines.join("\n"))
174 }
175 })
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use fs::Fs;
183 use gpui::TestAppContext;
184 use language::LineEnding;
185 use project::FakeFs;
186 use serde_json::json;
187 use settings::SettingsStore;
188 use util::path;
189
190 fn init_test(cx: &mut TestAppContext) {
191 cx.update(|cx| {
192 let settings_store = SettingsStore::test(cx);
193 cx.set_global(settings_store);
194 });
195 }
196
197 #[gpui::test]
198 async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
199 init_test(cx);
200
201 let fs = FakeFs::new(cx.executor());
202 fs.insert_tree(
203 "/root",
204 json!({
205 "dirty.txt": "on disk: dirty\n",
206 "clean.txt": "on disk: clean\n",
207 }),
208 )
209 .await;
210
211 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
212 let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
213
214 // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
215 let dirty_project_path = project.read_with(cx, |project, cx| {
216 project
217 .find_project_path("root/dirty.txt", cx)
218 .expect("dirty.txt should exist in project")
219 });
220
221 let dirty_buffer = project
222 .update(cx, |project, cx| {
223 project.open_buffer(dirty_project_path, cx)
224 })
225 .await
226 .unwrap();
227 dirty_buffer.update(cx, |buffer, cx| {
228 buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
229 });
230 assert!(
231 dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
232 "dirty.txt buffer should be dirty before restore"
233 );
234
235 // Ensure clean.txt is opened but remains clean.
236 let clean_project_path = project.read_with(cx, |project, cx| {
237 project
238 .find_project_path("root/clean.txt", cx)
239 .expect("clean.txt should exist in project")
240 });
241
242 let clean_buffer = project
243 .update(cx, |project, cx| {
244 project.open_buffer(clean_project_path, cx)
245 })
246 .await
247 .unwrap();
248 assert!(
249 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
250 "clean.txt buffer should start clean"
251 );
252
253 let output = cx
254 .update(|cx| {
255 tool.clone().run(
256 RestoreFileFromDiskToolInput {
257 paths: vec![
258 PathBuf::from("root/dirty.txt"),
259 PathBuf::from("root/clean.txt"),
260 ],
261 },
262 ToolCallEventStream::test().0,
263 cx,
264 )
265 })
266 .await
267 .unwrap();
268
269 // Output should mention restored + clean.
270 assert!(
271 output.contains("Restored 1 file(s)."),
272 "expected restored count line, got:\n{output}"
273 );
274 assert!(
275 output.contains("1 clean."),
276 "expected clean count line, got:\n{output}"
277 );
278
279 // Effect: dirty buffer should be restored back to disk content and become clean.
280 let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
281 assert_eq!(
282 dirty_text, "on disk: dirty\n",
283 "dirty.txt buffer should be restored to disk contents"
284 );
285 assert!(
286 !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
287 "dirty.txt buffer should not be dirty after restore"
288 );
289
290 // Disk contents should be unchanged (restore-from-disk should not write).
291 let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
292 assert_eq!(disk_dirty, "on disk: dirty\n");
293
294 // Sanity: clean buffer should remain clean and unchanged.
295 let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
296 assert_eq!(clean_text, "on disk: clean\n");
297 assert!(
298 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
299 "clean.txt buffer should remain clean"
300 );
301
302 // Test empty paths case.
303 let output = cx
304 .update(|cx| {
305 tool.clone().run(
306 RestoreFileFromDiskToolInput { paths: vec![] },
307 ToolCallEventStream::test().0,
308 cx,
309 )
310 })
311 .await
312 .unwrap();
313 assert_eq!(output, "No paths provided.");
314
315 // Test not-found path case (path outside the project root).
316 let output = cx
317 .update(|cx| {
318 tool.clone().run(
319 RestoreFileFromDiskToolInput {
320 paths: vec![PathBuf::from("nonexistent/path.txt")],
321 },
322 ToolCallEventStream::test().0,
323 cx,
324 )
325 })
326 .await
327 .unwrap();
328 assert!(
329 output.contains("Not found (1):"),
330 "expected not-found header line, got:\n{output}"
331 );
332 assert!(
333 output.contains("- nonexistent/path.txt"),
334 "expected not-found path bullet, got:\n{output}"
335 );
336
337 let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
338 }
339}