save_file_tool.rs

  1use agent_client_protocol as acp;
  2use anyhow::Result;
  3use collections::FxHashSet;
  4use gpui::{App, Entity, SharedString, Task};
  5use language::Buffer;
  6use project::Project;
  7use schemars::JsonSchema;
  8use serde::{Deserialize, Serialize};
  9use std::path::PathBuf;
 10use std::sync::Arc;
 11
 12use crate::{AgentTool, ToolCallEventStream};
 13
 14/// Saves files that have unsaved changes.
 15///
 16/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 17/// Only use this tool after asking the user for permission to save their unsaved changes.
 18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 19pub struct SaveFileToolInput {
 20    /// The paths of the files to save.
 21    pub paths: Vec<PathBuf>,
 22}
 23
 24pub struct SaveFileTool {
 25    project: Entity<Project>,
 26}
 27
 28impl SaveFileTool {
 29    pub fn new(project: Entity<Project>) -> Self {
 30        Self { project }
 31    }
 32}
 33
 34impl AgentTool for SaveFileTool {
 35    type Input = SaveFileToolInput;
 36    type Output = String;
 37
 38    fn name() -> &'static str {
 39        "save_file"
 40    }
 41
 42    fn kind() -> acp::ToolKind {
 43        acp::ToolKind::Other
 44    }
 45
 46    fn initial_title(
 47        &self,
 48        input: Result<Self::Input, serde_json::Value>,
 49        _cx: &mut App,
 50    ) -> SharedString {
 51        match input {
 52            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 53            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 54            Err(_) => "Save files".into(),
 55        }
 56    }
 57
 58    fn run(
 59        self: Arc<Self>,
 60        input: Self::Input,
 61        _event_stream: ToolCallEventStream,
 62        cx: &mut App,
 63    ) -> Task<Result<String>> {
 64        let project = self.project.clone();
 65        let input_paths = input.paths;
 66
 67        cx.spawn(async move |cx| {
 68            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
 69
 70            let mut saved_paths: Vec<PathBuf> = Vec::new();
 71            let mut clean_paths: Vec<PathBuf> = Vec::new();
 72            let mut not_found_paths: Vec<PathBuf> = Vec::new();
 73            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
 74            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
 75            let mut save_errors: Vec<(String, String)> = Vec::new();
 76
 77            for path in input_paths {
 78                let Some(project_path) =
 79                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
 80                else {
 81                    not_found_paths.push(path);
 82                    continue;
 83                };
 84
 85                let open_buffer_task =
 86                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
 87
 88                let buffer = match open_buffer_task.await {
 89                    Ok(buffer) => buffer,
 90                    Err(error) => {
 91                        open_errors.push((path, error.to_string()));
 92                        continue;
 93                    }
 94                };
 95
 96                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
 97
 98                if is_dirty {
 99                    buffers_to_save.insert(buffer);
100                    saved_paths.push(path);
101                } else {
102                    clean_paths.push(path);
103                }
104            }
105
106            // Save each buffer individually since there's no batch save API.
107            for buffer in buffers_to_save {
108                let path_for_buffer = buffer
109                    .read_with(cx, |buffer, _| {
110                        buffer
111                            .file()
112                            .map(|file| file.path().to_rel_path_buf())
113                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
114                    })
115                    .unwrap_or_else(|| "<unknown>".to_string());
116
117                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
118
119                if let Err(error) = save_task.await {
120                    save_errors.push((path_for_buffer, error.to_string()));
121                }
122            }
123
124            let mut lines: Vec<String> = Vec::new();
125
126            if !saved_paths.is_empty() {
127                lines.push(format!("Saved {} file(s).", saved_paths.len()));
128            }
129            if !clean_paths.is_empty() {
130                lines.push(format!("{} clean.", clean_paths.len()));
131            }
132
133            if !not_found_paths.is_empty() {
134                lines.push(format!("Not found ({}):", not_found_paths.len()));
135                for path in &not_found_paths {
136                    lines.push(format!("- {}", path.display()));
137                }
138            }
139            if !open_errors.is_empty() {
140                lines.push(format!("Open failed ({}):", open_errors.len()));
141                for (path, error) in &open_errors {
142                    lines.push(format!("- {}: {}", path.display(), error));
143                }
144            }
145            if !dirty_check_errors.is_empty() {
146                lines.push(format!(
147                    "Dirty check failed ({}):",
148                    dirty_check_errors.len()
149                ));
150                for (path, error) in &dirty_check_errors {
151                    lines.push(format!("- {}: {}", path.display(), error));
152                }
153            }
154            if !save_errors.is_empty() {
155                lines.push(format!("Save failed ({}):", save_errors.len()));
156                for (path, error) in &save_errors {
157                    lines.push(format!("- {}: {}", path, error));
158                }
159            }
160
161            if lines.is_empty() {
162                Ok("No paths provided.".to_string())
163            } else {
164                Ok(lines.join("\n"))
165            }
166        })
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use fs::Fs;
174    use gpui::TestAppContext;
175    use project::FakeFs;
176    use serde_json::json;
177    use settings::SettingsStore;
178    use util::path;
179
180    fn init_test(cx: &mut TestAppContext) {
181        cx.update(|cx| {
182            let settings_store = SettingsStore::test(cx);
183            cx.set_global(settings_store);
184        });
185    }
186
187    #[gpui::test]
188    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
189        init_test(cx);
190
191        let fs = FakeFs::new(cx.executor());
192        fs.insert_tree(
193            "/root",
194            json!({
195                "dirty.txt": "on disk: dirty\n",
196                "clean.txt": "on disk: clean\n",
197            }),
198        )
199        .await;
200
201        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
202        let tool = Arc::new(SaveFileTool::new(project.clone()));
203
204        // Make dirty.txt dirty in-memory.
205        let dirty_project_path = project.read_with(cx, |project, cx| {
206            project
207                .find_project_path("root/dirty.txt", cx)
208                .expect("dirty.txt should exist in project")
209        });
210
211        let dirty_buffer = project
212            .update(cx, |project, cx| {
213                project.open_buffer(dirty_project_path, cx)
214            })
215            .await
216            .unwrap();
217        dirty_buffer.update(cx, |buffer, cx| {
218            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
219        });
220        assert!(
221            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
222            "dirty.txt buffer should be dirty before save"
223        );
224
225        // Ensure clean.txt is opened but remains clean.
226        let clean_project_path = project.read_with(cx, |project, cx| {
227            project
228                .find_project_path("root/clean.txt", cx)
229                .expect("clean.txt should exist in project")
230        });
231
232        let clean_buffer = project
233            .update(cx, |project, cx| {
234                project.open_buffer(clean_project_path, cx)
235            })
236            .await
237            .unwrap();
238        assert!(
239            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
240            "clean.txt buffer should start clean"
241        );
242
243        let output = cx
244            .update(|cx| {
245                tool.clone().run(
246                    SaveFileToolInput {
247                        paths: vec![
248                            PathBuf::from("root/dirty.txt"),
249                            PathBuf::from("root/clean.txt"),
250                        ],
251                    },
252                    ToolCallEventStream::test().0,
253                    cx,
254                )
255            })
256            .await
257            .unwrap();
258
259        // Output should mention saved + clean.
260        assert!(
261            output.contains("Saved 1 file(s)."),
262            "expected saved count line, got:\n{output}"
263        );
264        assert!(
265            output.contains("1 clean."),
266            "expected clean count line, got:\n{output}"
267        );
268
269        // Effect: dirty buffer should now be clean and disk should have new content.
270        assert!(
271            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
272            "dirty.txt buffer should not be dirty after save"
273        );
274
275        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
276        assert_eq!(
277            disk_dirty, "in memory: dirty\n",
278            "dirty.txt disk content should be updated"
279        );
280
281        // Sanity: clean buffer should remain clean and disk unchanged.
282        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
283        assert_eq!(disk_clean, "on disk: clean\n");
284
285        // Test empty paths case.
286        let output = cx
287            .update(|cx| {
288                tool.clone().run(
289                    SaveFileToolInput { paths: vec![] },
290                    ToolCallEventStream::test().0,
291                    cx,
292                )
293            })
294            .await
295            .unwrap();
296        assert_eq!(output, "No paths provided.");
297
298        // Test not-found path case.
299        let output = cx
300            .update(|cx| {
301                tool.clone().run(
302                    SaveFileToolInput {
303                        paths: vec![PathBuf::from("nonexistent/path.txt")],
304                    },
305                    ToolCallEventStream::test().0,
306                    cx,
307                )
308            })
309            .await
310            .unwrap();
311        assert!(
312            output.contains("Not found (1):"),
313            "expected not-found header line, got:\n{output}"
314        );
315        assert!(
316            output.contains("- nonexistent/path.txt"),
317            "expected not-found path bullet, got:\n{output}"
318        );
319    }
320}