save_file_tool.rs

  1use agent_client_protocol as acp;
  2use anyhow::Result;
  3use collections::FxHashSet;
  4use futures::FutureExt as _;
  5use gpui::{App, Entity, SharedString, Task};
  6use language::Buffer;
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use std::path::PathBuf;
 11use std::sync::Arc;
 12
 13use crate::{AgentTool, ToolCallEventStream};
 14
 15/// Saves files that have unsaved changes.
 16///
 17/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 18/// Only use this tool after asking the user for permission to save their unsaved changes.
 19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 20pub struct SaveFileToolInput {
 21    /// The paths of the files to save.
 22    pub paths: Vec<PathBuf>,
 23}
 24
 25pub struct SaveFileTool {
 26    project: Entity<Project>,
 27}
 28
 29impl SaveFileTool {
 30    pub fn new(project: Entity<Project>) -> Self {
 31        Self { project }
 32    }
 33}
 34
 35impl AgentTool for SaveFileTool {
 36    type Input = SaveFileToolInput;
 37    type Output = String;
 38
 39    fn name() -> &'static str {
 40        "save_file"
 41    }
 42
 43    fn kind() -> acp::ToolKind {
 44        acp::ToolKind::Other
 45    }
 46
 47    fn initial_title(
 48        &self,
 49        input: Result<Self::Input, serde_json::Value>,
 50        _cx: &mut App,
 51    ) -> SharedString {
 52        match input {
 53            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 54            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 55            Err(_) => "Save files".into(),
 56        }
 57    }
 58
 59    fn run(
 60        self: Arc<Self>,
 61        input: Self::Input,
 62        event_stream: ToolCallEventStream,
 63        cx: &mut App,
 64    ) -> Task<Result<String>> {
 65        let project = self.project.clone();
 66        let input_paths = input.paths;
 67
 68        cx.spawn(async move |cx| {
 69            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
 70
 71            let mut saved_paths: Vec<PathBuf> = Vec::new();
 72            let mut clean_paths: Vec<PathBuf> = Vec::new();
 73            let mut not_found_paths: Vec<PathBuf> = Vec::new();
 74            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
 75            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
 76            let mut save_errors: Vec<(String, String)> = Vec::new();
 77
 78            for path in input_paths {
 79                let Some(project_path) =
 80                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
 81                else {
 82                    not_found_paths.push(path);
 83                    continue;
 84                };
 85
 86                let open_buffer_task =
 87                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 88
 89                let buffer = futures::select! {
 90                    result = open_buffer_task.fuse() => {
 91                        match result {
 92                            Ok(buffer) => buffer,
 93                            Err(error) => {
 94                                open_errors.push((path, error.to_string()));
 95                                continue;
 96                            }
 97                        }
 98                    }
 99                    _ = event_stream.cancelled_by_user().fuse() => {
100                        anyhow::bail!("Save cancelled by user");
101                    }
102                };
103
104                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
105
106                if is_dirty {
107                    buffers_to_save.insert(buffer);
108                    saved_paths.push(path);
109                } else {
110                    clean_paths.push(path);
111                }
112            }
113
114            // Save each buffer individually since there's no batch save API.
115            for buffer in buffers_to_save {
116                let path_for_buffer = buffer
117                    .read_with(cx, |buffer, _| {
118                        buffer
119                            .file()
120                            .map(|file| file.path().to_rel_path_buf())
121                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
122                    })
123                    .unwrap_or_else(|| "<unknown>".to_string());
124
125                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
126
127                let save_result = futures::select! {
128                    result = save_task.fuse() => result,
129                    _ = event_stream.cancelled_by_user().fuse() => {
130                        anyhow::bail!("Save cancelled by user");
131                    }
132                };
133                if let Err(error) = save_result {
134                    save_errors.push((path_for_buffer, error.to_string()));
135                }
136            }
137
138            let mut lines: Vec<String> = Vec::new();
139
140            if !saved_paths.is_empty() {
141                lines.push(format!("Saved {} file(s).", saved_paths.len()));
142            }
143            if !clean_paths.is_empty() {
144                lines.push(format!("{} clean.", clean_paths.len()));
145            }
146
147            if !not_found_paths.is_empty() {
148                lines.push(format!("Not found ({}):", not_found_paths.len()));
149                for path in &not_found_paths {
150                    lines.push(format!("- {}", path.display()));
151                }
152            }
153            if !open_errors.is_empty() {
154                lines.push(format!("Open failed ({}):", open_errors.len()));
155                for (path, error) in &open_errors {
156                    lines.push(format!("- {}: {}", path.display(), error));
157                }
158            }
159            if !dirty_check_errors.is_empty() {
160                lines.push(format!(
161                    "Dirty check failed ({}):",
162                    dirty_check_errors.len()
163                ));
164                for (path, error) in &dirty_check_errors {
165                    lines.push(format!("- {}: {}", path.display(), error));
166                }
167            }
168            if !save_errors.is_empty() {
169                lines.push(format!("Save failed ({}):", save_errors.len()));
170                for (path, error) in &save_errors {
171                    lines.push(format!("- {}: {}", path, error));
172                }
173            }
174
175            if lines.is_empty() {
176                Ok("No paths provided.".to_string())
177            } else {
178                Ok(lines.join("\n"))
179            }
180        })
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187    use fs::Fs;
188    use gpui::TestAppContext;
189    use project::FakeFs;
190    use serde_json::json;
191    use settings::SettingsStore;
192    use util::path;
193
194    fn init_test(cx: &mut TestAppContext) {
195        cx.update(|cx| {
196            let settings_store = SettingsStore::test(cx);
197            cx.set_global(settings_store);
198        });
199    }
200
201    #[gpui::test]
202    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
203        init_test(cx);
204
205        let fs = FakeFs::new(cx.executor());
206        fs.insert_tree(
207            "/root",
208            json!({
209                "dirty.txt": "on disk: dirty\n",
210                "clean.txt": "on disk: clean\n",
211            }),
212        )
213        .await;
214
215        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
216        let tool = Arc::new(SaveFileTool::new(project.clone()));
217
218        // Make dirty.txt dirty in-memory.
219        let dirty_project_path = project.read_with(cx, |project, cx| {
220            project
221                .find_project_path("root/dirty.txt", cx)
222                .expect("dirty.txt should exist in project")
223        });
224
225        let dirty_buffer = project
226            .update(cx, |project, cx| {
227                project.open_buffer(dirty_project_path, cx)
228            })
229            .await
230            .unwrap();
231        dirty_buffer.update(cx, |buffer, cx| {
232            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
233        });
234        assert!(
235            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
236            "dirty.txt buffer should be dirty before save"
237        );
238
239        // Ensure clean.txt is opened but remains clean.
240        let clean_project_path = project.read_with(cx, |project, cx| {
241            project
242                .find_project_path("root/clean.txt", cx)
243                .expect("clean.txt should exist in project")
244        });
245
246        let clean_buffer = project
247            .update(cx, |project, cx| {
248                project.open_buffer(clean_project_path, cx)
249            })
250            .await
251            .unwrap();
252        assert!(
253            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
254            "clean.txt buffer should start clean"
255        );
256
257        let output = cx
258            .update(|cx| {
259                tool.clone().run(
260                    SaveFileToolInput {
261                        paths: vec![
262                            PathBuf::from("root/dirty.txt"),
263                            PathBuf::from("root/clean.txt"),
264                        ],
265                    },
266                    ToolCallEventStream::test().0,
267                    cx,
268                )
269            })
270            .await
271            .unwrap();
272
273        // Output should mention saved + clean.
274        assert!(
275            output.contains("Saved 1 file(s)."),
276            "expected saved count line, got:\n{output}"
277        );
278        assert!(
279            output.contains("1 clean."),
280            "expected clean count line, got:\n{output}"
281        );
282
283        // Effect: dirty buffer should now be clean and disk should have new content.
284        assert!(
285            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
286            "dirty.txt buffer should not be dirty after save"
287        );
288
289        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
290        assert_eq!(
291            disk_dirty, "in memory: dirty\n",
292            "dirty.txt disk content should be updated"
293        );
294
295        // Sanity: clean buffer should remain clean and disk unchanged.
296        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
297        assert_eq!(disk_clean, "on disk: clean\n");
298
299        // Test empty paths case.
300        let output = cx
301            .update(|cx| {
302                tool.clone().run(
303                    SaveFileToolInput { paths: vec![] },
304                    ToolCallEventStream::test().0,
305                    cx,
306                )
307            })
308            .await
309            .unwrap();
310        assert_eq!(output, "No paths provided.");
311
312        // Test not-found path case.
313        let output = cx
314            .update(|cx| {
315                tool.clone().run(
316                    SaveFileToolInput {
317                        paths: vec![PathBuf::from("nonexistent/path.txt")],
318                    },
319                    ToolCallEventStream::test().0,
320                    cx,
321                )
322            })
323            .await
324            .unwrap();
325        assert!(
326            output.contains("Not found (1):"),
327            "expected not-found header line, got:\n{output}"
328        );
329        assert!(
330            output.contains("- nonexistent/path.txt"),
331            "expected not-found path bullet, got:\n{output}"
332        );
333    }
334}