save_file_tool.rs

  1use agent_client_protocol as acp;
  2use agent_settings::AgentSettings;
  3use collections::FxHashSet;
  4use futures::FutureExt as _;
  5use gpui::{App, Entity, SharedString, Task};
  6use language::Buffer;
  7use project::Project;
  8use schemars::JsonSchema;
  9use serde::{Deserialize, Serialize};
 10use settings::Settings;
 11use std::path::{Path, PathBuf};
 12use std::sync::Arc;
 13use util::markdown::MarkdownInlineCode;
 14
 15use super::tool_permissions::{
 16    ResolvedProjectPath, SensitiveSettingsKind, authorize_symlink_access,
 17    canonicalize_worktree_roots, path_has_symlink_escape, resolve_project_path,
 18    sensitive_settings_kind,
 19};
 20use crate::{
 21    AgentTool, ToolCallEventStream, ToolInput, ToolPermissionDecision, decide_permission_for_path,
 22};
 23
 24/// Saves files that have unsaved changes.
 25///
 26/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
 27/// Only use this tool after asking the user for permission to save their unsaved changes.
 28#[derive(Debug, Serialize, Deserialize, JsonSchema)]
 29pub struct SaveFileToolInput {
 30    /// The paths of the files to save.
 31    pub paths: Vec<PathBuf>,
 32}
 33
 34pub struct SaveFileTool {
 35    project: Entity<Project>,
 36}
 37
 38impl SaveFileTool {
 39    pub fn new(project: Entity<Project>) -> Self {
 40        Self { project }
 41    }
 42}
 43
 44impl AgentTool for SaveFileTool {
 45    type Input = SaveFileToolInput;
 46    type Output = String;
 47
 48    const NAME: &'static str = "save_file";
 49
 50    fn kind() -> acp::ToolKind {
 51        acp::ToolKind::Other
 52    }
 53
 54    fn initial_title(
 55        &self,
 56        input: Result<Self::Input, serde_json::Value>,
 57        _cx: &mut App,
 58    ) -> SharedString {
 59        match input {
 60            Ok(input) if input.paths.len() == 1 => "Save file".into(),
 61            Ok(input) => format!("Save {} files", input.paths.len()).into(),
 62            Err(_) => "Save files".into(),
 63        }
 64    }
 65
 66    fn run(
 67        self: Arc<Self>,
 68        input: ToolInput<Self::Input>,
 69        event_stream: ToolCallEventStream,
 70        cx: &mut App,
 71    ) -> Task<Result<String, String>> {
 72        let project = self.project.clone();
 73
 74        cx.spawn(async move |cx| {
 75            let input = input
 76                .recv()
 77                .await
 78                .map_err(|e| format!("Failed to receive tool input: {e}"))?;
 79
 80            // Check for any immediate deny before doing async work.
 81            for path in &input.paths {
 82                let path_str = path.to_string_lossy();
 83                let decision = cx.update(|cx| {
 84                    decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx))
 85                });
 86                if let ToolPermissionDecision::Deny(reason) = decision {
 87                    return Err(reason);
 88                }
 89            }
 90
 91            let input_paths = input.paths;
 92
 93            let fs = project.read_with(cx, |project, _cx| project.fs().clone());
 94            let canonical_roots = canonicalize_worktree_roots(&project, &fs, cx).await;
 95
 96            let mut confirmation_paths: Vec<String> = Vec::new();
 97
 98            for path in &input_paths {
 99                let path_str = path.to_string_lossy();
100                let decision = cx.update(|cx| {
101                    decide_permission_for_path(Self::NAME, &path_str, AgentSettings::get_global(cx))
102                });
103                let symlink_escape = project.read_with(cx, |project, cx| {
104                    path_has_symlink_escape(project, path, &canonical_roots, cx)
105                });
106
107                match decision {
108                    ToolPermissionDecision::Allow => {
109                        if !symlink_escape {
110                            let is_sensitive = super::tool_permissions::is_sensitive_settings_path(
111                                Path::new(&*path_str),
112                                fs.as_ref(),
113                            )
114                            .await;
115                            if is_sensitive {
116                                confirmation_paths.push(path_str.to_string());
117                            }
118                        }
119                    }
120                    ToolPermissionDecision::Deny(reason) => {
121                        return Err(reason);
122                    }
123                    ToolPermissionDecision::Confirm => {
124                        if !symlink_escape {
125                            confirmation_paths.push(path_str.to_string());
126                        }
127                    }
128                }
129            }
130
131            if !confirmation_paths.is_empty() {
132                let title = if confirmation_paths.len() == 1 {
133                    format!("Save {}", MarkdownInlineCode(&confirmation_paths[0]))
134                } else {
135                    let paths: Vec<_> = confirmation_paths
136                        .iter()
137                        .take(3)
138                        .map(|p| p.as_str())
139                        .collect();
140                    if confirmation_paths.len() > 3 {
141                        format!(
142                            "Save {}, and {} more",
143                            paths.join(", "),
144                            confirmation_paths.len() - 3
145                        )
146                    } else {
147                        format!("Save {}", paths.join(", "))
148                    }
149                };
150
151                let mut settings_kind = None;
152                for p in &confirmation_paths {
153                    if let Some(kind) = sensitive_settings_kind(Path::new(p), fs.as_ref()).await {
154                        settings_kind = Some(kind);
155                        break;
156                    }
157                }
158                let title = match settings_kind {
159                    Some(SensitiveSettingsKind::Local) => format!("{title} (local settings)"),
160                    Some(SensitiveSettingsKind::Global) => format!("{title} (settings)"),
161                    None => title,
162                };
163                let context =
164                    crate::ToolPermissionContext::new(Self::NAME, confirmation_paths.clone());
165                let authorize = cx.update(|cx| event_stream.authorize(title, context, cx));
166                authorize.await.map_err(|e| e.to_string())?;
167            }
168
169            let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
170
171            let mut dirty_count: usize = 0;
172            let mut clean_paths: Vec<PathBuf> = Vec::new();
173            let mut not_found_paths: Vec<PathBuf> = Vec::new();
174            let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
175            let mut authorization_errors: Vec<(PathBuf, String)> = Vec::new();
176            let mut save_errors: Vec<(String, String)> = Vec::new();
177
178            for path in input_paths {
179                let project_path = match project.read_with(cx, |project, cx| {
180                    resolve_project_path(project, &path, &canonical_roots, cx)
181                }) {
182                    Ok(resolved) => {
183                        let (project_path, symlink_canonical_target) = match resolved {
184                            ResolvedProjectPath::Safe(path) => (path, None),
185                            ResolvedProjectPath::SymlinkEscape {
186                                project_path,
187                                canonical_target,
188                            } => (project_path, Some(canonical_target)),
189                        };
190                        if let Some(canonical_target) = &symlink_canonical_target {
191                            let path_str = path.to_string_lossy();
192                            let authorize_task = cx.update(|cx| {
193                                authorize_symlink_access(
194                                    Self::NAME,
195                                    &path_str,
196                                    canonical_target,
197                                    &event_stream,
198                                    cx,
199                                )
200                            });
201                            let result = authorize_task.await;
202                            if let Err(err) = result {
203                                authorization_errors.push((path.clone(), err.to_string()));
204                                continue;
205                            }
206                        }
207                        project_path
208                    }
209                    Err(_) => {
210                        not_found_paths.push(path);
211                        continue;
212                    }
213                };
214
215                let open_buffer_task =
216                    project.update(cx, |project, cx| project.open_buffer(project_path, cx));
217
218                let buffer = futures::select! {
219                    result = open_buffer_task.fuse() => {
220                        match result {
221                            Ok(buffer) => buffer,
222                            Err(error) => {
223                                open_errors.push((path, error.to_string()));
224                                continue;
225                            }
226                        }
227                    }
228                    _ = event_stream.cancelled_by_user().fuse() => {
229                        return Err("Save cancelled by user".to_string());
230                    }
231                };
232
233                let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
234
235                if is_dirty {
236                    buffers_to_save.insert(buffer);
237                    dirty_count += 1;
238                } else {
239                    clean_paths.push(path);
240                }
241            }
242
243            // Save each buffer individually since there's no batch save API.
244            for buffer in buffers_to_save {
245                let path_for_buffer = buffer
246                    .read_with(cx, |buffer, _| {
247                        buffer
248                            .file()
249                            .map(|file| file.path().to_rel_path_buf())
250                            .map(|path| path.as_rel_path().as_unix_str().to_owned())
251                    })
252                    .unwrap_or_else(|| "<unknown>".to_string());
253
254                let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
255
256                let save_result = futures::select! {
257                    result = save_task.fuse() => result,
258                    _ = event_stream.cancelled_by_user().fuse() => {
259                        return Err("Save cancelled by user".to_string());
260                    }
261                };
262                if let Err(error) = save_result {
263                    save_errors.push((path_for_buffer, error.to_string()));
264                }
265            }
266
267            let mut lines: Vec<String> = Vec::new();
268
269            let successful_saves = dirty_count.saturating_sub(save_errors.len());
270            if successful_saves > 0 {
271                lines.push(format!("Saved {} file(s).", successful_saves));
272            }
273            if !clean_paths.is_empty() {
274                lines.push(format!("{} clean.", clean_paths.len()));
275            }
276
277            if !not_found_paths.is_empty() {
278                lines.push(format!("Not found ({}):", not_found_paths.len()));
279                for path in &not_found_paths {
280                    lines.push(format!("- {}", path.display()));
281                }
282            }
283            if !open_errors.is_empty() {
284                lines.push(format!("Open failed ({}):", open_errors.len()));
285                for (path, error) in &open_errors {
286                    lines.push(format!("- {}: {}", path.display(), error));
287                }
288            }
289            if !authorization_errors.is_empty() {
290                lines.push(format!(
291                    "Authorization failed ({}):",
292                    authorization_errors.len()
293                ));
294                for (path, error) in &authorization_errors {
295                    lines.push(format!("- {}: {}", path.display(), error));
296                }
297            }
298            if !save_errors.is_empty() {
299                lines.push(format!("Save failed ({}):", save_errors.len()));
300                for (path, error) in &save_errors {
301                    lines.push(format!("- {}: {}", path, error));
302                }
303            }
304
305            if lines.is_empty() {
306                Ok("No paths provided.".to_string())
307            } else {
308                Ok(lines.join("\n"))
309            }
310        })
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use fs::Fs as _;
318    use gpui::TestAppContext;
319    use project::FakeFs;
320    use serde_json::json;
321    use settings::SettingsStore;
322    use util::path;
323
324    fn init_test(cx: &mut TestAppContext) {
325        cx.update(|cx| {
326            let settings_store = SettingsStore::test(cx);
327            cx.set_global(settings_store);
328        });
329        cx.update(|cx| {
330            let mut settings = AgentSettings::get_global(cx).clone();
331            settings.tool_permissions.default = settings::ToolPermissionMode::Allow;
332            AgentSettings::override_global(settings, cx);
333        });
334    }
335
336    #[gpui::test]
337    async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
338        init_test(cx);
339
340        let fs = FakeFs::new(cx.executor());
341        fs.insert_tree(
342            "/root",
343            json!({
344                "dirty.txt": "on disk: dirty\n",
345                "clean.txt": "on disk: clean\n",
346            }),
347        )
348        .await;
349
350        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
351        let tool = Arc::new(SaveFileTool::new(project.clone()));
352
353        // Make dirty.txt dirty in-memory.
354        let dirty_project_path = project.read_with(cx, |project, cx| {
355            project
356                .find_project_path("root/dirty.txt", cx)
357                .expect("dirty.txt should exist in project")
358        });
359
360        let dirty_buffer = project
361            .update(cx, |project, cx| {
362                project.open_buffer(dirty_project_path, cx)
363            })
364            .await
365            .unwrap();
366        dirty_buffer.update(cx, |buffer, cx| {
367            buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
368        });
369        assert!(
370            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
371            "dirty.txt buffer should be dirty before save"
372        );
373
374        // Ensure clean.txt is opened but remains clean.
375        let clean_project_path = project.read_with(cx, |project, cx| {
376            project
377                .find_project_path("root/clean.txt", cx)
378                .expect("clean.txt should exist in project")
379        });
380
381        let clean_buffer = project
382            .update(cx, |project, cx| {
383                project.open_buffer(clean_project_path, cx)
384            })
385            .await
386            .unwrap();
387        assert!(
388            !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
389            "clean.txt buffer should start clean"
390        );
391
392        let output = cx
393            .update(|cx| {
394                tool.clone().run(
395                    ToolInput::resolved(SaveFileToolInput {
396                        paths: vec![
397                            PathBuf::from("root/dirty.txt"),
398                            PathBuf::from("root/clean.txt"),
399                        ],
400                    }),
401                    ToolCallEventStream::test().0,
402                    cx,
403                )
404            })
405            .await
406            .unwrap();
407
408        // Output should mention saved + clean.
409        assert!(
410            output.contains("Saved 1 file(s)."),
411            "expected saved count line, got:\n{output}"
412        );
413        assert!(
414            output.contains("1 clean."),
415            "expected clean count line, got:\n{output}"
416        );
417
418        // Effect: dirty buffer should now be clean and disk should have new content.
419        assert!(
420            !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
421            "dirty.txt buffer should not be dirty after save"
422        );
423
424        let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
425        assert_eq!(
426            disk_dirty, "in memory: dirty\n",
427            "dirty.txt disk content should be updated"
428        );
429
430        // Sanity: clean buffer should remain clean and disk unchanged.
431        let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
432        assert_eq!(disk_clean, "on disk: clean\n");
433
434        // Test empty paths case.
435        let output = cx
436            .update(|cx| {
437                tool.clone().run(
438                    ToolInput::resolved(SaveFileToolInput { paths: vec![] }),
439                    ToolCallEventStream::test().0,
440                    cx,
441                )
442            })
443            .await
444            .unwrap();
445        assert_eq!(output, "No paths provided.");
446
447        // Test not-found path case.
448        let output = cx
449            .update(|cx| {
450                tool.clone().run(
451                    ToolInput::resolved(SaveFileToolInput {
452                        paths: vec![PathBuf::from("nonexistent/path.txt")],
453                    }),
454                    ToolCallEventStream::test().0,
455                    cx,
456                )
457            })
458            .await
459            .unwrap();
460        assert!(
461            output.contains("Not found (1):"),
462            "expected not-found header line, got:\n{output}"
463        );
464        assert!(
465            output.contains("- nonexistent/path.txt"),
466            "expected not-found path bullet, got:\n{output}"
467        );
468    }
469
470    #[gpui::test]
471    async fn test_save_file_symlink_escape_requests_authorization(cx: &mut TestAppContext) {
472        init_test(cx);
473
474        let fs = FakeFs::new(cx.executor());
475        fs.insert_tree(
476            path!("/root"),
477            json!({
478                "project": {
479                    "src": {}
480                },
481                "external": {
482                    "secret.txt": "secret content"
483                }
484            }),
485        )
486        .await;
487
488        fs.create_symlink(
489            path!("/root/project/link.txt").as_ref(),
490            PathBuf::from("../external/secret.txt"),
491        )
492        .await
493        .unwrap();
494
495        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
496        cx.executor().run_until_parked();
497
498        let tool = Arc::new(SaveFileTool::new(project));
499
500        let (event_stream, mut event_rx) = ToolCallEventStream::test();
501        let task = cx.update(|cx| {
502            tool.clone().run(
503                ToolInput::resolved(SaveFileToolInput {
504                    paths: vec![PathBuf::from("project/link.txt")],
505                }),
506                event_stream,
507                cx,
508            )
509        });
510
511        cx.run_until_parked();
512
513        let auth = event_rx.expect_authorization().await;
514        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
515        assert!(
516            title.contains("points outside the project"),
517            "Expected symlink escape authorization, got: {title}",
518        );
519
520        auth.response
521            .send(acp::PermissionOptionId::new("allow"))
522            .unwrap();
523
524        let _result = task.await;
525    }
526
527    #[gpui::test]
528    async fn test_save_file_symlink_escape_honors_deny_policy(cx: &mut TestAppContext) {
529        init_test(cx);
530        cx.update(|cx| {
531            let mut settings = AgentSettings::get_global(cx).clone();
532            settings.tool_permissions.tools.insert(
533                "save_file".into(),
534                agent_settings::ToolRules {
535                    default: Some(settings::ToolPermissionMode::Deny),
536                    ..Default::default()
537                },
538            );
539            AgentSettings::override_global(settings, cx);
540        });
541
542        let fs = FakeFs::new(cx.executor());
543        fs.insert_tree(
544            path!("/root"),
545            json!({
546                "project": {
547                    "src": {}
548                },
549                "external": {
550                    "secret.txt": "secret content"
551                }
552            }),
553        )
554        .await;
555
556        fs.create_symlink(
557            path!("/root/project/link.txt").as_ref(),
558            PathBuf::from("../external/secret.txt"),
559        )
560        .await
561        .unwrap();
562
563        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
564        cx.executor().run_until_parked();
565
566        let tool = Arc::new(SaveFileTool::new(project));
567
568        let (event_stream, mut event_rx) = ToolCallEventStream::test();
569        let result = cx
570            .update(|cx| {
571                tool.clone().run(
572                    ToolInput::resolved(SaveFileToolInput {
573                        paths: vec![PathBuf::from("project/link.txt")],
574                    }),
575                    event_stream,
576                    cx,
577                )
578            })
579            .await;
580
581        assert!(result.is_err(), "Tool should fail when policy denies");
582        assert!(
583            !matches!(
584                event_rx.try_next(),
585                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
586            ),
587            "Deny policy should not emit symlink authorization prompt",
588        );
589    }
590
591    #[gpui::test]
592    async fn test_save_file_symlink_escape_confirm_requires_single_approval(
593        cx: &mut TestAppContext,
594    ) {
595        init_test(cx);
596        cx.update(|cx| {
597            let mut settings = AgentSettings::get_global(cx).clone();
598            settings.tool_permissions.default = settings::ToolPermissionMode::Confirm;
599            AgentSettings::override_global(settings, cx);
600        });
601
602        let fs = FakeFs::new(cx.executor());
603        fs.insert_tree(
604            path!("/root"),
605            json!({
606                "project": {
607                    "src": {}
608                },
609                "external": {
610                    "secret.txt": "secret content"
611                }
612            }),
613        )
614        .await;
615
616        fs.create_symlink(
617            path!("/root/project/link.txt").as_ref(),
618            PathBuf::from("../external/secret.txt"),
619        )
620        .await
621        .unwrap();
622
623        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
624        cx.executor().run_until_parked();
625
626        let tool = Arc::new(SaveFileTool::new(project));
627
628        let (event_stream, mut event_rx) = ToolCallEventStream::test();
629        let task = cx.update(|cx| {
630            tool.clone().run(
631                ToolInput::resolved(SaveFileToolInput {
632                    paths: vec![PathBuf::from("project/link.txt")],
633                }),
634                event_stream,
635                cx,
636            )
637        });
638
639        cx.run_until_parked();
640
641        let auth = event_rx.expect_authorization().await;
642        let title = auth.tool_call.fields.title.as_deref().unwrap_or("");
643        assert!(
644            title.contains("points outside the project"),
645            "Expected symlink escape authorization, got: {title}",
646        );
647
648        auth.response
649            .send(acp::PermissionOptionId::new("allow"))
650            .unwrap();
651
652        assert!(
653            !matches!(
654                event_rx.try_next(),
655                Ok(Some(Ok(crate::ThreadEvent::ToolCallAuthorization(_))))
656            ),
657            "Expected a single authorization prompt",
658        );
659
660        let _result = task.await;
661    }
662
663    #[gpui::test]
664    async fn test_save_file_symlink_denial_does_not_reduce_success_count(cx: &mut TestAppContext) {
665        init_test(cx);
666
667        let fs = FakeFs::new(cx.executor());
668        fs.insert_tree(
669            path!("/root"),
670            json!({
671                "project": {
672                    "dirty.txt": "on disk value\n",
673                },
674                "external": {
675                    "secret.txt": "secret content"
676                }
677            }),
678        )
679        .await;
680
681        fs.create_symlink(
682            path!("/root/project/link.txt").as_ref(),
683            PathBuf::from("../external/secret.txt"),
684        )
685        .await
686        .unwrap();
687
688        let project = Project::test(fs.clone(), [path!("/root/project").as_ref()], cx).await;
689        cx.executor().run_until_parked();
690
691        let dirty_project_path = project.read_with(cx, |project, cx| {
692            project
693                .find_project_path("project/dirty.txt", cx)
694                .expect("dirty.txt should exist in project")
695        });
696        let dirty_buffer = project
697            .update(cx, |project, cx| {
698                project.open_buffer(dirty_project_path, cx)
699            })
700            .await
701            .unwrap();
702        dirty_buffer.update(cx, |buffer, cx| {
703            buffer.edit([(0..buffer.len(), "in memory value\n")], None, cx);
704        });
705        assert!(
706            dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
707            "dirty.txt should be dirty before save"
708        );
709
710        let tool = Arc::new(SaveFileTool::new(project));
711
712        let (event_stream, mut event_rx) = ToolCallEventStream::test();
713        let task = cx.update(|cx| {
714            tool.clone().run(
715                ToolInput::resolved(SaveFileToolInput {
716                    paths: vec![
717                        PathBuf::from("project/dirty.txt"),
718                        PathBuf::from("project/link.txt"),
719                    ],
720                }),
721                event_stream,
722                cx,
723            )
724        });
725
726        cx.run_until_parked();
727
728        let auth = event_rx.expect_authorization().await;
729        auth.response
730            .send(acp::PermissionOptionId::new("deny"))
731            .unwrap();
732
733        let output = task.await.unwrap();
734        assert!(
735            output.contains("Saved 1 file(s)."),
736            "Expected successful save count to remain accurate, got:\n{output}",
737        );
738        assert!(
739            output.contains("Authorization failed (1):"),
740            "Expected authorization failure section, got:\n{output}",
741        );
742        assert!(
743            !output.contains("Save failed"),
744            "Authorization denials should not be counted as save failures, got:\n{output}",
745        );
746    }
747}