save_file_tool.rs

  1use agent_client_protocol as acp;
  2use agent_settings::AgentSettings;
  3use anyhow::Result;
  4use collections::FxHashSet;
  5use futures::FutureExt as _;
  6use gpui::{App, Entity, SharedString, Task};
  7use language::Buffer;
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::Settings;
 12use std::path::PathBuf;
 13use std::sync::Arc;
 14use util::markdown::MarkdownInlineCode;
 15
 16use crate::{
 17    AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
 18};
 19
 20/// Saves files that have unsaved changes.
 21///
 22/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 23/// Only use this tool after asking the user for permission to save their unsaved changes.
 24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 25pub struct SaveFileToolInput {
 26    /// The paths of the files to save.
 27    pub paths: Vec<PathBuf>,
 28}
 29
 30pub struct SaveFileTool {
 31    project: Entity<Project>,
 32}
 33
 34impl SaveFileTool {
 35    pub fn new(project: Entity<Project>) -> Self {
 36        Self { project }
 37    }
 38}
 39
 40impl AgentTool for SaveFileTool {
 41    type Input = SaveFileToolInput;
 42    type Output = String;
 43
 44    fn name() -> &'static str {
 45        "save_file"
 46    }
 47
 48    fn kind() -> acp::ToolKind {
 49        acp::ToolKind::Other
 50    }
 51
 52    fn initial_title(
 53        &self,
 54        input: Result<Self::Input, serde_json::Value>,
 55        _cx: &mut App,
 56    ) -> SharedString {
 57        match input {
 58            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 59            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 60            Err(_) => "Save files".into(),
 61        }
 62    }
 63
 64    fn run(
 65        self: Arc<Self>,
 66        input: Self::Input,
 67        event_stream: ToolCallEventStream,
 68        cx: &mut App,
 69    ) -> Task<Result<String>> {
 70        let settings = AgentSettings::get_global(cx);
 71        let mut needs_confirmation = false;
 72
 73        for path in &input.paths {
 74            let path_str = path.to_string_lossy();
 75            let decision = decide_permission_from_settings(Self::name(), &path_str, settings);
 76            match decision {
 77                ToolPermissionDecision::Allow => {}
 78                ToolPermissionDecision::Deny(reason) => {
 79                    return Task::ready(Err(anyhow::anyhow!("{}", reason)));
 80                }
 81                ToolPermissionDecision::Confirm => {
 82                    needs_confirmation = true;
 83                }
 84            }
 85        }
 86
 87        let authorize = if needs_confirmation {
 88            let title = if input.paths.len() == 1 {
 89                format!(
 90                    "Save {}",
 91                    MarkdownInlineCode(&input.paths[0].to_string_lossy())
 92                )
 93            } else {
 94                let paths: Vec<_> = input
 95                    .paths
 96                    .iter()
 97                    .take(3)
 98                    .map(|p| p.to_string_lossy().to_string())
 99                    .collect();
100                if input.paths.len() > 3 {
101                    format!(
102                        "Save {}, and {} more",
103                        paths.join(", "),
104                        input.paths.len() - 3
105                    )
106                } else {
107                    format!("Save {}", paths.join(", "))
108                }
109            };
110            Some(event_stream.authorize(title, cx))
111        } else {
112            None
113        };
114
115        let project = self.project.clone();
116        let input_paths = input.paths;
117
118        cx.spawn(async move |cx| {
119            if let Some(authorize) = authorize {
120                authorize.await?;
121            }
122
123            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
124
125            let mut saved_paths: Vec<PathBuf> = Vec::new();
126            let mut clean_paths: Vec<PathBuf> = Vec::new();
127            let mut not_found_paths: Vec<PathBuf> = Vec::new();
128            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
129            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
130            let mut save_errors: Vec<(String, String)> = Vec::new();
131
132            for path in input_paths {
133                let Some(project_path) =
134                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
135                else {
136                    not_found_paths.push(path);
137                    continue;
138                };
139
140                let open_buffer_task =
141                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
142
143                let buffer = futures::select! {
144                    result = open_buffer_task.fuse() => {
145                        match result {
146                            Ok(buffer) => buffer,
147                            Err(error) => {
148                                open_errors.push((path, error.to_string()));
149                                continue;
150                            }
151                        }
152                    }
153                    _ = event_stream.cancelled_by_user().fuse() => {
154                        anyhow::bail!("Save cancelled by user");
155                    }
156                };
157
158                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
159
160                if is_dirty {
161                    buffers_to_save.insert(buffer);
162                    saved_paths.push(path);
163                } else {
164                    clean_paths.push(path);
165                }
166            }
167
168            // Save each buffer individually since there's no batch save API.
169            for buffer in buffers_to_save {
170                let path_for_buffer = buffer
171                    .read_with(cx, |buffer, _| {
172                        buffer
173                            .file()
174                            .map(|file| file.path().to_rel_path_buf())
175                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
176                    })
177                    .unwrap_or_else(|| "<unknown>".to_string());
178
179                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
180
181                let save_result = futures::select! {
182                    result = save_task.fuse() => result,
183                    _ = event_stream.cancelled_by_user().fuse() => {
184                        anyhow::bail!("Save cancelled by user");
185                    }
186                };
187                if let Err(error) = save_result {
188                    save_errors.push((path_for_buffer, error.to_string()));
189                }
190            }
191
192            let mut lines: Vec<String> = Vec::new();
193
194            if !saved_paths.is_empty() {
195                lines.push(format!("Saved {} file(s).", saved_paths.len()));
196            }
197            if !clean_paths.is_empty() {
198                lines.push(format!("{} clean.", clean_paths.len()));
199            }
200
201            if !not_found_paths.is_empty() {
202                lines.push(format!("Not found ({}):", not_found_paths.len()));
203                for path in &not_found_paths {
204                    lines.push(format!("- {}", path.display()));
205                }
206            }
207            if !open_errors.is_empty() {
208                lines.push(format!("Open failed ({}):", open_errors.len()));
209                for (path, error) in &open_errors {
210                    lines.push(format!("- {}: {}", path.display(), error));
211                }
212            }
213            if !dirty_check_errors.is_empty() {
214                lines.push(format!(
215                    "Dirty check failed ({}):",
216                    dirty_check_errors.len()
217                ));
218                for (path, error) in &dirty_check_errors {
219                    lines.push(format!("- {}: {}", path.display(), error));
220                }
221            }
222            if !save_errors.is_empty() {
223                lines.push(format!("Save failed ({}):", save_errors.len()));
224                for (path, error) in &save_errors {
225                    lines.push(format!("- {}: {}", path, error));
226                }
227            }
228
229            if lines.is_empty() {
230                Ok("No paths provided.".to_string())
231            } else {
232                Ok(lines.join("\n"))
233            }
234        })
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use fs::Fs;
242    use gpui::TestAppContext;
243    use project::FakeFs;
244    use serde_json::json;
245    use settings::SettingsStore;
246    use util::path;
247
248    fn init_test(cx: &mut TestAppContext) {
249        cx.update(|cx| {
250            let settings_store = SettingsStore::test(cx);
251            cx.set_global(settings_store);
252        });
253        cx.update(|cx| {
254            let mut settings = AgentSettings::get_global(cx).clone();
255            settings.always_allow_tool_actions = true;
256            AgentSettings::override_global(settings, cx);
257        });
258    }
259
260    #[gpui::test]
261    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
262        init_test(cx);
263
264        let fs = FakeFs::new(cx.executor());
265        fs.insert_tree(
266            "/root",
267            json!({
268                "dirty.txt": "on disk: dirty\n",
269                "clean.txt": "on disk: clean\n",
270            }),
271        )
272        .await;
273
274        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
275        let tool = Arc::new(SaveFileTool::new(project.clone()));
276
277        // Make dirty.txt dirty in-memory.
278        let dirty_project_path = project.read_with(cx, |project, cx| {
279            project
280                .find_project_path("root/dirty.txt", cx)
281                .expect("dirty.txt should exist in project")
282        });
283
284        let dirty_buffer = project
285            .update(cx, |project, cx| {
286                project.open_buffer(dirty_project_path, cx)
287            })
288            .await
289            .unwrap();
290        dirty_buffer.update(cx, |buffer, cx| {
291            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
292        });
293        assert!(
294            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
295            "dirty.txt buffer should be dirty before save"
296        );
297
298        // Ensure clean.txt is opened but remains clean.
299        let clean_project_path = project.read_with(cx, |project, cx| {
300            project
301                .find_project_path("root/clean.txt", cx)
302                .expect("clean.txt should exist in project")
303        });
304
305        let clean_buffer = project
306            .update(cx, |project, cx| {
307                project.open_buffer(clean_project_path, cx)
308            })
309            .await
310            .unwrap();
311        assert!(
312            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
313            "clean.txt buffer should start clean"
314        );
315
316        let output = cx
317            .update(|cx| {
318                tool.clone().run(
319                    SaveFileToolInput {
320                        paths: vec![
321                            PathBuf::from("root/dirty.txt"),
322                            PathBuf::from("root/clean.txt"),
323                        ],
324                    },
325                    ToolCallEventStream::test().0,
326                    cx,
327                )
328            })
329            .await
330            .unwrap();
331
332        // Output should mention saved + clean.
333        assert!(
334            output.contains("Saved 1 file(s)."),
335            "expected saved count line, got:\n{output}"
336        );
337        assert!(
338            output.contains("1 clean."),
339            "expected clean count line, got:\n{output}"
340        );
341
342        // Effect: dirty buffer should now be clean and disk should have new content.
343        assert!(
344            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
345            "dirty.txt buffer should not be dirty after save"
346        );
347
348        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
349        assert_eq!(
350            disk_dirty, "in memory: dirty\n",
351            "dirty.txt disk content should be updated"
352        );
353
354        // Sanity: clean buffer should remain clean and disk unchanged.
355        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
356        assert_eq!(disk_clean, "on disk: clean\n");
357
358        // Test empty paths case.
359        let output = cx
360            .update(|cx| {
361                tool.clone().run(
362                    SaveFileToolInput { paths: vec![] },
363                    ToolCallEventStream::test().0,
364                    cx,
365                )
366            })
367            .await
368            .unwrap();
369        assert_eq!(output, "No paths provided.");
370
371        // Test not-found path case.
372        let output = cx
373            .update(|cx| {
374                tool.clone().run(
375                    SaveFileToolInput {
376                        paths: vec![PathBuf::from("nonexistent/path.txt")],
377                    },
378                    ToolCallEventStream::test().0,
379                    cx,
380                )
381            })
382            .await
383            .unwrap();
384        assert!(
385            output.contains("Not found (1):"),
386            "expected not-found header line, got:\n{output}"
387        );
388        assert!(
389            output.contains("- nonexistent/path.txt"),
390            "expected not-found path bullet, got:\n{output}"
391        );
392    }
393}