1use agent_client_protocol as acp;
2use anyhow::Result;
3use collections::FxHashSet;
4use futures::FutureExt as _;
5use gpui::{App, Entity, SharedString, Task};
6use language::Buffer;
7use project::Project;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use std::sync::Arc;
12
13use crate::{AgentTool, ToolCallEventStream};
14
15/// Discards unsaved changes in open buffers by reloading file contents from disk.
16///
17/// Use this tool when:
18/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
19/// - You want to reset files to the on-disk state before retrying an edit.
20///
21/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
22#[derive(Debug, Serialize, Deserialize, JsonSchema)]
23pub struct RestoreFileFromDiskToolInput {
24 /// The paths of the files to restore from disk.
25 pub paths: Vec<PathBuf>,
26}
27
28pub struct RestoreFileFromDiskTool {
29 project: Entity<Project>,
30}
31
32impl RestoreFileFromDiskTool {
33 pub fn new(project: Entity<Project>) -> Self {
34 Self { project }
35 }
36}
37
38impl AgentTool for RestoreFileFromDiskTool {
39 type Input = RestoreFileFromDiskToolInput;
40 type Output = String;
41
42 const NAME: &'static str = "restore_file_from_disk";
43
44 fn kind() -> acp::ToolKind {
45 acp::ToolKind::Other
46 }
47
48 fn initial_title(
49 &self,
50 input: Result<Self::Input, serde_json::Value>,
51 _cx: &mut App,
52 ) -> SharedString {
53 match input {
54 Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
55 Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
56 Err(_) => "Restore files from disk".into(),
57 }
58 }
59
60 fn run(
61 self: Arc<Self>,
62 input: Self::Input,
63 event_stream: ToolCallEventStream,
64 cx: &mut App,
65 ) -> Task<Result<String>> {
66 let project = self.project.clone();
67 let input_paths = input.paths;
68
69 cx.spawn(async move |cx| {
70 let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
71
72 let mut restored_paths: Vec<PathBuf> = Vec::new();
73 let mut clean_paths: Vec<PathBuf> = Vec::new();
74 let mut not_found_paths: Vec<PathBuf> = Vec::new();
75 let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
76 let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
77 let mut reload_errors: Vec<String> = Vec::new();
78
79 for path in input_paths {
80 let Some(project_path) =
81 project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
82 else {
83 not_found_paths.push(path);
84 continue;
85 };
86
87 let open_buffer_task =
88 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
89
90 let buffer = futures::select! {
91 result = open_buffer_task.fuse() => {
92 match result {
93 Ok(buffer) => buffer,
94 Err(error) => {
95 open_errors.push((path, error.to_string()));
96 continue;
97 }
98 }
99 }
100 _ = event_stream.cancelled_by_user().fuse() => {
101 anyhow::bail!("Restore cancelled by user");
102 }
103 };
104
105 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
106
107 if is_dirty {
108 buffers_to_reload.insert(buffer);
109 restored_paths.push(path);
110 } else {
111 clean_paths.push(path);
112 }
113 }
114
115 if !buffers_to_reload.is_empty() {
116 let reload_task = project.update(cx, |project, cx| {
117 project.reload_buffers(buffers_to_reload, true, cx)
118 });
119
120 let result = futures::select! {
121 result = reload_task.fuse() => result,
122 _ = event_stream.cancelled_by_user().fuse() => {
123 anyhow::bail!("Restore cancelled by user");
124 }
125 };
126 if let Err(error) = result {
127 reload_errors.push(error.to_string());
128 }
129 }
130
131 let mut lines: Vec<String> = Vec::new();
132
133 if !restored_paths.is_empty() {
134 lines.push(format!("Restored {} file(s).", restored_paths.len()));
135 }
136 if !clean_paths.is_empty() {
137 lines.push(format!("{} clean.", clean_paths.len()));
138 }
139
140 if !not_found_paths.is_empty() {
141 lines.push(format!("Not found ({}):", not_found_paths.len()));
142 for path in ¬_found_paths {
143 lines.push(format!("- {}", path.display()));
144 }
145 }
146 if !open_errors.is_empty() {
147 lines.push(format!("Open failed ({}):", open_errors.len()));
148 for (path, error) in &open_errors {
149 lines.push(format!("- {}: {}", path.display(), error));
150 }
151 }
152 if !dirty_check_errors.is_empty() {
153 lines.push(format!(
154 "Dirty check failed ({}):",
155 dirty_check_errors.len()
156 ));
157 for (path, error) in &dirty_check_errors {
158 lines.push(format!("- {}: {}", path.display(), error));
159 }
160 }
161 if !reload_errors.is_empty() {
162 lines.push(format!("Reload failed ({}):", reload_errors.len()));
163 for error in &reload_errors {
164 lines.push(format!("- {}", error));
165 }
166 }
167
168 if lines.is_empty() {
169 Ok("No paths provided.".to_string())
170 } else {
171 Ok(lines.join("\n"))
172 }
173 })
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use fs::Fs;
181 use gpui::TestAppContext;
182 use language::LineEnding;
183 use project::FakeFs;
184 use serde_json::json;
185 use settings::SettingsStore;
186 use util::path;
187
188 fn init_test(cx: &mut TestAppContext) {
189 cx.update(|cx| {
190 let settings_store = SettingsStore::test(cx);
191 cx.set_global(settings_store);
192 });
193 }
194
195 #[gpui::test]
196 async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
197 init_test(cx);
198
199 let fs = FakeFs::new(cx.executor());
200 fs.insert_tree(
201 "/root",
202 json!({
203 "dirty.txt": "on disk: dirty\n",
204 "clean.txt": "on disk: clean\n",
205 }),
206 )
207 .await;
208
209 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
210 let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
211
212 // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
213 let dirty_project_path = project.read_with(cx, |project, cx| {
214 project
215 .find_project_path("root/dirty.txt", cx)
216 .expect("dirty.txt should exist in project")
217 });
218
219 let dirty_buffer = project
220 .update(cx, |project, cx| {
221 project.open_buffer(dirty_project_path, cx)
222 })
223 .await
224 .unwrap();
225 dirty_buffer.update(cx, |buffer, cx| {
226 buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
227 });
228 assert!(
229 dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
230 "dirty.txt buffer should be dirty before restore"
231 );
232
233 // Ensure clean.txt is opened but remains clean.
234 let clean_project_path = project.read_with(cx, |project, cx| {
235 project
236 .find_project_path("root/clean.txt", cx)
237 .expect("clean.txt should exist in project")
238 });
239
240 let clean_buffer = project
241 .update(cx, |project, cx| {
242 project.open_buffer(clean_project_path, cx)
243 })
244 .await
245 .unwrap();
246 assert!(
247 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
248 "clean.txt buffer should start clean"
249 );
250
251 let output = cx
252 .update(|cx| {
253 tool.clone().run(
254 RestoreFileFromDiskToolInput {
255 paths: vec![
256 PathBuf::from("root/dirty.txt"),
257 PathBuf::from("root/clean.txt"),
258 ],
259 },
260 ToolCallEventStream::test().0,
261 cx,
262 )
263 })
264 .await
265 .unwrap();
266
267 // Output should mention restored + clean.
268 assert!(
269 output.contains("Restored 1 file(s)."),
270 "expected restored count line, got:\n{output}"
271 );
272 assert!(
273 output.contains("1 clean."),
274 "expected clean count line, got:\n{output}"
275 );
276
277 // Effect: dirty buffer should be restored back to disk content and become clean.
278 let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
279 assert_eq!(
280 dirty_text, "on disk: dirty\n",
281 "dirty.txt buffer should be restored to disk contents"
282 );
283 assert!(
284 !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
285 "dirty.txt buffer should not be dirty after restore"
286 );
287
288 // Disk contents should be unchanged (restore-from-disk should not write).
289 let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
290 assert_eq!(disk_dirty, "on disk: dirty\n");
291
292 // Sanity: clean buffer should remain clean and unchanged.
293 let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
294 assert_eq!(clean_text, "on disk: clean\n");
295 assert!(
296 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
297 "clean.txt buffer should remain clean"
298 );
299
300 // Test empty paths case.
301 let output = cx
302 .update(|cx| {
303 tool.clone().run(
304 RestoreFileFromDiskToolInput { paths: vec![] },
305 ToolCallEventStream::test().0,
306 cx,
307 )
308 })
309 .await
310 .unwrap();
311 assert_eq!(output, "No paths provided.");
312
313 // Test not-found path case (path outside the project root).
314 let output = cx
315 .update(|cx| {
316 tool.clone().run(
317 RestoreFileFromDiskToolInput {
318 paths: vec![PathBuf::from("nonexistent/path.txt")],
319 },
320 ToolCallEventStream::test().0,
321 cx,
322 )
323 })
324 .await
325 .unwrap();
326 assert!(
327 output.contains("Not found (1):"),
328 "expected not-found header line, got:\n{output}"
329 );
330 assert!(
331 output.contains("- nonexistent/path.txt"),
332 "expected not-found path bullet, got:\n{output}"
333 );
334
335 let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
336 }
337}