save_file_tool.rs

  1use agent_client_protocol as acp;
  2use agent_settings::AgentSettings;
  3use anyhow::Result;
  4use collections::FxHashSet;
  5use futures::FutureExt as _;
  6use gpui::{App, Entity, SharedString, Task};
  7use language::Buffer;
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::Settings;
 12use std::path::PathBuf;
 13use std::sync::Arc;
 14use util::markdown::MarkdownInlineCode;
 15
 16use crate::{
 17    AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
 18};
 19
 20/// Saves files that have unsaved changes.
 21///
 22/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 23/// Only use this tool after asking the user for permission to save their unsaved changes.
 24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 25pub struct SaveFileToolInput {
 26    /// The paths of the files to save.
 27    pub paths: Vec<PathBuf>,
 28}
 29
 30pub struct SaveFileTool {
 31    project: Entity<Project>,
 32}
 33
 34impl SaveFileTool {
 35    pub fn new(project: Entity<Project>) -> Self {
 36        Self { project }
 37    }
 38}
 39
 40impl AgentTool for SaveFileTool {
 41    type Input = SaveFileToolInput;
 42    type Output = String;
 43
 44    const NAME: &'static str = "save_file";
 45
 46    fn kind() -> acp::ToolKind {
 47        acp::ToolKind::Other
 48    }
 49
 50    fn initial_title(
 51        &self,
 52        input: Result<Self::Input, serde_json::Value>,
 53        _cx: &mut App,
 54    ) -> SharedString {
 55        match input {
 56            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 57            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 58            Err(_) => "Save files".into(),
 59        }
 60    }
 61
 62    fn run(
 63        self: Arc<Self>,
 64        input: Self::Input,
 65        event_stream: ToolCallEventStream,
 66        cx: &mut App,
 67    ) -> Task<Result<String>> {
 68        let settings = AgentSettings::get_global(cx);
 69        let mut needs_confirmation = false;
 70
 71        for path in &input.paths {
 72            let path_str = path.to_string_lossy();
 73            let decision = decide_permission_from_settings(Self::NAME, &path_str, settings);
 74            match decision {
 75                ToolPermissionDecision::Allow => {}
 76                ToolPermissionDecision::Deny(reason) => {
 77                    return Task::ready(Err(anyhow::anyhow!("{}", reason)));
 78                }
 79                ToolPermissionDecision::Confirm => {
 80                    needs_confirmation = true;
 81                }
 82            }
 83        }
 84
 85        let authorize = if needs_confirmation {
 86            let title = if input.paths.len() == 1 {
 87                format!(
 88                    "Save {}",
 89                    MarkdownInlineCode(&input.paths[0].to_string_lossy())
 90                )
 91            } else {
 92                let paths: Vec<_> = input
 93                    .paths
 94                    .iter()
 95                    .take(3)
 96                    .map(|p| p.to_string_lossy().to_string())
 97                    .collect();
 98                if input.paths.len() > 3 {
 99                    format!(
100                        "Save {}, and {} more",
101                        paths.join(", "),
102                        input.paths.len() - 3
103                    )
104                } else {
105                    format!("Save {}", paths.join(", "))
106                }
107            };
108            let first_path = input
109                .paths
110                .first()
111                .map(|p| p.to_string_lossy().to_string())
112                .unwrap_or_default();
113            let context = crate::ToolPermissionContext {
114                tool_name: Self::NAME.to_string(),
115                input_value: first_path,
116            };
117            Some(event_stream.authorize(title, context, cx))
118        } else {
119            None
120        };
121
122        let project = self.project.clone();
123        let input_paths = input.paths;
124
125        cx.spawn(async move |cx| {
126            if let Some(authorize) = authorize {
127                authorize.await?;
128            }
129
130            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
131
132            let mut saved_paths: Vec<PathBuf> = Vec::new();
133            let mut clean_paths: Vec<PathBuf> = Vec::new();
134            let mut not_found_paths: Vec<PathBuf> = Vec::new();
135            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
136            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
137            let mut save_errors: Vec<(String, String)> = Vec::new();
138
139            for path in input_paths {
140                let Some(project_path) =
141                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
142                else {
143                    not_found_paths.push(path);
144                    continue;
145                };
146
147                let open_buffer_task =
148                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
149
150                let buffer = futures::select! {
151                    result = open_buffer_task.fuse() => {
152                        match result {
153                            Ok(buffer) => buffer,
154                            Err(error) => {
155                                open_errors.push((path, error.to_string()));
156                                continue;
157                            }
158                        }
159                    }
160                    _ = event_stream.cancelled_by_user().fuse() => {
161                        anyhow::bail!("Save cancelled by user");
162                    }
163                };
164
165                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
166
167                if is_dirty {
168                    buffers_to_save.insert(buffer);
169                    saved_paths.push(path);
170                } else {
171                    clean_paths.push(path);
172                }
173            }
174
175            // Save each buffer individually since there's no batch save API.
176            for buffer in buffers_to_save {
177                let path_for_buffer = buffer
178                    .read_with(cx, |buffer, _| {
179                        buffer
180                            .file()
181                            .map(|file| file.path().to_rel_path_buf())
182                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
183                    })
184                    .unwrap_or_else(|| "<unknown>".to_string());
185
186                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
187
188                let save_result = futures::select! {
189                    result = save_task.fuse() => result,
190                    _ = event_stream.cancelled_by_user().fuse() => {
191                        anyhow::bail!("Save cancelled by user");
192                    }
193                };
194                if let Err(error) = save_result {
195                    save_errors.push((path_for_buffer, error.to_string()));
196                }
197            }
198
199            let mut lines: Vec<String> = Vec::new();
200
201            if !saved_paths.is_empty() {
202                lines.push(format!("Saved {} file(s).", saved_paths.len()));
203            }
204            if !clean_paths.is_empty() {
205                lines.push(format!("{} clean.", clean_paths.len()));
206            }
207
208            if !not_found_paths.is_empty() {
209                lines.push(format!("Not found ({}):", not_found_paths.len()));
210                for path in &not_found_paths {
211                    lines.push(format!("- {}", path.display()));
212                }
213            }
214            if !open_errors.is_empty() {
215                lines.push(format!("Open failed ({}):", open_errors.len()));
216                for (path, error) in &open_errors {
217                    lines.push(format!("- {}: {}", path.display(), error));
218                }
219            }
220            if !dirty_check_errors.is_empty() {
221                lines.push(format!(
222                    "Dirty check failed ({}):",
223                    dirty_check_errors.len()
224                ));
225                for (path, error) in &dirty_check_errors {
226                    lines.push(format!("- {}: {}", path.display(), error));
227                }
228            }
229            if !save_errors.is_empty() {
230                lines.push(format!("Save failed ({}):", save_errors.len()));
231                for (path, error) in &save_errors {
232                    lines.push(format!("- {}: {}", path, error));
233                }
234            }
235
236            if lines.is_empty() {
237                Ok("No paths provided.".to_string())
238            } else {
239                Ok(lines.join("\n"))
240            }
241        })
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use fs::Fs;
249    use gpui::TestAppContext;
250    use project::FakeFs;
251    use serde_json::json;
252    use settings::SettingsStore;
253    use util::path;
254
255    fn init_test(cx: &mut TestAppContext) {
256        cx.update(|cx| {
257            let settings_store = SettingsStore::test(cx);
258            cx.set_global(settings_store);
259        });
260        cx.update(|cx| {
261            let mut settings = AgentSettings::get_global(cx).clone();
262            settings.always_allow_tool_actions = true;
263            AgentSettings::override_global(settings, cx);
264        });
265    }
266
267    #[gpui::test]
268    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
269        init_test(cx);
270
271        let fs = FakeFs::new(cx.executor());
272        fs.insert_tree(
273            "/root",
274            json!({
275                "dirty.txt": "on disk: dirty\n",
276                "clean.txt": "on disk: clean\n",
277            }),
278        )
279        .await;
280
281        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
282        let tool = Arc::new(SaveFileTool::new(project.clone()));
283
284        // Make dirty.txt dirty in-memory.
285        let dirty_project_path = project.read_with(cx, |project, cx| {
286            project
287                .find_project_path("root/dirty.txt", cx)
288                .expect("dirty.txt should exist in project")
289        });
290
291        let dirty_buffer = project
292            .update(cx, |project, cx| {
293                project.open_buffer(dirty_project_path, cx)
294            })
295            .await
296            .unwrap();
297        dirty_buffer.update(cx, |buffer, cx| {
298            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
299        });
300        assert!(
301            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
302            "dirty.txt buffer should be dirty before save"
303        );
304
305        // Ensure clean.txt is opened but remains clean.
306        let clean_project_path = project.read_with(cx, |project, cx| {
307            project
308                .find_project_path("root/clean.txt", cx)
309                .expect("clean.txt should exist in project")
310        });
311
312        let clean_buffer = project
313            .update(cx, |project, cx| {
314                project.open_buffer(clean_project_path, cx)
315            })
316            .await
317            .unwrap();
318        assert!(
319            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
320            "clean.txt buffer should start clean"
321        );
322
323        let output = cx
324            .update(|cx| {
325                tool.clone().run(
326                    SaveFileToolInput {
327                        paths: vec![
328                            PathBuf::from("root/dirty.txt"),
329                            PathBuf::from("root/clean.txt"),
330                        ],
331                    },
332                    ToolCallEventStream::test().0,
333                    cx,
334                )
335            })
336            .await
337            .unwrap();
338
339        // Output should mention saved + clean.
340        assert!(
341            output.contains("Saved 1 file(s)."),
342            "expected saved count line, got:\n{output}"
343        );
344        assert!(
345            output.contains("1 clean."),
346            "expected clean count line, got:\n{output}"
347        );
348
349        // Effect: dirty buffer should now be clean and disk should have new content.
350        assert!(
351            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
352            "dirty.txt buffer should not be dirty after save"
353        );
354
355        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
356        assert_eq!(
357            disk_dirty, "in memory: dirty\n",
358            "dirty.txt disk content should be updated"
359        );
360
361        // Sanity: clean buffer should remain clean and disk unchanged.
362        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
363        assert_eq!(disk_clean, "on disk: clean\n");
364
365        // Test empty paths case.
366        let output = cx
367            .update(|cx| {
368                tool.clone().run(
369                    SaveFileToolInput { paths: vec![] },
370                    ToolCallEventStream::test().0,
371                    cx,
372                )
373            })
374            .await
375            .unwrap();
376        assert_eq!(output, "No paths provided.");
377
378        // Test not-found path case.
379        let output = cx
380            .update(|cx| {
381                tool.clone().run(
382                    SaveFileToolInput {
383                        paths: vec![PathBuf::from("nonexistent/path.txt")],
384                    },
385                    ToolCallEventStream::test().0,
386                    cx,
387                )
388            })
389            .await
390            .unwrap();
391        assert!(
392            output.contains("Not found (1):"),
393            "expected not-found header line, got:\n{output}"
394        );
395        assert!(
396            output.contains("- nonexistent/path.txt"),
397            "expected not-found path bullet, got:\n{output}"
398        );
399    }
400}