1use agent_client_protocol as acp;
2use agent_settings::AgentSettings;
3use anyhow::Result;
4use collections::FxHashSet;
5use futures::FutureExt as _;
6use gpui::{App, Entity, SharedString, Task};
7use language::Buffer;
8use project::Project;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use settings::Settings;
12use std::path::PathBuf;
13use std::sync::Arc;
14use util::markdown::MarkdownInlineCode;
15
16use crate::{
17 AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
18};
19
20/// Saves files that have unsaved changes.
21///
22/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
23/// Only use this tool after asking the user for permission to save their unsaved changes.
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25pub struct SaveFileToolInput {
26 /// The paths of the files to save.
27 pub paths: Vec<PathBuf>,
28}
29
30pub struct SaveFileTool {
31 project: Entity<Project>,
32}
33
34impl SaveFileTool {
35 pub fn new(project: Entity<Project>) -> Self {
36 Self { project }
37 }
38}
39
40impl AgentTool for SaveFileTool {
41 type Input = SaveFileToolInput;
42 type Output = String;
43
44 const NAME: &'static str = "save_file";
45
46 fn kind() -> acp::ToolKind {
47 acp::ToolKind::Other
48 }
49
50 fn initial_title(
51 &self,
52 input: Result<Self::Input, serde_json::Value>,
53 _cx: &mut App,
54 ) -> SharedString {
55 match input {
56 Ok(input) if input.paths.len() == 1 => "Save file".into(),
57 Ok(input) => format!("Save {} files", input.paths.len()).into(),
58 Err(_) => "Save files".into(),
59 }
60 }
61
62 fn run(
63 self: Arc<Self>,
64 input: Self::Input,
65 event_stream: ToolCallEventStream,
66 cx: &mut App,
67 ) -> Task<Result<String>> {
68 let settings = AgentSettings::get_global(cx);
69 let mut needs_confirmation = false;
70
71 for path in &input.paths {
72 let path_str = path.to_string_lossy();
73 let decision = decide_permission_from_settings(Self::NAME, &path_str, settings);
74 match decision {
75 ToolPermissionDecision::Allow => {}
76 ToolPermissionDecision::Deny(reason) => {
77 return Task::ready(Err(anyhow::anyhow!("{}", reason)));
78 }
79 ToolPermissionDecision::Confirm => {
80 needs_confirmation = true;
81 }
82 }
83 }
84
85 let authorize = if needs_confirmation {
86 let title = if input.paths.len() == 1 {
87 format!(
88 "Save {}",
89 MarkdownInlineCode(&input.paths[0].to_string_lossy())
90 )
91 } else {
92 let paths: Vec<_> = input
93 .paths
94 .iter()
95 .take(3)
96 .map(|p| p.to_string_lossy().to_string())
97 .collect();
98 if input.paths.len() > 3 {
99 format!(
100 "Save {}, and {} more",
101 paths.join(", "),
102 input.paths.len() - 3
103 )
104 } else {
105 format!("Save {}", paths.join(", "))
106 }
107 };
108 let first_path = input
109 .paths
110 .first()
111 .map(|p| p.to_string_lossy().to_string())
112 .unwrap_or_default();
113 let context = crate::ToolPermissionContext {
114 tool_name: Self::NAME.to_string(),
115 input_value: first_path,
116 };
117 Some(event_stream.authorize(title, context, cx))
118 } else {
119 None
120 };
121
122 let project = self.project.clone();
123 let input_paths = input.paths;
124
125 cx.spawn(async move |cx| {
126 if let Some(authorize) = authorize {
127 authorize.await?;
128 }
129
130 let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
131
132 let mut saved_paths: Vec<PathBuf> = Vec::new();
133 let mut clean_paths: Vec<PathBuf> = Vec::new();
134 let mut not_found_paths: Vec<PathBuf> = Vec::new();
135 let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
136 let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
137 let mut save_errors: Vec<(String, String)> = Vec::new();
138
139 for path in input_paths {
140 let Some(project_path) =
141 project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
142 else {
143 not_found_paths.push(path);
144 continue;
145 };
146
147 let open_buffer_task =
148 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
149
150 let buffer = futures::select! {
151 result = open_buffer_task.fuse() => {
152 match result {
153 Ok(buffer) => buffer,
154 Err(error) => {
155 open_errors.push((path, error.to_string()));
156 continue;
157 }
158 }
159 }
160 _ = event_stream.cancelled_by_user().fuse() => {
161 anyhow::bail!("Save cancelled by user");
162 }
163 };
164
165 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
166
167 if is_dirty {
168 buffers_to_save.insert(buffer);
169 saved_paths.push(path);
170 } else {
171 clean_paths.push(path);
172 }
173 }
174
175 // Save each buffer individually since there's no batch save API.
176 for buffer in buffers_to_save {
177 let path_for_buffer = buffer
178 .read_with(cx, |buffer, _| {
179 buffer
180 .file()
181 .map(|file| file.path().to_rel_path_buf())
182 .map(|path| path.as_rel_path().as_unix_str().to_owned())
183 })
184 .unwrap_or_else(|| "<unknown>".to_string());
185
186 let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
187
188 let save_result = futures::select! {
189 result = save_task.fuse() => result,
190 _ = event_stream.cancelled_by_user().fuse() => {
191 anyhow::bail!("Save cancelled by user");
192 }
193 };
194 if let Err(error) = save_result {
195 save_errors.push((path_for_buffer, error.to_string()));
196 }
197 }
198
199 let mut lines: Vec<String> = Vec::new();
200
201 if !saved_paths.is_empty() {
202 lines.push(format!("Saved {} file(s).", saved_paths.len()));
203 }
204 if !clean_paths.is_empty() {
205 lines.push(format!("{} clean.", clean_paths.len()));
206 }
207
208 if !not_found_paths.is_empty() {
209 lines.push(format!("Not found ({}):", not_found_paths.len()));
210 for path in ¬_found_paths {
211 lines.push(format!("- {}", path.display()));
212 }
213 }
214 if !open_errors.is_empty() {
215 lines.push(format!("Open failed ({}):", open_errors.len()));
216 for (path, error) in &open_errors {
217 lines.push(format!("- {}: {}", path.display(), error));
218 }
219 }
220 if !dirty_check_errors.is_empty() {
221 lines.push(format!(
222 "Dirty check failed ({}):",
223 dirty_check_errors.len()
224 ));
225 for (path, error) in &dirty_check_errors {
226 lines.push(format!("- {}: {}", path.display(), error));
227 }
228 }
229 if !save_errors.is_empty() {
230 lines.push(format!("Save failed ({}):", save_errors.len()));
231 for (path, error) in &save_errors {
232 lines.push(format!("- {}: {}", path, error));
233 }
234 }
235
236 if lines.is_empty() {
237 Ok("No paths provided.".to_string())
238 } else {
239 Ok(lines.join("\n"))
240 }
241 })
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use fs::Fs;
249 use gpui::TestAppContext;
250 use project::FakeFs;
251 use serde_json::json;
252 use settings::SettingsStore;
253 use util::path;
254
255 fn init_test(cx: &mut TestAppContext) {
256 cx.update(|cx| {
257 let settings_store = SettingsStore::test(cx);
258 cx.set_global(settings_store);
259 });
260 cx.update(|cx| {
261 let mut settings = AgentSettings::get_global(cx).clone();
262 settings.always_allow_tool_actions = true;
263 AgentSettings::override_global(settings, cx);
264 });
265 }
266
267 #[gpui::test]
268 async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
269 init_test(cx);
270
271 let fs = FakeFs::new(cx.executor());
272 fs.insert_tree(
273 "/root",
274 json!({
275 "dirty.txt": "on disk: dirty\n",
276 "clean.txt": "on disk: clean\n",
277 }),
278 )
279 .await;
280
281 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
282 let tool = Arc::new(SaveFileTool::new(project.clone()));
283
284 // Make dirty.txt dirty in-memory.
285 let dirty_project_path = project.read_with(cx, |project, cx| {
286 project
287 .find_project_path("root/dirty.txt", cx)
288 .expect("dirty.txt should exist in project")
289 });
290
291 let dirty_buffer = project
292 .update(cx, |project, cx| {
293 project.open_buffer(dirty_project_path, cx)
294 })
295 .await
296 .unwrap();
297 dirty_buffer.update(cx, |buffer, cx| {
298 buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
299 });
300 assert!(
301 dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
302 "dirty.txt buffer should be dirty before save"
303 );
304
305 // Ensure clean.txt is opened but remains clean.
306 let clean_project_path = project.read_with(cx, |project, cx| {
307 project
308 .find_project_path("root/clean.txt", cx)
309 .expect("clean.txt should exist in project")
310 });
311
312 let clean_buffer = project
313 .update(cx, |project, cx| {
314 project.open_buffer(clean_project_path, cx)
315 })
316 .await
317 .unwrap();
318 assert!(
319 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
320 "clean.txt buffer should start clean"
321 );
322
323 let output = cx
324 .update(|cx| {
325 tool.clone().run(
326 SaveFileToolInput {
327 paths: vec![
328 PathBuf::from("root/dirty.txt"),
329 PathBuf::from("root/clean.txt"),
330 ],
331 },
332 ToolCallEventStream::test().0,
333 cx,
334 )
335 })
336 .await
337 .unwrap();
338
339 // Output should mention saved + clean.
340 assert!(
341 output.contains("Saved 1 file(s)."),
342 "expected saved count line, got:\n{output}"
343 );
344 assert!(
345 output.contains("1 clean."),
346 "expected clean count line, got:\n{output}"
347 );
348
349 // Effect: dirty buffer should now be clean and disk should have new content.
350 assert!(
351 !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
352 "dirty.txt buffer should not be dirty after save"
353 );
354
355 let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
356 assert_eq!(
357 disk_dirty, "in memory: dirty\n",
358 "dirty.txt disk content should be updated"
359 );
360
361 // Sanity: clean buffer should remain clean and disk unchanged.
362 let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
363 assert_eq!(disk_clean, "on disk: clean\n");
364
365 // Test empty paths case.
366 let output = cx
367 .update(|cx| {
368 tool.clone().run(
369 SaveFileToolInput { paths: vec![] },
370 ToolCallEventStream::test().0,
371 cx,
372 )
373 })
374 .await
375 .unwrap();
376 assert_eq!(output, "No paths provided.");
377
378 // Test not-found path case.
379 let output = cx
380 .update(|cx| {
381 tool.clone().run(
382 SaveFileToolInput {
383 paths: vec![PathBuf::from("nonexistent/path.txt")],
384 },
385 ToolCallEventStream::test().0,
386 cx,
387 )
388 })
389 .await
390 .unwrap();
391 assert!(
392 output.contains("Not found (1):"),
393 "expected not-found header line, got:\n{output}"
394 );
395 assert!(
396 output.contains("- nonexistent/path.txt"),
397 "expected not-found path bullet, got:\n{output}"
398 );
399 }
400}