save_file_tool.rs

  1use agent_client_protocol as acp;
  2use agent_settings::AgentSettings;
  3use anyhow::Result;
  4use collections::FxHashSet;
  5use futures::FutureExt as _;
  6use gpui::{App, Entity, SharedString, Task};
  7use language::Buffer;
  8use project::Project;
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::Settings;
 12use std::path::{Path, PathBuf};
 13use std::sync::Arc;
 14use util::markdown::MarkdownInlineCode;
 15
 16use super::tool_permissions::{
 17    ResolvedProjectPath, SensitiveSettingsKind, authorize_symlink_access,
 18    canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path,
 19    sensitive_settings_kind,
 20};
 21use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path};
 22
 23/// Saves files that have unsaved changes.
 24///
 25/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 26/// Only use this tool after asking the user for permission to save their unsaved changes.
 27#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 28pub struct SaveFileToolInput {
 29    /// The paths of the files to save.
 30    pub paths: Vec<PathBuf>,
 31}
 32
 33pub struct SaveFileTool {
 34    project: Entity<Project>,
 35}
 36
 37impl SaveFileTool {
 38    pub fn new(project: Entity<Project>) -> Self {
 39        Self { project }
 40    }
 41}
 42
 43impl AgentTool for SaveFileTool {
 44    type Input = SaveFileToolInput;
 45    type Output = String;
 46
 47    const NAME: &'static str = "save_file";
 48
 49    fn kind() -> acp::ToolKind {
 50        acp::ToolKind::Other
 51    }
 52
 53    fn initial_title(
 54        &self,
 55        input: Result<Self::Input, serde_json::Value>,
 56        _cx: &mut App,
 57    ) -> SharedString {
 58        match input {
 59            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 60            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 61            Err(_) => "Save files".into(),
 62        }
 63    }
 64
 65    fn run(
 66        self: Arc<Self>,
 67        input: Self::Input,
 68        event_stream: ToolCallEventStream,
 69        cx: &mut App,
 70    ) -> Task<Result<String>> {
 71        let settings = AgentSettings::get_global(cx).clone();
 72
 73        // Check for any immediate deny before spawning async work.
 74        for path in &input.paths {
 75            let path_str = path.to_string_lossy();
 76            let decision = decide_permission_for_path(Self::NAME, &path_str, &settings);
 77            if let ToolPermissionDecision::Deny(reason) = decision {
 78                return Task::ready(Err(anyhow::anyhow!("{}", reason)));
 79            }
 80        }
 81
 82        let project = self.project.clone();
 83        let input_paths = input.paths;
 84
 85        cx.spawn(async move |cx| {
 86            let fs = project.read_with(cx, |project, _cx| project.fs().clone());
 87            let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
 88
 89            let mut confirmation_paths: Vec<String> = Vec::new();
 90
 91            for path in &input_paths {
 92                let path_str = path.to_string_lossy();
 93                let decision = decide_permission_for_path(Self::NAME, &path_str, &settings);
 94                let symlink_escape = project.read_with(cx, |project, cx| {
 95                    path_has_symlink_escape(project, path, &canonical_roots, cx)
 96                });
 97
 98                match decision {
 99                    ToolPermissionDecision::Allow => {
100                        if !symlink_escape {
101                            let is_sensitive = super::tool_permissions::is_sensitive_settings_path(
102                                Path::new(&*path_str),
103                                fs.as_ref(),
104                            )
105                            .await;
106                            if is_sensitive {
107                                confirmation_paths.push(path_str.to_string());
108                            }
109                        }
110                    }
111                    ToolPermissionDecision::Deny(reason) => {
112                        return Err(anyhow::anyhow!("{}", reason));
113                    }
114                    ToolPermissionDecision::Confirm => {
115                        if !symlink_escape {
116                            confirmation_paths.push(path_str.to_string());
117                        }
118                    }
119                }
120            }
121
122            if !confirmation_paths.is_empty() {
123                let title = if confirmation_paths.len() == 1 {
124                    format!("Save {}", MarkdownInlineCode(&confirmation_paths[0]))
125                } else {
126                    let paths: Vec<_> = confirmation_paths
127                        .iter()
128                        .take(3)
129                        .map(|p| p.as_str())
130                        .collect();
131                    if confirmation_paths.len() > 3 {
132                        format!(
133                            "Save {}, and {} more",
134                            paths.join(", "),
135                            confirmation_paths.len() - 3
136                        )
137                    } else {
138                        format!("Save {}", paths.join(", "))
139                    }
140                };
141
142                let mut settings_kind = None;
143                for p in &confirmation_paths {
144                    if let Some(kind) = sensitive_settings_kind(Path::new(p), fs.as_ref()).await {
145                        settings_kind = Some(kind);
146                        break;
147                    }
148                }
149                let title = match settings_kind {
150                    Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
151                    Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
152                    None => title,
153                };
154                let context =
155                    crate::ToolPermissionContext::new(Self::NAME, confirmation_paths.clone());
156                let authorize = cx.update(|cx| event_stream.authorize(title, context, cx));
157                authorize.await?;
158            }
159
160            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
161
162            let mut dirty_count: usize = 0;
163            let mut clean_paths: Vec<PathBuf> = Vec::new();
164            let mut not_found_paths: Vec<PathBuf> = Vec::new();
165            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
166            let mut authorization_errors: Vec<(PathBuf, String)> = Vec::new();
167            let mut save_errors: Vec<(String, String)> = Vec::new();
168
169            for path in input_paths {
170                let project_path = match project.read_with(cx, |project, cx| {
171                    resolve_project_path(project, &path, &canonical_roots, cx)
172                }) {
173                    Ok(resolved) => {
174                        let (project_path, symlink_canonical_target) = match resolved {
175                            ResolvedProjectPath::Safe(path) => (path, None),
176                            ResolvedProjectPath::SymlinkEscape {
177                                project_path,
178                                canonical_target,
179                            } => (project_path, Some(canonical_target)),
180                        };
181                        if let Some(canonical_target) = &symlink_canonical_target {
182                            let path_str = path.to_string_lossy();
183                            let authorize_task = cx.update(|cx| {
184                                authorize_symlink_access(
185                                    Self::NAME,
186                                    &path_str,
187                                    canonical_target,
188                                    &event_stream,
189                                    cx,
190                                )
191                            });
192                            let result = authorize_task.await;
193                            if let Err(err) = result {
194                                authorization_errors.push((path.clone(), err.to_string()));
195                                continue;
196                            }
197                        }
198                        project_path
199                    }
200                    Err(_) => {
201                        not_found_paths.push(path);
202                        continue;
203                    }
204                };
205
206                let open_buffer_task =
207                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
208
209                let buffer = futures::select! {
210                    result = open_buffer_task.fuse() => {
211                        match result {
212                            Ok(buffer) => buffer,
213                            Err(error) => {
214                                open_errors.push((path, error.to_string()));
215                                continue;
216                            }
217                        }
218                    }
219                    _ = event_stream.cancelled_by_user().fuse() => {
220                        anyhow::bail!("Save cancelled by user");
221                    }
222                };
223
224                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
225
226                if is_dirty {
227                    buffers_to_save.insert(buffer);
228                    dirty_count += 1;
229                } else {
230                    clean_paths.push(path);
231                }
232            }
233
234            // Save each buffer individually since there's no batch save API.
235            for buffer in buffers_to_save {
236                let path_for_buffer = buffer
237                    .read_with(cx, |buffer, _| {
238                        buffer
239                            .file()
240                            .map(|file| file.path().to_rel_path_buf())
241                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
242                    })
243                    .unwrap_or_else(|| "<unknown>".to_string());
244
245                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
246
247                let save_result = futures::select! {
248                    result = save_task.fuse() => result,
249                    _ = event_stream.cancelled_by_user().fuse() => {
250                        anyhow::bail!("Save cancelled by user");
251                    }
252                };
253                if let Err(error) = save_result {
254                    save_errors.push((path_for_buffer, error.to_string()));
255                }
256            }
257
258            let mut lines: Vec<String> = Vec::new();
259
260            let successful_saves = dirty_count.saturating_sub(save_errors.len());
261            if successful_saves > 0 {
262                lines.push(format!("Saved {} file(s).", successful_saves));
263            }
264            if !clean_paths.is_empty() {
265                lines.push(format!("{} clean.", clean_paths.len()));
266            }
267
268            if !not_found_paths.is_empty() {
269                lines.push(format!("Not found ({}):", not_found_paths.len()));
270                for path in &not_found_paths {
271                    lines.push(format!("- {}", path.display()));
272                }
273            }
274            if !open_errors.is_empty() {
275                lines.push(format!("Open failed ({}):", open_errors.len()));
276                for (path, error) in &open_errors {
277                    lines.push(format!("- {}: {}", path.display(), error));
278                }
279            }
280            if !authorization_errors.is_empty() {
281                lines.push(format!(
282                    "Authorization failed ({}):",
283                    authorization_errors.len()
284                ));
285                for (path, error) in &authorization_errors {
286                    lines.push(format!("- {}: {}", path.display(), error));
287                }
288            }
289            if !save_errors.is_empty() {
290                lines.push(format!("Save failed ({}):", save_errors.len()));
291                for (path, error) in &save_errors {
292                    lines.push(format!("- {}: {}", path, error));
293                }
294            }
295
296            if lines.is_empty() {
297                Ok("No paths provided.".to_string())
298            } else {
299                Ok(lines.join("\n"))
300            }
301        })
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use fs::Fs as _;
309    use gpui::TestAppContext;
310    use project::FakeFs;
311    use serde_json::json;
312    use settings::SettingsStore;
313    use util::path;
314
315    fn init_test(cx: &mut TestAppContext) {
316        cx.update(|cx| {
317            let settings_store = SettingsStore::test(cx);
318            cx.set_global(settings_store);
319        });
320        cx.update(|cx| {
321            let mut settings = AgentSettings::get_global(cx).clone();
322            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
323            AgentSettings::override_global(settings, cx);
324        });
325    }
326
327    #[gpui::test]
328    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
329        init_test(cx);
330
331        let fs = FakeFs::new(cx.executor());
332        fs.insert_tree(
333            "/root",
334            json!({
335                "dirty.txt": "on disk: dirty\n",
336                "clean.txt": "on disk: clean\n",
337            }),
338        )
339        .await;
340
341        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
342        let tool = Arc::new(SaveFileTool::new(project.clone()));
343
344        // Make dirty.txt dirty in-memory.
345        let dirty_project_path = project.read_with(cx, |project, cx| {
346            project
347                .find_project_path("root/dirty.txt", cx)
348                .expect("dirty.txt should exist in project")
349        });
350
351        let dirty_buffer = project
352            .update(cx, |project, cx| {
353                project.open_buffer(dirty_project_path, cx)
354            })
355            .await
356            .unwrap();
357        dirty_buffer.update(cx, |buffer, cx| {
358            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
359        });
360        assert!(
361            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
362            "dirty.txt buffer should be dirty before save"
363        );
364
365        // Ensure clean.txt is opened but remains clean.
366        let clean_project_path = project.read_with(cx, |project, cx| {
367            project
368                .find_project_path("root/clean.txt", cx)
369                .expect("clean.txt should exist in project")
370        });
371
372        let clean_buffer = project
373            .update(cx, |project, cx| {
374                project.open_buffer(clean_project_path, cx)
375            })
376            .await
377            .unwrap();
378        assert!(
379            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
380            "clean.txt buffer should start clean"
381        );
382
383        let output = cx
384            .update(|cx| {
385                tool.clone().run(
386                    SaveFileToolInput {
387                        paths: vec![
388                            PathBuf::from("root/dirty.txt"),
389                            PathBuf::from("root/clean.txt"),
390                        ],
391                    },
392                    ToolCallEventStream::test().0,
393                    cx,
394                )
395            })
396            .await
397            .unwrap();
398
399        // Output should mention saved + clean.
400        assert!(
401            output.contains("Saved 1 file(s)."),
402            "expected saved count line, got:\n{output}"
403        );
404        assert!(
405            output.contains("1 clean."),
406            "expected clean count line, got:\n{output}"
407        );
408
409        // Effect: dirty buffer should now be clean and disk should have new content.
410        assert!(
411            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
412            "dirty.txt buffer should not be dirty after save"
413        );
414
415        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
416        assert_eq!(
417            disk_dirty, "in memory: dirty\n",
418            "dirty.txt disk content should be updated"
419        );
420
421        // Sanity: clean buffer should remain clean and disk unchanged.
422        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
423        assert_eq!(disk_clean, "on disk: clean\n");
424
425        // Test empty paths case.
426        let output = cx
427            .update(|cx| {
428                tool.clone().run(
429                    SaveFileToolInput { paths: vec![] },
430                    ToolCallEventStream::test().0,
431                    cx,
432                )
433            })
434            .await
435            .unwrap();
436        assert_eq!(output, "No paths provided.");
437
438        // Test not-found path case.
439        let output = cx
440            .update(|cx| {
441                tool.clone().run(
442                    SaveFileToolInput {
443                        paths: vec![PathBuf::from("nonexistent/path.txt")],
444                    },
445                    ToolCallEventStream::test().0,
446                    cx,
447                )
448            })
449            .await
450            .unwrap();
451        assert!(
452            output.contains("Not found (1):"),
453            "expected not-found header line, got:\n{output}"
454        );
455        assert!(
456            output.contains("- nonexistent/path.txt"),
457            "expected not-found path bullet, got:\n{output}"
458        );
459    }
460
461    #[gpui::test]
462    async fn test_save_file_symlink_escape_requests_authorization(cx: &mut TestAppContext) {
463        init_test(cx);
464
465        let fs = FakeFs::new(cx.executor());
466        fs.insert_tree(
467            path!("/root"),
468            json!({
469                "project": {
470                    "src": {}
471                },
472                "external": {
473                    "secret.txt": "secret content"
474                }
475            }),
476        )
477        .await;
478
479        fs.create_symlink(
480            path!("/root/project/link.txt").as_ref(),
481            PathBuf::from("../external/secret.txt"),
482        )
483        .await
484        .unwrap();
485
486        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
487        cx.executor().run_until_parked();
488
489        let tool = Arc::new(SaveFileTool::new(project));
490
491        let (event_stream, mut event_rx) = ToolCallEventStream::test();
492        let task = cx.update(|cx| {
493            tool.clone().run(
494                SaveFileToolInput {
495                    paths: vec![PathBuf::from("project/link.txt")],
496                },
497                event_stream,
498                cx,
499            )
500        });
501
502        cx.run_until_parked();
503
504        let auth = event_rx.expect_authorization().await;
505        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
506        assert!(
507            title.contains("points outside the project"),
508            "Expected symlink escape authorization, got: {title}",
509        );
510
511        auth.response
512            .send(acp::PermissionOptionId::new("allow"))
513            .unwrap();
514
515        let _result = task.await;
516    }
517
518    #[gpui::test]
519    async fn test_save_file_symlink_escape_honors_deny_policy(cx: &mut TestAppContext) {
520        init_test(cx);
521        cx.update(|cx| {
522            let mut settings = AgentSettings::get_global(cx).clone();
523            settings.tool_permissions.tools.insert(
524                "save_file".into(),
525                agent_settings::ToolRules {
526                    default: Some(settings::ToolPermissionMode::Deny),
527                    ..Default::default()
528                },
529            );
530            AgentSettings::override_global(settings, cx);
531        });
532
533        let fs = FakeFs::new(cx.executor());
534        fs.insert_tree(
535            path!("/root"),
536            json!({
537                "project": {
538                    "src": {}
539                },
540                "external": {
541                    "secret.txt": "secret content"
542                }
543            }),
544        )
545        .await;
546
547        fs.create_symlink(
548            path!("/root/project/link.txt").as_ref(),
549            PathBuf::from("../external/secret.txt"),
550        )
551        .await
552        .unwrap();
553
554        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
555        cx.executor().run_until_parked();
556
557        let tool = Arc::new(SaveFileTool::new(project));
558
559        let (event_stream, mut event_rx) = ToolCallEventStream::test();
560        let result = cx
561            .update(|cx| {
562                tool.clone().run(
563                    SaveFileToolInput {
564                        paths: vec![PathBuf::from("project/link.txt")],
565                    },
566                    event_stream,
567                    cx,
568                )
569            })
570            .await;
571
572        assert!(result.is_err(), "Tool should fail when policy denies");
573        assert!(
574            !matches!(
575                event_rx.try_next(),
576                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
577            ),
578            "Deny policy should not emit symlink authorization prompt",
579        );
580    }
581
582    #[gpui::test]
583    async fn test_save_file_symlink_escape_confirm_requires_single_approval(
584        cx: &mut TestAppContext,
585    ) {
586        init_test(cx);
587        cx.update(|cx| {
588            let mut settings = AgentSettings::get_global(cx).clone();
589            settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
590            AgentSettings::override_global(settings, cx);
591        });
592
593        let fs = FakeFs::new(cx.executor());
594        fs.insert_tree(
595            path!("/root"),
596            json!({
597                "project": {
598                    "src": {}
599                },
600                "external": {
601                    "secret.txt": "secret content"
602                }
603            }),
604        )
605        .await;
606
607        fs.create_symlink(
608            path!("/root/project/link.txt").as_ref(),
609            PathBuf::from("../external/secret.txt"),
610        )
611        .await
612        .unwrap();
613
614        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
615        cx.executor().run_until_parked();
616
617        let tool = Arc::new(SaveFileTool::new(project));
618
619        let (event_stream, mut event_rx) = ToolCallEventStream::test();
620        let task = cx.update(|cx| {
621            tool.clone().run(
622                SaveFileToolInput {
623                    paths: vec![PathBuf::from("project/link.txt")],
624                },
625                event_stream,
626                cx,
627            )
628        });
629
630        cx.run_until_parked();
631
632        let auth = event_rx.expect_authorization().await;
633        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
634        assert!(
635            title.contains("points outside the project"),
636            "Expected symlink escape authorization, got: {title}",
637        );
638
639        auth.response
640            .send(acp::PermissionOptionId::new("allow"))
641            .unwrap();
642
643        assert!(
644            !matches!(
645                event_rx.try_next(),
646                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
647            ),
648            "Expected a single authorization prompt",
649        );
650
651        let _result = task.await;
652    }
653
654    #[gpui::test]
655    async fn test_save_file_symlink_denial_does_not_reduce_success_count(cx: &mut TestAppContext) {
656        init_test(cx);
657
658        let fs = FakeFs::new(cx.executor());
659        fs.insert_tree(
660            path!("/root"),
661            json!({
662                "project": {
663                    "dirty.txt": "on disk value\n",
664                },
665                "external": {
666                    "secret.txt": "secret content"
667                }
668            }),
669        )
670        .await;
671
672        fs.create_symlink(
673            path!("/root/project/link.txt").as_ref(),
674            PathBuf::from("../external/secret.txt"),
675        )
676        .await
677        .unwrap();
678
679        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
680        cx.executor().run_until_parked();
681
682        let dirty_project_path = project.read_with(cx, |project, cx| {
683            project
684                .find_project_path("project/dirty.txt", cx)
685                .expect("dirty.txt should exist in project")
686        });
687        let dirty_buffer = project
688            .update(cx, |project, cx| {
689                project.open_buffer(dirty_project_path, cx)
690            })
691            .await
692            .unwrap();
693        dirty_buffer.update(cx, |buffer, cx| {
694            buffer.edit([(0..buffer.len(), "in memory value\n")], None, cx);
695        });
696        assert!(
697            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
698            "dirty.txt should be dirty before save"
699        );
700
701        let tool = Arc::new(SaveFileTool::new(project));
702
703        let (event_stream, mut event_rx) = ToolCallEventStream::test();
704        let task = cx.update(|cx| {
705            tool.clone().run(
706                SaveFileToolInput {
707                    paths: vec![
708                        PathBuf::from("project/dirty.txt"),
709                        PathBuf::from("project/link.txt"),
710                    ],
711                },
712                event_stream,
713                cx,
714            )
715        });
716
717        cx.run_until_parked();
718
719        let auth = event_rx.expect_authorization().await;
720        auth.response
721            .send(acp::PermissionOptionId::new("deny"))
722            .unwrap();
723
724        let output = task.await.unwrap();
725        assert!(
726            output.contains("Saved 1 file(s)."),
727            "Expected successful save count to remain accurate, got:\n{output}",
728        );
729        assert!(
730            output.contains("Authorization failed (1):"),
731            "Expected authorization failure section, got:\n{output}",
732        );
733        assert!(
734            !output.contains("Save failed"),
735            "Authorization denials should not be counted as save failures, got:\n{output}",
736        );
737    }
738}