save_file_tool.rs

  1use agent_client_protocol as acp;
  2use agent_settings::AgentSettings;
  3use anyhow::Result;
  4use collections::FxHashSet;
  5use futures::FutureExt as _;
  6use gpui::{App, Entity, SharedString, Task};
  7use language::Buffer;
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::Settings;
 12use std::path::{Path, PathBuf};
 13use std::sync::Arc;
 14use util::markdown::MarkdownInlineCode;
 15
 16use super::edit_file_tool::{
 17    SensitiveSettingsKind, is_sensitive_settings_path, sensitive_settings_kind,
 18};
 19use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path};
 20
 21/// Saves files that have unsaved changes.
 22///
 23/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 24/// Only use this tool after asking the user for permission to save their unsaved changes.
 25#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 26pub struct SaveFileToolInput {
 27    /// The paths of the files to save.
 28    pub paths: Vec<PathBuf>,
 29}
 30
 31pub struct SaveFileTool {
 32    project: Entity<Project>,
 33}
 34
 35impl SaveFileTool {
 36    pub fn new(project: Entity<Project>) -> Self {
 37        Self { project }
 38    }
 39}
 40
 41impl AgentTool for SaveFileTool {
 42    type Input = SaveFileToolInput;
 43    type Output = String;
 44
 45    const NAME: &'static str = "save_file";
 46
 47    fn kind() -> acp::ToolKind {
 48        acp::ToolKind::Other
 49    }
 50
 51    fn initial_title(
 52        &self,
 53        input: Result<Self::Input, serde_json::Value>,
 54        _cx: &mut App,
 55    ) -> SharedString {
 56        match input {
 57            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 58            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 59            Err(_) => "Save files".into(),
 60        }
 61    }
 62
 63    fn run(
 64        self: Arc<Self>,
 65        input: Self::Input,
 66        event_stream: ToolCallEventStream,
 67        cx: &mut App,
 68    ) -> Task<Result<String>> {
 69        let settings = AgentSettings::get_global(cx);
 70        let mut confirmation_paths: Vec<String> = Vec::new();
 71
 72        for path in &input.paths {
 73            let path_str = path.to_string_lossy();
 74            let decision = decide_permission_for_path(Self::NAME, &path_str, settings);
 75            match decision {
 76                ToolPermissionDecision::Allow => {
 77                    if is_sensitive_settings_path(Path::new(&*path_str)) {
 78                        confirmation_paths.push(path_str.to_string());
 79                    }
 80                }
 81                ToolPermissionDecision::Deny(reason) => {
 82                    return Task::ready(Err(anyhow::anyhow!("{}", reason)));
 83                }
 84                ToolPermissionDecision::Confirm => {
 85                    confirmation_paths.push(path_str.to_string());
 86                }
 87            }
 88        }
 89
 90        let authorize = if !confirmation_paths.is_empty() {
 91            let title = if confirmation_paths.len() == 1 {
 92                format!("Save {}", MarkdownInlineCode(&confirmation_paths[0]))
 93            } else {
 94                let paths: Vec<_> = confirmation_paths
 95                    .iter()
 96                    .take(3)
 97                    .map(|p| p.as_str())
 98                    .collect();
 99                if confirmation_paths.len() > 3 {
100                    format!(
101                        "Save {}, and {} more",
102                        paths.join(", "),
103                        confirmation_paths.len() - 3
104                    )
105                } else {
106                    format!("Save {}", paths.join(", "))
107                }
108            };
109            let sensitive_kind = confirmation_paths
110                .iter()
111                .find_map(|p| sensitive_settings_kind(Path::new(p)));
112            let title = match sensitive_kind {
113                Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
114                Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
115                None => title,
116            };
117            let context = crate::ToolPermissionContext {
118                tool_name: Self::NAME.to_string(),
119                input_values: confirmation_paths.clone(),
120            };
121            Some(event_stream.authorize(title, context, cx))
122        } else {
123            None
124        };
125
126        let project = self.project.clone();
127        let input_paths = input.paths;
128
129        cx.spawn(async move |cx| {
130            if let Some(authorize) = authorize {
131                authorize.await?;
132            }
133
134            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
135
136            let mut dirty_count: usize = 0;
137            let mut clean_paths: Vec<PathBuf> = Vec::new();
138            let mut not_found_paths: Vec<PathBuf> = Vec::new();
139            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
140            let mut save_errors: Vec<(String, String)> = Vec::new();
141
142            for path in input_paths {
143                let Some(project_path) =
144                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
145                else {
146                    not_found_paths.push(path);
147                    continue;
148                };
149
150                let open_buffer_task =
151                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
152
153                let buffer = futures::select! {
154                    result = open_buffer_task.fuse() => {
155                        match result {
156                            Ok(buffer) => buffer,
157                            Err(error) => {
158                                open_errors.push((path, error.to_string()));
159                                continue;
160                            }
161                        }
162                    }
163                    _ = event_stream.cancelled_by_user().fuse() => {
164                        anyhow::bail!("Save cancelled by user");
165                    }
166                };
167
168                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
169
170                if is_dirty {
171                    buffers_to_save.insert(buffer);
172                    dirty_count += 1;
173                } else {
174                    clean_paths.push(path);
175                }
176            }
177
178            // Save each buffer individually since there's no batch save API.
179            for buffer in buffers_to_save {
180                let path_for_buffer = buffer
181                    .read_with(cx, |buffer, _| {
182                        buffer
183                            .file()
184                            .map(|file| file.path().to_rel_path_buf())
185                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
186                    })
187                    .unwrap_or_else(|| "<unknown>".to_string());
188
189                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
190
191                let save_result = futures::select! {
192                    result = save_task.fuse() => result,
193                    _ = event_stream.cancelled_by_user().fuse() => {
194                        anyhow::bail!("Save cancelled by user");
195                    }
196                };
197                if let Err(error) = save_result {
198                    save_errors.push((path_for_buffer, error.to_string()));
199                }
200            }
201
202            let mut lines: Vec<String> = Vec::new();
203
204            let successful_saves = dirty_count.saturating_sub(save_errors.len());
205            if successful_saves > 0 {
206                lines.push(format!("Saved {} file(s).", successful_saves));
207            }
208            if !clean_paths.is_empty() {
209                lines.push(format!("{} clean.", clean_paths.len()));
210            }
211
212            if !not_found_paths.is_empty() {
213                lines.push(format!("Not found ({}):", not_found_paths.len()));
214                for path in &not_found_paths {
215                    lines.push(format!("- {}", path.display()));
216                }
217            }
218            if !open_errors.is_empty() {
219                lines.push(format!("Open failed ({}):", open_errors.len()));
220                for (path, error) in &open_errors {
221                    lines.push(format!("- {}: {}", path.display(), error));
222                }
223            }
224            if !save_errors.is_empty() {
225                lines.push(format!("Save failed ({}):", save_errors.len()));
226                for (path, error) in &save_errors {
227                    lines.push(format!("- {}: {}", path, error));
228                }
229            }
230
231            if lines.is_empty() {
232                Ok("No paths provided.".to_string())
233            } else {
234                Ok(lines.join("\n"))
235            }
236        })
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use fs::Fs;
244    use gpui::TestAppContext;
245    use project::FakeFs;
246    use serde_json::json;
247    use settings::SettingsStore;
248    use util::path;
249
250    fn init_test(cx: &mut TestAppContext) {
251        cx.update(|cx| {
252            let settings_store = SettingsStore::test(cx);
253            cx.set_global(settings_store);
254        });
255        cx.update(|cx| {
256            let mut settings = AgentSettings::get_global(cx).clone();
257            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
258            AgentSettings::override_global(settings, cx);
259        });
260    }
261
262    #[gpui::test]
263    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
264        init_test(cx);
265
266        let fs = FakeFs::new(cx.executor());
267        fs.insert_tree(
268            "/root",
269            json!({
270                "dirty.txt": "on disk: dirty\n",
271                "clean.txt": "on disk: clean\n",
272            }),
273        )
274        .await;
275
276        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
277        let tool = Arc::new(SaveFileTool::new(project.clone()));
278
279        // Make dirty.txt dirty in-memory.
280        let dirty_project_path = project.read_with(cx, |project, cx| {
281            project
282                .find_project_path("root/dirty.txt", cx)
283                .expect("dirty.txt should exist in project")
284        });
285
286        let dirty_buffer = project
287            .update(cx, |project, cx| {
288                project.open_buffer(dirty_project_path, cx)
289            })
290            .await
291            .unwrap();
292        dirty_buffer.update(cx, |buffer, cx| {
293            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
294        });
295        assert!(
296            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
297            "dirty.txt buffer should be dirty before save"
298        );
299
300        // Ensure clean.txt is opened but remains clean.
301        let clean_project_path = project.read_with(cx, |project, cx| {
302            project
303                .find_project_path("root/clean.txt", cx)
304                .expect("clean.txt should exist in project")
305        });
306
307        let clean_buffer = project
308            .update(cx, |project, cx| {
309                project.open_buffer(clean_project_path, cx)
310            })
311            .await
312            .unwrap();
313        assert!(
314            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
315            "clean.txt buffer should start clean"
316        );
317
318        let output = cx
319            .update(|cx| {
320                tool.clone().run(
321                    SaveFileToolInput {
322                        paths: vec![
323                            PathBuf::from("root/dirty.txt"),
324                            PathBuf::from("root/clean.txt"),
325                        ],
326                    },
327                    ToolCallEventStream::test().0,
328                    cx,
329                )
330            })
331            .await
332            .unwrap();
333
334        // Output should mention saved + clean.
335        assert!(
336            output.contains("Saved 1 file(s)."),
337            "expected saved count line, got:\n{output}"
338        );
339        assert!(
340            output.contains("1 clean."),
341            "expected clean count line, got:\n{output}"
342        );
343
344        // Effect: dirty buffer should now be clean and disk should have new content.
345        assert!(
346            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
347            "dirty.txt buffer should not be dirty after save"
348        );
349
350        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
351        assert_eq!(
352            disk_dirty, "in memory: dirty\n",
353            "dirty.txt disk content should be updated"
354        );
355
356        // Sanity: clean buffer should remain clean and disk unchanged.
357        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
358        assert_eq!(disk_clean, "on disk: clean\n");
359
360        // Test empty paths case.
361        let output = cx
362            .update(|cx| {
363                tool.clone().run(
364                    SaveFileToolInput { paths: vec![] },
365                    ToolCallEventStream::test().0,
366                    cx,
367                )
368            })
369            .await
370            .unwrap();
371        assert_eq!(output, "No paths provided.");
372
373        // Test not-found path case.
374        let output = cx
375            .update(|cx| {
376                tool.clone().run(
377                    SaveFileToolInput {
378                        paths: vec![PathBuf::from("nonexistent/path.txt")],
379                    },
380                    ToolCallEventStream::test().0,
381                    cx,
382                )
383            })
384            .await
385            .unwrap();
386        assert!(
387            output.contains("Not found (1):"),
388            "expected not-found header line, got:\n{output}"
389        );
390        assert!(
391            output.contains("- nonexistent/path.txt"),
392            "expected not-found path bullet, got:\n{output}"
393        );
394    }
395}