1use agent_client_protocol as acp;
2use anyhow::Result;
3use collections::FxHashSet;
4use gpui::{App, Entity, SharedString, Task};
5use language::Buffer;
6use project::Project;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use crate::{AgentTool, ToolCallEventStream};
13
14/// Discards unsaved changes in open buffers by reloading file contents from disk.
15///
16/// Use this tool when:
17/// - You attempted to edit files but they have unsaved changes the user does not want to keep.
18/// - You want to reset files to the on-disk state before retrying an edit.
19///
20/// Only use this tool after asking the user for permission, because it will discard unsaved changes.
21#[derive(Debug, Serialize, Deserialize, JsonSchema)]
22pub struct RestoreFileFromDiskToolInput {
23 /// The paths of the files to restore from disk.
24 pub paths: Vec<PathBuf>,
25}
26
27pub struct RestoreFileFromDiskTool {
28 project: Entity<Project>,
29}
30
31impl RestoreFileFromDiskTool {
32 pub fn new(project: Entity<Project>) -> Self {
33 Self { project }
34 }
35}
36
37impl AgentTool for RestoreFileFromDiskTool {
38 type Input = RestoreFileFromDiskToolInput;
39 type Output = String;
40
41 fn name() -> &'static str {
42 "restore_file_from_disk"
43 }
44
45 fn kind() -> acp::ToolKind {
46 acp::ToolKind::Other
47 }
48
49 fn initial_title(
50 &self,
51 input: Result<Self::Input, serde_json::Value>,
52 _cx: &mut App,
53 ) -> SharedString {
54 match input {
55 Ok(input) if input.paths.len() == 1 => "Restore file from disk".into(),
56 Ok(input) => format!("Restore {} files from disk", input.paths.len()).into(),
57 Err(_) => "Restore files from disk".into(),
58 }
59 }
60
61 fn run(
62 self: Arc<Self>,
63 input: Self::Input,
64 _event_stream: ToolCallEventStream,
65 cx: &mut App,
66 ) -> Task<Result<String>> {
67 let project = self.project.clone();
68 let input_paths = input.paths;
69
70 cx.spawn(async move |cx| {
71 let mut buffers_to_reload: FxHashSet<Entity<Buffer>> = FxHashSet::default();
72
73 let mut restored_paths: Vec<PathBuf> = Vec::new();
74 let mut clean_paths: Vec<PathBuf> = Vec::new();
75 let mut not_found_paths: Vec<PathBuf> = Vec::new();
76 let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
77 let mut dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
78 let mut reload_errors: Vec<String> = Vec::new();
79
80 for path in input_paths {
81 let project_path =
82 project.read_with(cx, |project, cx| project.find_project_path(&path, cx));
83
84 let project_path = match project_path {
85 Ok(Some(project_path)) => project_path,
86 Ok(None) => {
87 not_found_paths.push(path);
88 continue;
89 }
90 Err(error) => {
91 open_errors.push((path, error.to_string()));
92 continue;
93 }
94 };
95
96 let open_buffer_task =
97 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
98
99 let buffer = match open_buffer_task {
100 Ok(task) => match task.await {
101 Ok(buffer) => buffer,
102 Err(error) => {
103 open_errors.push((path, error.to_string()));
104 continue;
105 }
106 },
107 Err(error) => {
108 open_errors.push((path, error.to_string()));
109 continue;
110 }
111 };
112
113 let is_dirty = match buffer.read_with(cx, |buffer, _| buffer.is_dirty()) {
114 Ok(is_dirty) => is_dirty,
115 Err(error) => {
116 dirty_check_errors.push((path, error.to_string()));
117 continue;
118 }
119 };
120
121 if is_dirty {
122 buffers_to_reload.insert(buffer);
123 restored_paths.push(path);
124 } else {
125 clean_paths.push(path);
126 }
127 }
128
129 if !buffers_to_reload.is_empty() {
130 let reload_task = project.update(cx, |project, cx| {
131 project.reload_buffers(buffers_to_reload, true, cx)
132 });
133
134 match reload_task {
135 Ok(task) => {
136 if let Err(error) = task.await {
137 reload_errors.push(error.to_string());
138 }
139 }
140 Err(error) => {
141 reload_errors.push(error.to_string());
142 }
143 }
144 }
145
146 let mut lines: Vec<String> = Vec::new();
147
148 if !restored_paths.is_empty() {
149 lines.push(format!("Restored {} file(s).", restored_paths.len()));
150 }
151 if !clean_paths.is_empty() {
152 lines.push(format!("{} clean.", clean_paths.len()));
153 }
154
155 if !not_found_paths.is_empty() {
156 lines.push(format!("Not found ({}):", not_found_paths.len()));
157 for path in ¬_found_paths {
158 lines.push(format!("- {}", path.display()));
159 }
160 }
161 if !open_errors.is_empty() {
162 lines.push(format!("Open failed ({}):", open_errors.len()));
163 for (path, error) in &open_errors {
164 lines.push(format!("- {}: {}", path.display(), error));
165 }
166 }
167 if !dirty_check_errors.is_empty() {
168 lines.push(format!(
169 "Dirty check failed ({}):",
170 dirty_check_errors.len()
171 ));
172 for (path, error) in &dirty_check_errors {
173 lines.push(format!("- {}: {}", path.display(), error));
174 }
175 }
176 if !reload_errors.is_empty() {
177 lines.push(format!("Reload failed ({}):", reload_errors.len()));
178 for error in &reload_errors {
179 lines.push(format!("- {}", error));
180 }
181 }
182
183 if lines.is_empty() {
184 Ok("No paths provided.".to_string())
185 } else {
186 Ok(lines.join("\n"))
187 }
188 })
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use fs::Fs;
196 use gpui::TestAppContext;
197 use language::LineEnding;
198 use project::FakeFs;
199 use serde_json::json;
200 use settings::SettingsStore;
201 use util::path;
202
203 fn init_test(cx: &mut TestAppContext) {
204 cx.update(|cx| {
205 let settings_store = SettingsStore::test(cx);
206 cx.set_global(settings_store);
207 });
208 }
209
210 #[gpui::test]
211 async fn test_restore_file_from_disk_output_and_effects(cx: &mut TestAppContext) {
212 init_test(cx);
213
214 let fs = FakeFs::new(cx.executor());
215 fs.insert_tree(
216 "/root",
217 json!({
218 "dirty.txt": "on disk: dirty\n",
219 "clean.txt": "on disk: clean\n",
220 }),
221 )
222 .await;
223
224 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
225 let tool = Arc::new(RestoreFileFromDiskTool::new(project.clone()));
226
227 // Make dirty.txt dirty in-memory by saving different content into the buffer without saving to disk.
228 let dirty_project_path = project.read_with(cx, |project, cx| {
229 project
230 .find_project_path("root/dirty.txt", cx)
231 .expect("dirty.txt should exist in project")
232 });
233
234 let dirty_buffer = project
235 .update(cx, |project, cx| {
236 project.open_buffer(dirty_project_path, cx)
237 })
238 .await
239 .unwrap();
240 dirty_buffer.update(cx, |buffer, cx| {
241 buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
242 });
243 assert!(
244 dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
245 "dirty.txt buffer should be dirty before restore"
246 );
247
248 // Ensure clean.txt is opened but remains clean.
249 let clean_project_path = project.read_with(cx, |project, cx| {
250 project
251 .find_project_path("root/clean.txt", cx)
252 .expect("clean.txt should exist in project")
253 });
254
255 let clean_buffer = project
256 .update(cx, |project, cx| {
257 project.open_buffer(clean_project_path, cx)
258 })
259 .await
260 .unwrap();
261 assert!(
262 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
263 "clean.txt buffer should start clean"
264 );
265
266 let output = cx
267 .update(|cx| {
268 tool.clone().run(
269 RestoreFileFromDiskToolInput {
270 paths: vec![
271 PathBuf::from("root/dirty.txt"),
272 PathBuf::from("root/clean.txt"),
273 ],
274 },
275 ToolCallEventStream::test().0,
276 cx,
277 )
278 })
279 .await
280 .unwrap();
281
282 // Output should mention restored + clean.
283 assert!(
284 output.contains("Restored 1 file(s)."),
285 "expected restored count line, got:\n{output}"
286 );
287 assert!(
288 output.contains("1 clean."),
289 "expected clean count line, got:\n{output}"
290 );
291
292 // Effect: dirty buffer should be restored back to disk content and become clean.
293 let dirty_text = dirty_buffer.read_with(cx, |buffer, _| buffer.text());
294 assert_eq!(
295 dirty_text, "on disk: dirty\n",
296 "dirty.txt buffer should be restored to disk contents"
297 );
298 assert!(
299 !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
300 "dirty.txt buffer should not be dirty after restore"
301 );
302
303 // Disk contents should be unchanged (restore-from-disk should not write).
304 let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
305 assert_eq!(disk_dirty, "on disk: dirty\n");
306
307 // Sanity: clean buffer should remain clean and unchanged.
308 let clean_text = clean_buffer.read_with(cx, |buffer, _| buffer.text());
309 assert_eq!(clean_text, "on disk: clean\n");
310 assert!(
311 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
312 "clean.txt buffer should remain clean"
313 );
314
315 // Test empty paths case.
316 let output = cx
317 .update(|cx| {
318 tool.clone().run(
319 RestoreFileFromDiskToolInput { paths: vec![] },
320 ToolCallEventStream::test().0,
321 cx,
322 )
323 })
324 .await
325 .unwrap();
326 assert_eq!(output, "No paths provided.");
327
328 // Test not-found path case (path outside the project root).
329 let output = cx
330 .update(|cx| {
331 tool.clone().run(
332 RestoreFileFromDiskToolInput {
333 paths: vec![PathBuf::from("nonexistent/path.txt")],
334 },
335 ToolCallEventStream::test().0,
336 cx,
337 )
338 })
339 .await
340 .unwrap();
341 assert!(
342 output.contains("Not found (1):"),
343 "expected not-found header line, got:\n{output}"
344 );
345 assert!(
346 output.contains("- nonexistent/path.txt"),
347 "expected not-found path bullet, got:\n{output}"
348 );
349
350 let _ = LineEnding::Unix; // keep import used if the buffer edit API changes
351 }
352}