save_file_tool.rs

  1use agent_client_protocol as acp;
  2use agent_settings::AgentSettings;
  3use collections::FxHashSet;
  4use futures::FutureExt as _;
  5use gpui::{App, Entity, SharedString, Task};
  6use language::Buffer;
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use settings::Settings;
 11use std::path::{Path, PathBuf};
 12use std::sync::Arc;
 13use util::markdown::MarkdownInlineCode;
 14
 15use super::tool_permissions::{
 16    ResolvedProjectPath, SensitiveSettingsKind, authorize_symlink_access,
 17    canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path,
 18    sensitive_settings_kind,
 19};
 20use crate::{AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_for_path};
 21
 22/// Saves files that have unsaved changes.
 23///
 24/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 25/// Only use this tool after asking the user for permission to save their unsaved changes.
 26#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 27pub struct SaveFileToolInput {
 28    /// The paths of the files to save.
 29    pub paths: Vec<PathBuf>,
 30}
 31
 32pub struct SaveFileTool {
 33    project: Entity<Project>,
 34}
 35
 36impl SaveFileTool {
 37    pub fn new(project: Entity<Project>) -> Self {
 38        Self { project }
 39    }
 40}
 41
 42impl AgentTool for SaveFileTool {
 43    type Input = SaveFileToolInput;
 44    type Output = String;
 45
 46    const NAME: &'static str = "save_file";
 47
 48    fn kind() -> acp::ToolKind {
 49        acp::ToolKind::Other
 50    }
 51
 52    fn initial_title(
 53        &self,
 54        input: Result<Self::Input, serde_json::Value>,
 55        _cx: &mut App,
 56    ) -> SharedString {
 57        match input {
 58            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 59            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 60            Err(_) => "Save files".into(),
 61        }
 62    }
 63
 64    fn run(
 65        self: Arc<Self>,
 66        input: Self::Input,
 67        event_stream: ToolCallEventStream,
 68        cx: &mut App,
 69    ) -> Task<Result<String, String>> {
 70        let settings = AgentSettings::get_global(cx).clone();
 71
 72        // Check for any immediate deny before spawning async work.
 73        for path in &input.paths {
 74            let path_str = path.to_string_lossy();
 75            let decision = decide_permission_for_path(Self::NAME, &path_str, &settings);
 76            if let ToolPermissionDecision::Deny(reason) = decision {
 77                return Task::ready(Err(reason));
 78            }
 79        }
 80
 81        let project = self.project.clone();
 82        let input_paths = input.paths;
 83
 84        cx.spawn(async move |cx| {
 85            let fs = project.read_with(cx, |project, _cx| project.fs().clone());
 86            let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
 87
 88            let mut confirmation_paths: Vec<String> = Vec::new();
 89
 90            for path in &input_paths {
 91                let path_str = path.to_string_lossy();
 92                let decision = decide_permission_for_path(Self::NAME, &path_str, &settings);
 93                let symlink_escape = project.read_with(cx, |project, cx| {
 94                    path_has_symlink_escape(project, path, &canonical_roots, cx)
 95                });
 96
 97                match decision {
 98                    ToolPermissionDecision::Allow => {
 99                        if !symlink_escape {
100                            let is_sensitive = super::tool_permissions::is_sensitive_settings_path(
101                                Path::new(&*path_str),
102                                fs.as_ref(),
103                            )
104                            .await;
105                            if is_sensitive {
106                                confirmation_paths.push(path_str.to_string());
107                            }
108                        }
109                    }
110                    ToolPermissionDecision::Deny(reason) => {
111                        return Err(reason);
112                    }
113                    ToolPermissionDecision::Confirm => {
114                        if !symlink_escape {
115                            confirmation_paths.push(path_str.to_string());
116                        }
117                    }
118                }
119            }
120
121            if !confirmation_paths.is_empty() {
122                let title = if confirmation_paths.len() == 1 {
123                    format!("Save {}", MarkdownInlineCode(&confirmation_paths[0]))
124                } else {
125                    let paths: Vec<_> = confirmation_paths
126                        .iter()
127                        .take(3)
128                        .map(|p| p.as_str())
129                        .collect();
130                    if confirmation_paths.len() > 3 {
131                        format!(
132                            "Save {}, and {} more",
133                            paths.join(", "),
134                            confirmation_paths.len() - 3
135                        )
136                    } else {
137                        format!("Save {}", paths.join(", "))
138                    }
139                };
140
141                let mut settings_kind = None;
142                for p in &confirmation_paths {
143                    if let Some(kind) = sensitive_settings_kind(Path::new(p), fs.as_ref()).await {
144                        settings_kind = Some(kind);
145                        break;
146                    }
147                }
148                let title = match settings_kind {
149                    Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
150                    Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
151                    None => title,
152                };
153                let context =
154                    crate::ToolPermissionContext::new(Self::NAME, confirmation_paths.clone());
155                let authorize = cx.update(|cx| event_stream.authorize(title, context, cx));
156                authorize.await.map_err(|e| e.to_string())?;
157            }
158
159            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
160
161            let mut dirty_count: usize = 0;
162            let mut clean_paths: Vec<PathBuf> = Vec::new();
163            let mut not_found_paths: Vec<PathBuf> = Vec::new();
164            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
165            let mut authorization_errors: Vec<(PathBuf, String)> = Vec::new();
166            let mut save_errors: Vec<(String, String)> = Vec::new();
167
168            for path in input_paths {
169                let project_path = match project.read_with(cx, |project, cx| {
170                    resolve_project_path(project, &path, &canonical_roots, cx)
171                }) {
172                    Ok(resolved) => {
173                        let (project_path, symlink_canonical_target) = match resolved {
174                            ResolvedProjectPath::Safe(path) => (path, None),
175                            ResolvedProjectPath::SymlinkEscape {
176                                project_path,
177                                canonical_target,
178                            } => (project_path, Some(canonical_target)),
179                        };
180                        if let Some(canonical_target) = &symlink_canonical_target {
181                            let path_str = path.to_string_lossy();
182                            let authorize_task = cx.update(|cx| {
183                                authorize_symlink_access(
184                                    Self::NAME,
185                                    &path_str,
186                                    canonical_target,
187                                    &event_stream,
188                                    cx,
189                                )
190                            });
191                            let result = authorize_task.await;
192                            if let Err(err) = result {
193                                authorization_errors.push((path.clone(), err.to_string()));
194                                continue;
195                            }
196                        }
197                        project_path
198                    }
199                    Err(_) => {
200                        not_found_paths.push(path);
201                        continue;
202                    }
203                };
204
205                let open_buffer_task =
206                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
207
208                let buffer = futures::select! {
209                    result = open_buffer_task.fuse() => {
210                        match result {
211                            Ok(buffer) => buffer,
212                            Err(error) => {
213                                open_errors.push((path, error.to_string()));
214                                continue;
215                            }
216                        }
217                    }
218                    _ = event_stream.cancelled_by_user().fuse() => {
219                        return Err("Save cancelled by user".to_string());
220                    }
221                };
222
223                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
224
225                if is_dirty {
226                    buffers_to_save.insert(buffer);
227                    dirty_count += 1;
228                } else {
229                    clean_paths.push(path);
230                }
231            }
232
233            // Save each buffer individually since there's no batch save API.
234            for buffer in buffers_to_save {
235                let path_for_buffer = buffer
236                    .read_with(cx, |buffer, _| {
237                        buffer
238                            .file()
239                            .map(|file| file.path().to_rel_path_buf())
240                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
241                    })
242                    .unwrap_or_else(|| "<unknown>".to_string());
243
244                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
245
246                let save_result = futures::select! {
247                    result = save_task.fuse() => result,
248                    _ = event_stream.cancelled_by_user().fuse() => {
249                        return Err("Save cancelled by user".to_string());
250                    }
251                };
252                if let Err(error) = save_result {
253                    save_errors.push((path_for_buffer, error.to_string()));
254                }
255            }
256
257            let mut lines: Vec<String> = Vec::new();
258
259            let successful_saves = dirty_count.saturating_sub(save_errors.len());
260            if successful_saves > 0 {
261                lines.push(format!("Saved {} file(s).", successful_saves));
262            }
263            if !clean_paths.is_empty() {
264                lines.push(format!("{} clean.", clean_paths.len()));
265            }
266
267            if !not_found_paths.is_empty() {
268                lines.push(format!("Not found ({}):", not_found_paths.len()));
269                for path in &not_found_paths {
270                    lines.push(format!("- {}", path.display()));
271                }
272            }
273            if !open_errors.is_empty() {
274                lines.push(format!("Open failed ({}):", open_errors.len()));
275                for (path, error) in &open_errors {
276                    lines.push(format!("- {}: {}", path.display(), error));
277                }
278            }
279            if !authorization_errors.is_empty() {
280                lines.push(format!(
281                    "Authorization failed ({}):",
282                    authorization_errors.len()
283                ));
284                for (path, error) in &authorization_errors {
285                    lines.push(format!("- {}: {}", path.display(), error));
286                }
287            }
288            if !save_errors.is_empty() {
289                lines.push(format!("Save failed ({}):", save_errors.len()));
290                for (path, error) in &save_errors {
291                    lines.push(format!("- {}: {}", path, error));
292                }
293            }
294
295            if lines.is_empty() {
296                Ok("No paths provided.".to_string())
297            } else {
298                Ok(lines.join("\n"))
299            }
300        })
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use fs::Fs as _;
308    use gpui::TestAppContext;
309    use project::FakeFs;
310    use serde_json::json;
311    use settings::SettingsStore;
312    use util::path;
313
314    fn init_test(cx: &mut TestAppContext) {
315        cx.update(|cx| {
316            let settings_store = SettingsStore::test(cx);
317            cx.set_global(settings_store);
318        });
319        cx.update(|cx| {
320            let mut settings = AgentSettings::get_global(cx).clone();
321            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
322            AgentSettings::override_global(settings, cx);
323        });
324    }
325
326    #[gpui::test]
327    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
328        init_test(cx);
329
330        let fs = FakeFs::new(cx.executor());
331        fs.insert_tree(
332            "/root",
333            json!({
334                "dirty.txt": "on disk: dirty\n",
335                "clean.txt": "on disk: clean\n",
336            }),
337        )
338        .await;
339
340        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
341        let tool = Arc::new(SaveFileTool::new(project.clone()));
342
343        // Make dirty.txt dirty in-memory.
344        let dirty_project_path = project.read_with(cx, |project, cx| {
345            project
346                .find_project_path("root/dirty.txt", cx)
347                .expect("dirty.txt should exist in project")
348        });
349
350        let dirty_buffer = project
351            .update(cx, |project, cx| {
352                project.open_buffer(dirty_project_path, cx)
353            })
354            .await
355            .unwrap();
356        dirty_buffer.update(cx, |buffer, cx| {
357            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
358        });
359        assert!(
360            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
361            "dirty.txt buffer should be dirty before save"
362        );
363
364        // Ensure clean.txt is opened but remains clean.
365        let clean_project_path = project.read_with(cx, |project, cx| {
366            project
367                .find_project_path("root/clean.txt", cx)
368                .expect("clean.txt should exist in project")
369        });
370
371        let clean_buffer = project
372            .update(cx, |project, cx| {
373                project.open_buffer(clean_project_path, cx)
374            })
375            .await
376            .unwrap();
377        assert!(
378            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
379            "clean.txt buffer should start clean"
380        );
381
382        let output = cx
383            .update(|cx| {
384                tool.clone().run(
385                    SaveFileToolInput {
386                        paths: vec![
387                            PathBuf::from("root/dirty.txt"),
388                            PathBuf::from("root/clean.txt"),
389                        ],
390                    },
391                    ToolCallEventStream::test().0,
392                    cx,
393                )
394            })
395            .await
396            .unwrap();
397
398        // Output should mention saved + clean.
399        assert!(
400            output.contains("Saved 1 file(s)."),
401            "expected saved count line, got:\n{output}"
402        );
403        assert!(
404            output.contains("1 clean."),
405            "expected clean count line, got:\n{output}"
406        );
407
408        // Effect: dirty buffer should now be clean and disk should have new content.
409        assert!(
410            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
411            "dirty.txt buffer should not be dirty after save"
412        );
413
414        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
415        assert_eq!(
416            disk_dirty, "in memory: dirty\n",
417            "dirty.txt disk content should be updated"
418        );
419
420        // Sanity: clean buffer should remain clean and disk unchanged.
421        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
422        assert_eq!(disk_clean, "on disk: clean\n");
423
424        // Test empty paths case.
425        let output = cx
426            .update(|cx| {
427                tool.clone().run(
428                    SaveFileToolInput { paths: vec![] },
429                    ToolCallEventStream::test().0,
430                    cx,
431                )
432            })
433            .await
434            .unwrap();
435        assert_eq!(output, "No paths provided.");
436
437        // Test not-found path case.
438        let output = cx
439            .update(|cx| {
440                tool.clone().run(
441                    SaveFileToolInput {
442                        paths: vec![PathBuf::from("nonexistent/path.txt")],
443                    },
444                    ToolCallEventStream::test().0,
445                    cx,
446                )
447            })
448            .await
449            .unwrap();
450        assert!(
451            output.contains("Not found (1):"),
452            "expected not-found header line, got:\n{output}"
453        );
454        assert!(
455            output.contains("- nonexistent/path.txt"),
456            "expected not-found path bullet, got:\n{output}"
457        );
458    }
459
460    #[gpui::test]
461    async fn test_save_file_symlink_escape_requests_authorization(cx: &mut TestAppContext) {
462        init_test(cx);
463
464        let fs = FakeFs::new(cx.executor());
465        fs.insert_tree(
466            path!("/root"),
467            json!({
468                "project": {
469                    "src": {}
470                },
471                "external": {
472                    "secret.txt": "secret content"
473                }
474            }),
475        )
476        .await;
477
478        fs.create_symlink(
479            path!("/root/project/link.txt").as_ref(),
480            PathBuf::from("../external/secret.txt"),
481        )
482        .await
483        .unwrap();
484
485        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
486        cx.executor().run_until_parked();
487
488        let tool = Arc::new(SaveFileTool::new(project));
489
490        let (event_stream, mut event_rx) = ToolCallEventStream::test();
491        let task = cx.update(|cx| {
492            tool.clone().run(
493                SaveFileToolInput {
494                    paths: vec![PathBuf::from("project/link.txt")],
495                },
496                event_stream,
497                cx,
498            )
499        });
500
501        cx.run_until_parked();
502
503        let auth = event_rx.expect_authorization().await;
504        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
505        assert!(
506            title.contains("points outside the project"),
507            "Expected symlink escape authorization, got: {title}",
508        );
509
510        auth.response
511            .send(acp::PermissionOptionId::new("allow"))
512            .unwrap();
513
514        let _result = task.await;
515    }
516
517    #[gpui::test]
518    async fn test_save_file_symlink_escape_honors_deny_policy(cx: &mut TestAppContext) {
519        init_test(cx);
520        cx.update(|cx| {
521            let mut settings = AgentSettings::get_global(cx).clone();
522            settings.tool_permissions.tools.insert(
523                "save_file".into(),
524                agent_settings::ToolRules {
525                    default: Some(settings::ToolPermissionMode::Deny),
526                    ..Default::default()
527                },
528            );
529            AgentSettings::override_global(settings, cx);
530        });
531
532        let fs = FakeFs::new(cx.executor());
533        fs.insert_tree(
534            path!("/root"),
535            json!({
536                "project": {
537                    "src": {}
538                },
539                "external": {
540                    "secret.txt": "secret content"
541                }
542            }),
543        )
544        .await;
545
546        fs.create_symlink(
547            path!("/root/project/link.txt").as_ref(),
548            PathBuf::from("../external/secret.txt"),
549        )
550        .await
551        .unwrap();
552
553        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
554        cx.executor().run_until_parked();
555
556        let tool = Arc::new(SaveFileTool::new(project));
557
558        let (event_stream, mut event_rx) = ToolCallEventStream::test();
559        let result = cx
560            .update(|cx| {
561                tool.clone().run(
562                    SaveFileToolInput {
563                        paths: vec![PathBuf::from("project/link.txt")],
564                    },
565                    event_stream,
566                    cx,
567                )
568            })
569            .await;
570
571        assert!(result.is_err(), "Tool should fail when policy denies");
572        assert!(
573            !matches!(
574                event_rx.try_next(),
575                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
576            ),
577            "Deny policy should not emit symlink authorization prompt",
578        );
579    }
580
581    #[gpui::test]
582    async fn test_save_file_symlink_escape_confirm_requires_single_approval(
583        cx: &mut TestAppContext,
584    ) {
585        init_test(cx);
586        cx.update(|cx| {
587            let mut settings = AgentSettings::get_global(cx).clone();
588            settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
589            AgentSettings::override_global(settings, cx);
590        });
591
592        let fs = FakeFs::new(cx.executor());
593        fs.insert_tree(
594            path!("/root"),
595            json!({
596                "project": {
597                    "src": {}
598                },
599                "external": {
600                    "secret.txt": "secret content"
601                }
602            }),
603        )
604        .await;
605
606        fs.create_symlink(
607            path!("/root/project/link.txt").as_ref(),
608            PathBuf::from("../external/secret.txt"),
609        )
610        .await
611        .unwrap();
612
613        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
614        cx.executor().run_until_parked();
615
616        let tool = Arc::new(SaveFileTool::new(project));
617
618        let (event_stream, mut event_rx) = ToolCallEventStream::test();
619        let task = cx.update(|cx| {
620            tool.clone().run(
621                SaveFileToolInput {
622                    paths: vec![PathBuf::from("project/link.txt")],
623                },
624                event_stream,
625                cx,
626            )
627        });
628
629        cx.run_until_parked();
630
631        let auth = event_rx.expect_authorization().await;
632        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
633        assert!(
634            title.contains("points outside the project"),
635            "Expected symlink escape authorization, got: {title}",
636        );
637
638        auth.response
639            .send(acp::PermissionOptionId::new("allow"))
640            .unwrap();
641
642        assert!(
643            !matches!(
644                event_rx.try_next(),
645                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
646            ),
647            "Expected a single authorization prompt",
648        );
649
650        let _result = task.await;
651    }
652
653    #[gpui::test]
654    async fn test_save_file_symlink_denial_does_not_reduce_success_count(cx: &mut TestAppContext) {
655        init_test(cx);
656
657        let fs = FakeFs::new(cx.executor());
658        fs.insert_tree(
659            path!("/root"),
660            json!({
661                "project": {
662                    "dirty.txt": "on disk value\n",
663                },
664                "external": {
665                    "secret.txt": "secret content"
666                }
667            }),
668        )
669        .await;
670
671        fs.create_symlink(
672            path!("/root/project/link.txt").as_ref(),
673            PathBuf::from("../external/secret.txt"),
674        )
675        .await
676        .unwrap();
677
678        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
679        cx.executor().run_until_parked();
680
681        let dirty_project_path = project.read_with(cx, |project, cx| {
682            project
683                .find_project_path("project/dirty.txt", cx)
684                .expect("dirty.txt should exist in project")
685        });
686        let dirty_buffer = project
687            .update(cx, |project, cx| {
688                project.open_buffer(dirty_project_path, cx)
689            })
690            .await
691            .unwrap();
692        dirty_buffer.update(cx, |buffer, cx| {
693            buffer.edit([(0..buffer.len(), "in memory value\n")], None, cx);
694        });
695        assert!(
696            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
697            "dirty.txt should be dirty before save"
698        );
699
700        let tool = Arc::new(SaveFileTool::new(project));
701
702        let (event_stream, mut event_rx) = ToolCallEventStream::test();
703        let task = cx.update(|cx| {
704            tool.clone().run(
705                SaveFileToolInput {
706                    paths: vec![
707                        PathBuf::from("project/dirty.txt"),
708                        PathBuf::from("project/link.txt"),
709                    ],
710                },
711                event_stream,
712                cx,
713            )
714        });
715
716        cx.run_until_parked();
717
718        let auth = event_rx.expect_authorization().await;
719        auth.response
720            .send(acp::PermissionOptionId::new("deny"))
721            .unwrap();
722
723        let output = task.await.unwrap();
724        assert!(
725            output.contains("Saved 1 file(s)."),
726            "Expected successful save count to remain accurate, got:\n{output}",
727        );
728        assert!(
729            output.contains("Authorization failed (1):"),
730            "Expected authorization failure section, got:\n{output}",
731        );
732        assert!(
733            !output.contains("Save failed"),
734            "Authorization denials should not be counted as save failures, got:\n{output}",
735        );
736    }
737}