save_file_tool.rs

  1use agent_client_protocol as acp;
  2use agent_settings::AgentSettings;
  3use anyhow::Result;
  4use collections::FxHashSet;
  5use futures::FutureExt as _;
  6use gpui::{App, Entity, SharedString, Task};
  7use language::Buffer;
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::Settings;
 12use std::path::{Component, Path, PathBuf};
 13use std::sync::Arc;
 14use util::markdown::MarkdownInlineCode;
 15
 16use super::edit_file_tool::{
 17    SensitiveSettingsKind, is_sensitive_settings_path, sensitive_settings_kind,
 18};
 19use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path};
 20
 21fn common_parent_for_paths(paths: &[String]) -> Option<PathBuf> {
 22    let first = paths.first()?;
 23    let mut common: Vec<Component<'_>> = Path::new(first).parent()?.components().collect();
 24    for path in &paths[1..] {
 25        let parent: Vec<Component<'_>> = Path::new(path).parent()?.components().collect();
 26        let prefix_len = common
 27            .iter()
 28            .zip(parent.iter())
 29            .take_while(|(a, b)| a == b)
 30            .count();
 31        common.truncate(prefix_len);
 32    }
 33    if common.is_empty() {
 34        return None;
 35    }
 36    Some(common.iter().collect())
 37}
 38
 39/// Saves files that have unsaved changes.
 40///
 41/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 42/// Only use this tool after asking the user for permission to save their unsaved changes.
 43#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 44pub struct SaveFileToolInput {
 45    /// The paths of the files to save.
 46    pub paths: Vec<PathBuf>,
 47}
 48
 49pub struct SaveFileTool {
 50    project: Entity<Project>,
 51}
 52
 53impl SaveFileTool {
 54    pub fn new(project: Entity<Project>) -> Self {
 55        Self { project }
 56    }
 57}
 58
 59impl AgentTool for SaveFileTool {
 60    type Input = SaveFileToolInput;
 61    type Output = String;
 62
 63    const NAME: &'static str = "save_file";
 64
 65    fn kind() -> acp::ToolKind {
 66        acp::ToolKind::Other
 67    }
 68
 69    fn initial_title(
 70        &self,
 71        input: Result<Self::Input, serde_json::Value>,
 72        _cx: &mut App,
 73    ) -> SharedString {
 74        match input {
 75            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 76            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 77            Err(_) => "Save files".into(),
 78        }
 79    }
 80
 81    fn run(
 82        self: Arc<Self>,
 83        input: Self::Input,
 84        event_stream: ToolCallEventStream,
 85        cx: &mut App,
 86    ) -> Task<Result<String>> {
 87        let settings = AgentSettings::get_global(cx);
 88        let mut confirmation_paths: Vec<String> = Vec::new();
 89
 90        for path in &input.paths {
 91            let path_str = path.to_string_lossy();
 92            let decision = decide_permission_for_path(Self::NAME, &path_str, settings);
 93            match decision {
 94                ToolPermissionDecision::Allow => {
 95                    if !settings.always_allow_tool_actions
 96                        && is_sensitive_settings_path(Path::new(&*path_str))
 97                    {
 98                        confirmation_paths.push(path_str.to_string());
 99                    }
100                }
101                ToolPermissionDecision::Deny(reason) => {
102                    return Task::ready(Err(anyhow::anyhow!("{}", reason)));
103                }
104                ToolPermissionDecision::Confirm => {
105                    confirmation_paths.push(path_str.to_string());
106                }
107            }
108        }
109
110        let authorize = if !confirmation_paths.is_empty() {
111            let title = if confirmation_paths.len() == 1 {
112                format!("Save {}", MarkdownInlineCode(&confirmation_paths[0]))
113            } else {
114                let paths: Vec<_> = confirmation_paths
115                    .iter()
116                    .take(3)
117                    .map(|p| p.as_str())
118                    .collect();
119                if confirmation_paths.len() > 3 {
120                    format!(
121                        "Save {}, and {} more",
122                        paths.join(", "),
123                        confirmation_paths.len() - 3
124                    )
125                } else {
126                    format!("Save {}", paths.join(", "))
127                }
128            };
129            let sensitive_kind = confirmation_paths
130                .iter()
131                .find_map(|p| sensitive_settings_kind(Path::new(p)));
132            let title = match sensitive_kind {
133                Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
134                Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
135                None => title,
136            };
137            let input_value = if confirmation_paths.len() == 1 {
138                confirmation_paths[0].clone()
139            } else {
140                common_parent_for_paths(&confirmation_paths)
141                    .map(|parent| format!("{}/_", parent.display()))
142                    .unwrap_or_else(|| confirmation_paths[0].clone())
143            };
144            let context = crate::ToolPermissionContext {
145                tool_name: Self::NAME.to_string(),
146                input_value,
147            };
148            Some(event_stream.authorize(title, context, cx))
149        } else {
150            None
151        };
152
153        let project = self.project.clone();
154        let input_paths = input.paths;
155
156        cx.spawn(async move |cx| {
157            if let Some(authorize) = authorize {
158                authorize.await?;
159            }
160
161            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
162
163            let mut saved_paths: Vec<PathBuf> = Vec::new();
164            let mut clean_paths: Vec<PathBuf> = Vec::new();
165            let mut not_found_paths: Vec<PathBuf> = Vec::new();
166            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
167            let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
168            let mut save_errors: Vec<(String, String)> = Vec::new();
169
170            for path in input_paths {
171                let Some(project_path) =
172                    project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
173                else {
174                    not_found_paths.push(path);
175                    continue;
176                };
177
178                let open_buffer_task =
179                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
180
181                let buffer = futures::select! {
182                    result = open_buffer_task.fuse() => {
183                        match result {
184                            Ok(buffer) => buffer,
185                            Err(error) => {
186                                open_errors.push((path, error.to_string()));
187                                continue;
188                            }
189                        }
190                    }
191                    _ = event_stream.cancelled_by_user().fuse() => {
192                        anyhow::bail!("Save cancelled by user");
193                    }
194                };
195
196                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
197
198                if is_dirty {
199                    buffers_to_save.insert(buffer);
200                    saved_paths.push(path);
201                } else {
202                    clean_paths.push(path);
203                }
204            }
205
206            // Save each buffer individually since there's no batch save API.
207            for buffer in buffers_to_save {
208                let path_for_buffer = buffer
209                    .read_with(cx, |buffer, _| {
210                        buffer
211                            .file()
212                            .map(|file| file.path().to_rel_path_buf())
213                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
214                    })
215                    .unwrap_or_else(|| "<unknown>".to_string());
216
217                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
218
219                let save_result = futures::select! {
220                    result = save_task.fuse() => result,
221                    _ = event_stream.cancelled_by_user().fuse() => {
222                        anyhow::bail!("Save cancelled by user");
223                    }
224                };
225                if let Err(error) = save_result {
226                    save_errors.push((path_for_buffer, error.to_string()));
227                }
228            }
229
230            let mut lines: Vec<String> = Vec::new();
231
232            if !saved_paths.is_empty() {
233                lines.push(format!("Saved {} file(s).", saved_paths.len()));
234            }
235            if !clean_paths.is_empty() {
236                lines.push(format!("{} clean.", clean_paths.len()));
237            }
238
239            if !not_found_paths.is_empty() {
240                lines.push(format!("Not found ({}):", not_found_paths.len()));
241                for path in &not_found_paths {
242                    lines.push(format!("- {}", path.display()));
243                }
244            }
245            if !open_errors.is_empty() {
246                lines.push(format!("Open failed ({}):", open_errors.len()));
247                for (path, error) in &open_errors {
248                    lines.push(format!("- {}: {}", path.display(), error));
249                }
250            }
251            if !dirty_check_errors.is_empty() {
252                lines.push(format!(
253                    "Dirty check failed ({}):",
254                    dirty_check_errors.len()
255                ));
256                for (path, error) in &dirty_check_errors {
257                    lines.push(format!("- {}: {}", path.display(), error));
258                }
259            }
260            if !save_errors.is_empty() {
261                lines.push(format!("Save failed ({}):", save_errors.len()));
262                for (path, error) in &save_errors {
263                    lines.push(format!("- {}: {}", path, error));
264                }
265            }
266
267            if lines.is_empty() {
268                Ok("No paths provided.".to_string())
269            } else {
270                Ok(lines.join("\n"))
271            }
272        })
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use fs::Fs;
280    use gpui::TestAppContext;
281    use project::FakeFs;
282    use serde_json::json;
283    use settings::SettingsStore;
284    use util::path;
285
286    fn init_test(cx: &mut TestAppContext) {
287        cx.update(|cx| {
288            let settings_store = SettingsStore::test(cx);
289            cx.set_global(settings_store);
290        });
291        cx.update(|cx| {
292            let mut settings = AgentSettings::get_global(cx).clone();
293            settings.always_allow_tool_actions = true;
294            AgentSettings::override_global(settings, cx);
295        });
296    }
297
298    #[gpui::test]
299    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
300        init_test(cx);
301
302        let fs = FakeFs::new(cx.executor());
303        fs.insert_tree(
304            "/root",
305            json!({
306                "dirty.txt": "on disk: dirty\n",
307                "clean.txt": "on disk: clean\n",
308            }),
309        )
310        .await;
311
312        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
313        let tool = Arc::new(SaveFileTool::new(project.clone()));
314
315        // Make dirty.txt dirty in-memory.
316        let dirty_project_path = project.read_with(cx, |project, cx| {
317            project
318                .find_project_path("root/dirty.txt", cx)
319                .expect("dirty.txt should exist in project")
320        });
321
322        let dirty_buffer = project
323            .update(cx, |project, cx| {
324                project.open_buffer(dirty_project_path, cx)
325            })
326            .await
327            .unwrap();
328        dirty_buffer.update(cx, |buffer, cx| {
329            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
330        });
331        assert!(
332            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
333            "dirty.txt buffer should be dirty before save"
334        );
335
336        // Ensure clean.txt is opened but remains clean.
337        let clean_project_path = project.read_with(cx, |project, cx| {
338            project
339                .find_project_path("root/clean.txt", cx)
340                .expect("clean.txt should exist in project")
341        });
342
343        let clean_buffer = project
344            .update(cx, |project, cx| {
345                project.open_buffer(clean_project_path, cx)
346            })
347            .await
348            .unwrap();
349        assert!(
350            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
351            "clean.txt buffer should start clean"
352        );
353
354        let output = cx
355            .update(|cx| {
356                tool.clone().run(
357                    SaveFileToolInput {
358                        paths: vec![
359                            PathBuf::from("root/dirty.txt"),
360                            PathBuf::from("root/clean.txt"),
361                        ],
362                    },
363                    ToolCallEventStream::test().0,
364                    cx,
365                )
366            })
367            .await
368            .unwrap();
369
370        // Output should mention saved + clean.
371        assert!(
372            output.contains("Saved 1 file(s)."),
373            "expected saved count line, got:\n{output}"
374        );
375        assert!(
376            output.contains("1 clean."),
377            "expected clean count line, got:\n{output}"
378        );
379
380        // Effect: dirty buffer should now be clean and disk should have new content.
381        assert!(
382            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
383            "dirty.txt buffer should not be dirty after save"
384        );
385
386        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
387        assert_eq!(
388            disk_dirty, "in memory: dirty\n",
389            "dirty.txt disk content should be updated"
390        );
391
392        // Sanity: clean buffer should remain clean and disk unchanged.
393        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
394        assert_eq!(disk_clean, "on disk: clean\n");
395
396        // Test empty paths case.
397        let output = cx
398            .update(|cx| {
399                tool.clone().run(
400                    SaveFileToolInput { paths: vec![] },
401                    ToolCallEventStream::test().0,
402                    cx,
403                )
404            })
405            .await
406            .unwrap();
407        assert_eq!(output, "No paths provided.");
408
409        // Test not-found path case.
410        let output = cx
411            .update(|cx| {
412                tool.clone().run(
413                    SaveFileToolInput {
414                        paths: vec![PathBuf::from("nonexistent/path.txt")],
415                    },
416                    ToolCallEventStream::test().0,
417                    cx,
418                )
419            })
420            .await
421            .unwrap();
422        assert!(
423            output.contains("Not found (1):"),
424            "expected not-found header line, got:\n{output}"
425        );
426        assert!(
427            output.contains("- nonexistent/path.txt"),
428            "expected not-found path bullet, got:\n{output}"
429        );
430    }
431}