save_file_tool.rs

  1use agent_client_protocol as acp;
  2use anyhow::Result;
  3use collections::FxHashSet;
  4use gpui::{App, Entity, SharedString, Task};
  5use language::Buffer;
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::path::PathBuf;
 10use std::sync::Arc;
 11
 12use crate::{AgentTool, ToolCallEventStream};
 13
 14/// Saves files that have unsaved changes.
 15///
 16/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 17/// Only use this tool after asking the user for permission to save their unsaved changes.
 18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 19pub struct SaveFileToolInput {
 20    /// The paths of the files to save.
 21    pub paths: Vec<PathBuf>,
 22}
 23
 24pub struct SaveFileTool {
 25    project: Entity<Project>,
 26}
 27
 28impl SaveFileTool {
 29    pub fn new(project: Entity<Project>) -> Self {
 30        Self { project }
 31    }
 32}
 33
 34impl AgentTool for SaveFileTool {
 35    type Input = SaveFileToolInput;
 36    type Output = String;
 37
 38    fn name() -> &'static str {
 39        "save_file"
 40    }
 41
 42    fn kind() -> acp::ToolKind {
 43        acp::ToolKind::Other
 44    }
 45
 46    fn initial_title(
 47        &self,
 48        input: Result<Self::Input, serde_json::Value>,
 49        _cx: &mut App,
 50    ) -> SharedString {
 51        match input {
 52            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 53            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 54            Err(_) => "Save files".into(),
 55        }
 56    }
 57
 58    fn run(
 59        self: Arc<Self>,
 60        input: Self::Input,
 61        _event_stream: ToolCallEventStream,
 62        cx: &mut App,
 63    ) -> Task<Result<String>> {
 64        let project = self.project.clone();
 65        let input_paths = input.paths;
 66
 67        cx.spawn(async move |cx| {
 68            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
 69
 70            let mut saved_paths: Vec<PathBuf> = Vec::new();
 71            let mut clean_paths: Vec<PathBuf> = Vec::new();
 72            let mut not_found_paths: Vec<PathBuf> = Vec::new();
 73            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
 74            let mut dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
 75            let mut save_errors: Vec<(String, String)> = Vec::new();
 76
 77            for path in input_paths {
 78                let project_path =
 79                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx));
 80
 81                let project_path = match project_path {
 82                    Ok(Some(project_path)) => project_path,
 83                    Ok(None) => {
 84                        not_found_paths.push(path);
 85                        continue;
 86                    }
 87                    Err(error) => {
 88                        open_errors.push((path, error.to_string()));
 89                        continue;
 90                    }
 91                };
 92
 93                let open_buffer_task =
 94                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 95
 96                let buffer = match open_buffer_task {
 97                    Ok(task) => match task.await {
 98                        Ok(buffer) => buffer,
 99                        Err(error) => {
100                            open_errors.push((path, error.to_string()));
101                            continue;
102                        }
103                    },
104                    Err(error) => {
105                        open_errors.push((path, error.to_string()));
106                        continue;
107                    }
108                };
109
110                let is_dirty = match buffer.read_with(cx, |buffer, _| buffer.is_dirty()) {
111                    Ok(is_dirty) => is_dirty,
112                    Err(error) => {
113                        dirty_check_errors.push((path, error.to_string()));
114                        continue;
115                    }
116                };
117
118                if is_dirty {
119                    buffers_to_save.insert(buffer);
120                    saved_paths.push(path);
121                } else {
122                    clean_paths.push(path);
123                }
124            }
125
126            // Save each buffer individually since there's no batch save API.
127            for buffer in buffers_to_save {
128                let path_for_buffer = match buffer.read_with(cx, |buffer, _| {
129                    buffer
130                        .file()
131                        .map(|file| file.path().to_rel_path_buf())
132                        .map(|path| path.as_rel_path().as_unix_str().to_owned())
133                }) {
134                    Ok(path) => path.unwrap_or_else(|| "<unknown>".to_string()),
135                    Err(error) => {
136                        save_errors.push(("<unknown>".to_string(), error.to_string()));
137                        continue;
138                    }
139                };
140
141                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
142
143                match save_task {
144                    Ok(task) => {
145                        if let Err(error) = task.await {
146                            save_errors.push((path_for_buffer, error.to_string()));
147                        }
148                    }
149                    Err(error) => {
150                        save_errors.push((path_for_buffer, error.to_string()));
151                    }
152                }
153            }
154
155            let mut lines: Vec<String> = Vec::new();
156
157            if !saved_paths.is_empty() {
158                lines.push(format!("Saved {} file(s).", saved_paths.len()));
159            }
160            if !clean_paths.is_empty() {
161                lines.push(format!("{} clean.", clean_paths.len()));
162            }
163
164            if !not_found_paths.is_empty() {
165                lines.push(format!("Not found ({}):", not_found_paths.len()));
166                for path in &not_found_paths {
167                    lines.push(format!("- {}", path.display()));
168                }
169            }
170            if !open_errors.is_empty() {
171                lines.push(format!("Open failed ({}):", open_errors.len()));
172                for (path, error) in &open_errors {
173                    lines.push(format!("- {}: {}", path.display(), error));
174                }
175            }
176            if !dirty_check_errors.is_empty() {
177                lines.push(format!(
178                    "Dirty check failed ({}):",
179                    dirty_check_errors.len()
180                ));
181                for (path, error) in &dirty_check_errors {
182                    lines.push(format!("- {}: {}", path.display(), error));
183                }
184            }
185            if !save_errors.is_empty() {
186                lines.push(format!("Save failed ({}):", save_errors.len()));
187                for (path, error) in &save_errors {
188                    lines.push(format!("- {}: {}", path, error));
189                }
190            }
191
192            if lines.is_empty() {
193                Ok("No paths provided.".to_string())
194            } else {
195                Ok(lines.join("\n"))
196            }
197        })
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use fs::Fs;
205    use gpui::TestAppContext;
206    use project::FakeFs;
207    use serde_json::json;
208    use settings::SettingsStore;
209    use util::path;
210
211    fn init_test(cx: &mut TestAppContext) {
212        cx.update(|cx| {
213            let settings_store = SettingsStore::test(cx);
214            cx.set_global(settings_store);
215        });
216    }
217
218    #[gpui::test]
219    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
220        init_test(cx);
221
222        let fs = FakeFs::new(cx.executor());
223        fs.insert_tree(
224            "/root",
225            json!({
226                "dirty.txt": "on disk: dirty\n",
227                "clean.txt": "on disk: clean\n",
228            }),
229        )
230        .await;
231
232        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
233        let tool = Arc::new(SaveFileTool::new(project.clone()));
234
235        // Make dirty.txt dirty in-memory.
236        let dirty_project_path = project.read_with(cx, |project, cx| {
237            project
238                .find_project_path("root/dirty.txt", cx)
239                .expect("dirty.txt should exist in project")
240        });
241
242        let dirty_buffer = project
243            .update(cx, |project, cx| {
244                project.open_buffer(dirty_project_path, cx)
245            })
246            .await
247            .unwrap();
248        dirty_buffer.update(cx, |buffer, cx| {
249            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
250        });
251        assert!(
252            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
253            "dirty.txt buffer should be dirty before save"
254        );
255
256        // Ensure clean.txt is opened but remains clean.
257        let clean_project_path = project.read_with(cx, |project, cx| {
258            project
259                .find_project_path("root/clean.txt", cx)
260                .expect("clean.txt should exist in project")
261        });
262
263        let clean_buffer = project
264            .update(cx, |project, cx| {
265                project.open_buffer(clean_project_path, cx)
266            })
267            .await
268            .unwrap();
269        assert!(
270            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
271            "clean.txt buffer should start clean"
272        );
273
274        let output = cx
275            .update(|cx| {
276                tool.clone().run(
277                    SaveFileToolInput {
278                        paths: vec![
279                            PathBuf::from("root/dirty.txt"),
280                            PathBuf::from("root/clean.txt"),
281                        ],
282                    },
283                    ToolCallEventStream::test().0,
284                    cx,
285                )
286            })
287            .await
288            .unwrap();
289
290        // Output should mention saved + clean.
291        assert!(
292            output.contains("Saved 1 file(s)."),
293            "expected saved count line, got:\n{output}"
294        );
295        assert!(
296            output.contains("1 clean."),
297            "expected clean count line, got:\n{output}"
298        );
299
300        // Effect: dirty buffer should now be clean and disk should have new content.
301        assert!(
302            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
303            "dirty.txt buffer should not be dirty after save"
304        );
305
306        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
307        assert_eq!(
308            disk_dirty, "in memory: dirty\n",
309            "dirty.txt disk content should be updated"
310        );
311
312        // Sanity: clean buffer should remain clean and disk unchanged.
313        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
314        assert_eq!(disk_clean, "on disk: clean\n");
315
316        // Test empty paths case.
317        let output = cx
318            .update(|cx| {
319                tool.clone().run(
320                    SaveFileToolInput { paths: vec![] },
321                    ToolCallEventStream::test().0,
322                    cx,
323                )
324            })
325            .await
326            .unwrap();
327        assert_eq!(output, "No paths provided.");
328
329        // Test not-found path case.
330        let output = cx
331            .update(|cx| {
332                tool.clone().run(
333                    SaveFileToolInput {
334                        paths: vec![PathBuf::from("nonexistent/path.txt")],
335                    },
336                    ToolCallEventStream::test().0,
337                    cx,
338                )
339            })
340            .await
341            .unwrap();
342        assert!(
343            output.contains("Not found (1):"),
344            "expected not-found header line, got:\n{output}"
345        );
346        assert!(
347            output.contains("- nonexistent/path.txt"),
348            "expected not-found path bullet, got:\n{output}"
349        );
350    }
351}