save_file_tool.rs

  1use agent_client_protocol as acp;
  2use agent_settings::AgentSettings;
  3use anyhow::Result;
  4use collections::FxHashSet;
  5use futures::FutureExt as _;
  6use gpui::{App, Entity, SharedString, Task};
  7use language::Buffer;
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::Settings;
 12use std::path::PathBuf;
 13use std::sync::Arc;
 14use util::markdown::MarkdownInlineCode;
 15
 16use crate::{
 17    AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
 18};
 19
 20/// Saves files that have unsaved changes.
 21///
 22/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 23/// Only use this tool after asking the user for permission to save their unsaved changes.
 24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 25pub struct SaveFileToolInput {
 26    /// The paths of the files to save.
 27    pub paths: Vec<PathBuf>,
 28}
 29
 30pub struct SaveFileTool {
 31    project: Entity<Project>,
 32}
 33
 34impl SaveFileTool {
 35    pub fn new(project: Entity<Project>) -> Self {
 36        Self { project }
 37    }
 38}
 39
 40impl AgentTool for SaveFileTool {
 41    type Input = SaveFileToolInput;
 42    type Output = String;
 43
 44    fn name() -> &'static str {
 45        "save_file"
 46    }
 47
 48    fn kind() -> acp::ToolKind {
 49        acp::ToolKind::Other
 50    }
 51
 52    fn initial_title(
 53        &self,
 54        input: Result<Self::Input, serde_json::Value>,
 55        _cx: &mut App,
 56    ) -> SharedString {
 57        match input {
 58            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 59            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 60            Err(_) => "Save files".into(),
 61        }
 62    }
 63
 64    fn run(
 65        self: Arc<Self>,
 66        input: Self::Input,
 67        event_stream: ToolCallEventStream,
 68        cx: &mut App,
 69    ) -> Task<Result<String>> {
 70        let settings = AgentSettings::get_global(cx);
 71        let mut needs_confirmation = false;
 72
 73        for path in &input.paths {
 74            let path_str = path.to_string_lossy();
 75            let decision = decide_permission_from_settings(Self::name(), &path_str, settings);
 76            match decision {
 77                ToolPermissionDecision::Allow => {}
 78                ToolPermissionDecision::Deny(reason) => {
 79                    return Task::ready(Err(anyhow::anyhow!("{}", reason)));
 80                }
 81                ToolPermissionDecision::Confirm => {
 82                    needs_confirmation = true;
 83                }
 84            }
 85        }
 86
 87        let authorize = if needs_confirmation {
 88            let title = if input.paths.len() == 1 {
 89                format!(
 90                    "Save {}",
 91                    MarkdownInlineCode(&input.paths[0].to_string_lossy())
 92                )
 93            } else {
 94                let paths: Vec<_> = input
 95                    .paths
 96                    .iter()
 97                    .take(3)
 98                    .map(|p| p.to_string_lossy().to_string())
 99                    .collect();
100                if input.paths.len() > 3 {
101                    format!(
102                        "Save {}, and {} more",
103                        paths.join(", "),
104                        input.paths.len() - 3
105                    )
106                } else {
107                    format!("Save {}", paths.join(", "))
108                }
109            };
110            let first_path = input
111                .paths
112                .first()
113                .map(|p| p.to_string_lossy().to_string())
114                .unwrap_or_default();
115            let context = crate::ToolPermissionContext {
116                tool_name: "save_file".to_string(),
117                input_value: first_path,
118            };
119            Some(event_stream.authorize(title, context, cx))
120        } else {
121            None
122        };
123
124        let project = self.project.clone();
125        let input_paths = input.paths;
126
127        cx.spawn(async move |cx| {
128            if let Some(authorize) = authorize {
129                authorize.await?;
130            }
131
132            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
133
134            let mut saved_paths: Vec<PathBuf> = Vec::new();
135            let mut clean_paths: Vec<PathBuf> = Vec::new();
136            let mut not_found_paths: Vec<PathBuf> = Vec::new();
137            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
138            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
139            let mut save_errors: Vec<(String, String)> = Vec::new();
140
141            for path in input_paths {
142                let Some(project_path) =
143                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
144                else {
145                    not_found_paths.push(path);
146                    continue;
147                };
148
149                let open_buffer_task =
150                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
151
152                let buffer = futures::select! {
153                    result = open_buffer_task.fuse() => {
154                        match result {
155                            Ok(buffer) => buffer,
156                            Err(error) => {
157                                open_errors.push((path, error.to_string()));
158                                continue;
159                            }
160                        }
161                    }
162                    _ = event_stream.cancelled_by_user().fuse() => {
163                        anyhow::bail!("Save cancelled by user");
164                    }
165                };
166
167                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
168
169                if is_dirty {
170                    buffers_to_save.insert(buffer);
171                    saved_paths.push(path);
172                } else {
173                    clean_paths.push(path);
174                }
175            }
176
177            // Save each buffer individually since there's no batch save API.
178            for buffer in buffers_to_save {
179                let path_for_buffer = buffer
180                    .read_with(cx, |buffer, _| {
181                        buffer
182                            .file()
183                            .map(|file| file.path().to_rel_path_buf())
184                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
185                    })
186                    .unwrap_or_else(|| "<unknown>".to_string());
187
188                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
189
190                let save_result = futures::select! {
191                    result = save_task.fuse() => result,
192                    _ = event_stream.cancelled_by_user().fuse() => {
193                        anyhow::bail!("Save cancelled by user");
194                    }
195                };
196                if let Err(error) = save_result {
197                    save_errors.push((path_for_buffer, error.to_string()));
198                }
199            }
200
201            let mut lines: Vec<String> = Vec::new();
202
203            if !saved_paths.is_empty() {
204                lines.push(format!("Saved {} file(s).", saved_paths.len()));
205            }
206            if !clean_paths.is_empty() {
207                lines.push(format!("{} clean.", clean_paths.len()));
208            }
209
210            if !not_found_paths.is_empty() {
211                lines.push(format!("Not found ({}):", not_found_paths.len()));
212                for path in &not_found_paths {
213                    lines.push(format!("- {}", path.display()));
214                }
215            }
216            if !open_errors.is_empty() {
217                lines.push(format!("Open failed ({}):", open_errors.len()));
218                for (path, error) in &open_errors {
219                    lines.push(format!("- {}: {}", path.display(), error));
220                }
221            }
222            if !dirty_check_errors.is_empty() {
223                lines.push(format!(
224                    "Dirty check failed ({}):",
225                    dirty_check_errors.len()
226                ));
227                for (path, error) in &dirty_check_errors {
228                    lines.push(format!("- {}: {}", path.display(), error));
229                }
230            }
231            if !save_errors.is_empty() {
232                lines.push(format!("Save failed ({}):", save_errors.len()));
233                for (path, error) in &save_errors {
234                    lines.push(format!("- {}: {}", path, error));
235                }
236            }
237
238            if lines.is_empty() {
239                Ok("No paths provided.".to_string())
240            } else {
241                Ok(lines.join("\n"))
242            }
243        })
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use fs::Fs;
251    use gpui::TestAppContext;
252    use project::FakeFs;
253    use serde_json::json;
254    use settings::SettingsStore;
255    use util::path;
256
257    fn init_test(cx: &mut TestAppContext) {
258        cx.update(|cx| {
259            let settings_store = SettingsStore::test(cx);
260            cx.set_global(settings_store);
261        });
262        cx.update(|cx| {
263            let mut settings = AgentSettings::get_global(cx).clone();
264            settings.always_allow_tool_actions = true;
265            AgentSettings::override_global(settings, cx);
266        });
267    }
268
269    #[gpui::test]
270    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
271        init_test(cx);
272
273        let fs = FakeFs::new(cx.executor());
274        fs.insert_tree(
275            "/root",
276            json!({
277                "dirty.txt": "on disk: dirty\n",
278                "clean.txt": "on disk: clean\n",
279            }),
280        )
281        .await;
282
283        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
284        let tool = Arc::new(SaveFileTool::new(project.clone()));
285
286        // Make dirty.txt dirty in-memory.
287        let dirty_project_path = project.read_with(cx, |project, cx| {
288            project
289                .find_project_path("root/dirty.txt", cx)
290                .expect("dirty.txt should exist in project")
291        });
292
293        let dirty_buffer = project
294            .update(cx, |project, cx| {
295                project.open_buffer(dirty_project_path, cx)
296            })
297            .await
298            .unwrap();
299        dirty_buffer.update(cx, |buffer, cx| {
300            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
301        });
302        assert!(
303            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
304            "dirty.txt buffer should be dirty before save"
305        );
306
307        // Ensure clean.txt is opened but remains clean.
308        let clean_project_path = project.read_with(cx, |project, cx| {
309            project
310                .find_project_path("root/clean.txt", cx)
311                .expect("clean.txt should exist in project")
312        });
313
314        let clean_buffer = project
315            .update(cx, |project, cx| {
316                project.open_buffer(clean_project_path, cx)
317            })
318            .await
319            .unwrap();
320        assert!(
321            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
322            "clean.txt buffer should start clean"
323        );
324
325        let output = cx
326            .update(|cx| {
327                tool.clone().run(
328                    SaveFileToolInput {
329                        paths: vec![
330                            PathBuf::from("root/dirty.txt"),
331                            PathBuf::from("root/clean.txt"),
332                        ],
333                    },
334                    ToolCallEventStream::test().0,
335                    cx,
336                )
337            })
338            .await
339            .unwrap();
340
341        // Output should mention saved + clean.
342        assert!(
343            output.contains("Saved 1 file(s)."),
344            "expected saved count line, got:\n{output}"
345        );
346        assert!(
347            output.contains("1 clean."),
348            "expected clean count line, got:\n{output}"
349        );
350
351        // Effect: dirty buffer should now be clean and disk should have new content.
352        assert!(
353            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
354            "dirty.txt buffer should not be dirty after save"
355        );
356
357        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
358        assert_eq!(
359            disk_dirty, "in memory: dirty\n",
360            "dirty.txt disk content should be updated"
361        );
362
363        // Sanity: clean buffer should remain clean and disk unchanged.
364        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
365        assert_eq!(disk_clean, "on disk: clean\n");
366
367        // Test empty paths case.
368        let output = cx
369            .update(|cx| {
370                tool.clone().run(
371                    SaveFileToolInput { paths: vec![] },
372                    ToolCallEventStream::test().0,
373                    cx,
374                )
375            })
376            .await
377            .unwrap();
378        assert_eq!(output, "No paths provided.");
379
380        // Test not-found path case.
381        let output = cx
382            .update(|cx| {
383                tool.clone().run(
384                    SaveFileToolInput {
385                        paths: vec![PathBuf::from("nonexistent/path.txt")],
386                    },
387                    ToolCallEventStream::test().0,
388                    cx,
389                )
390            })
391            .await
392            .unwrap();
393        assert!(
394            output.contains("Not found (1):"),
395            "expected not-found header line, got:\n{output}"
396        );
397        assert!(
398            output.contains("- nonexistent/path.txt"),
399            "expected not-found path bullet, got:\n{output}"
400        );
401    }
402}