1use agent_client_protocol as acp;
2use agent_settings::AgentSettings;
3use anyhow::Result;
4use collections::FxHashSet;
5use futures::FutureExt as _;
6use gpui::{App, Entity, SharedString, Task};
7use language::Buffer;
8use project::Project;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use settings::Settings;
12use std::path::PathBuf;
13use std::sync::Arc;
14use util::markdown::MarkdownInlineCode;
15
16use crate::{
17 AgentTool, ToolCallEventStream, ToolPermissionDecision, decide_permission_from_settings,
18};
19
20/// Saves files that have unsaved changes.
21///
22/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
23/// Only use this tool after asking the user for permission to save their unsaved changes.
24#[derive(Debug, Serialize, Deserialize, JsonSchema)]
25pub struct SaveFileToolInput {
26 /// The paths of the files to save.
27 pub paths: Vec<PathBuf>,
28}
29
30pub struct SaveFileTool {
31 project: Entity<Project>,
32}
33
34impl SaveFileTool {
35 pub fn new(project: Entity<Project>) -> Self {
36 Self { project }
37 }
38}
39
40impl AgentTool for SaveFileTool {
41 type Input = SaveFileToolInput;
42 type Output = String;
43
44 fn name() -> &'static str {
45 "save_file"
46 }
47
48 fn kind() -> acp::ToolKind {
49 acp::ToolKind::Other
50 }
51
52 fn initial_title(
53 &self,
54 input: Result<Self::Input, serde_json::Value>,
55 _cx: &mut App,
56 ) -> SharedString {
57 match input {
58 Ok(input) if input.paths.len() == 1 => "Save file".into(),
59 Ok(input) => format!("Save {} files", input.paths.len()).into(),
60 Err(_) => "Save files".into(),
61 }
62 }
63
64 fn run(
65 self: Arc<Self>,
66 input: Self::Input,
67 event_stream: ToolCallEventStream,
68 cx: &mut App,
69 ) -> Task<Result<String>> {
70 let settings = AgentSettings::get_global(cx);
71 let mut needs_confirmation = false;
72
73 for path in &input.paths {
74 let path_str = path.to_string_lossy();
75 let decision = decide_permission_from_settings(Self::name(), &path_str, settings);
76 match decision {
77 ToolPermissionDecision::Allow => {}
78 ToolPermissionDecision::Deny(reason) => {
79 return Task::ready(Err(anyhow::anyhow!("{}", reason)));
80 }
81 ToolPermissionDecision::Confirm => {
82 needs_confirmation = true;
83 }
84 }
85 }
86
87 let authorize = if needs_confirmation {
88 let title = if input.paths.len() == 1 {
89 format!(
90 "Save {}",
91 MarkdownInlineCode(&input.paths[0].to_string_lossy())
92 )
93 } else {
94 let paths: Vec<_> = input
95 .paths
96 .iter()
97 .take(3)
98 .map(|p| p.to_string_lossy().to_string())
99 .collect();
100 if input.paths.len() > 3 {
101 format!(
102 "Save {}, and {} more",
103 paths.join(", "),
104 input.paths.len() - 3
105 )
106 } else {
107 format!("Save {}", paths.join(", "))
108 }
109 };
110 let first_path = input
111 .paths
112 .first()
113 .map(|p| p.to_string_lossy().to_string())
114 .unwrap_or_default();
115 let context = crate::ToolPermissionContext {
116 tool_name: "save_file".to_string(),
117 input_value: first_path,
118 };
119 Some(event_stream.authorize(title, context, cx))
120 } else {
121 None
122 };
123
124 let project = self.project.clone();
125 let input_paths = input.paths;
126
127 cx.spawn(async move |cx| {
128 if let Some(authorize) = authorize {
129 authorize.await?;
130 }
131
132 let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
133
134 let mut saved_paths: Vec<PathBuf> = Vec::new();
135 let mut clean_paths: Vec<PathBuf> = Vec::new();
136 let mut not_found_paths: Vec<PathBuf> = Vec::new();
137 let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
138 let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
139 let mut save_errors: Vec<(String, String)> = Vec::new();
140
141 for path in input_paths {
142 let Some(project_path) =
143 project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
144 else {
145 not_found_paths.push(path);
146 continue;
147 };
148
149 let open_buffer_task =
150 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
151
152 let buffer = futures::select! {
153 result = open_buffer_task.fuse() => {
154 match result {
155 Ok(buffer) => buffer,
156 Err(error) => {
157 open_errors.push((path, error.to_string()));
158 continue;
159 }
160 }
161 }
162 _ = event_stream.cancelled_by_user().fuse() => {
163 anyhow::bail!("Save cancelled by user");
164 }
165 };
166
167 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
168
169 if is_dirty {
170 buffers_to_save.insert(buffer);
171 saved_paths.push(path);
172 } else {
173 clean_paths.push(path);
174 }
175 }
176
177 // Save each buffer individually since there's no batch save API.
178 for buffer in buffers_to_save {
179 let path_for_buffer = buffer
180 .read_with(cx, |buffer, _| {
181 buffer
182 .file()
183 .map(|file| file.path().to_rel_path_buf())
184 .map(|path| path.as_rel_path().as_unix_str().to_owned())
185 })
186 .unwrap_or_else(|| "<unknown>".to_string());
187
188 let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
189
190 let save_result = futures::select! {
191 result = save_task.fuse() => result,
192 _ = event_stream.cancelled_by_user().fuse() => {
193 anyhow::bail!("Save cancelled by user");
194 }
195 };
196 if let Err(error) = save_result {
197 save_errors.push((path_for_buffer, error.to_string()));
198 }
199 }
200
201 let mut lines: Vec<String> = Vec::new();
202
203 if !saved_paths.is_empty() {
204 lines.push(format!("Saved {} file(s).", saved_paths.len()));
205 }
206 if !clean_paths.is_empty() {
207 lines.push(format!("{} clean.", clean_paths.len()));
208 }
209
210 if !not_found_paths.is_empty() {
211 lines.push(format!("Not found ({}):", not_found_paths.len()));
212 for path in ¬_found_paths {
213 lines.push(format!("- {}", path.display()));
214 }
215 }
216 if !open_errors.is_empty() {
217 lines.push(format!("Open failed ({}):", open_errors.len()));
218 for (path, error) in &open_errors {
219 lines.push(format!("- {}: {}", path.display(), error));
220 }
221 }
222 if !dirty_check_errors.is_empty() {
223 lines.push(format!(
224 "Dirty check failed ({}):",
225 dirty_check_errors.len()
226 ));
227 for (path, error) in &dirty_check_errors {
228 lines.push(format!("- {}: {}", path.display(), error));
229 }
230 }
231 if !save_errors.is_empty() {
232 lines.push(format!("Save failed ({}):", save_errors.len()));
233 for (path, error) in &save_errors {
234 lines.push(format!("- {}: {}", path, error));
235 }
236 }
237
238 if lines.is_empty() {
239 Ok("No paths provided.".to_string())
240 } else {
241 Ok(lines.join("\n"))
242 }
243 })
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use fs::Fs;
251 use gpui::TestAppContext;
252 use project::FakeFs;
253 use serde_json::json;
254 use settings::SettingsStore;
255 use util::path;
256
257 fn init_test(cx: &mut TestAppContext) {
258 cx.update(|cx| {
259 let settings_store = SettingsStore::test(cx);
260 cx.set_global(settings_store);
261 });
262 cx.update(|cx| {
263 let mut settings = AgentSettings::get_global(cx).clone();
264 settings.always_allow_tool_actions = true;
265 AgentSettings::override_global(settings, cx);
266 });
267 }
268
269 #[gpui::test]
270 async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
271 init_test(cx);
272
273 let fs = FakeFs::new(cx.executor());
274 fs.insert_tree(
275 "/root",
276 json!({
277 "dirty.txt": "on disk: dirty\n",
278 "clean.txt": "on disk: clean\n",
279 }),
280 )
281 .await;
282
283 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
284 let tool = Arc::new(SaveFileTool::new(project.clone()));
285
286 // Make dirty.txt dirty in-memory.
287 let dirty_project_path = project.read_with(cx, |project, cx| {
288 project
289 .find_project_path("root/dirty.txt", cx)
290 .expect("dirty.txt should exist in project")
291 });
292
293 let dirty_buffer = project
294 .update(cx, |project, cx| {
295 project.open_buffer(dirty_project_path, cx)
296 })
297 .await
298 .unwrap();
299 dirty_buffer.update(cx, |buffer, cx| {
300 buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
301 });
302 assert!(
303 dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
304 "dirty.txt buffer should be dirty before save"
305 );
306
307 // Ensure clean.txt is opened but remains clean.
308 let clean_project_path = project.read_with(cx, |project, cx| {
309 project
310 .find_project_path("root/clean.txt", cx)
311 .expect("clean.txt should exist in project")
312 });
313
314 let clean_buffer = project
315 .update(cx, |project, cx| {
316 project.open_buffer(clean_project_path, cx)
317 })
318 .await
319 .unwrap();
320 assert!(
321 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
322 "clean.txt buffer should start clean"
323 );
324
325 let output = cx
326 .update(|cx| {
327 tool.clone().run(
328 SaveFileToolInput {
329 paths: vec![
330 PathBuf::from("root/dirty.txt"),
331 PathBuf::from("root/clean.txt"),
332 ],
333 },
334 ToolCallEventStream::test().0,
335 cx,
336 )
337 })
338 .await
339 .unwrap();
340
341 // Output should mention saved + clean.
342 assert!(
343 output.contains("Saved 1 file(s)."),
344 "expected saved count line, got:\n{output}"
345 );
346 assert!(
347 output.contains("1 clean."),
348 "expected clean count line, got:\n{output}"
349 );
350
351 // Effect: dirty buffer should now be clean and disk should have new content.
352 assert!(
353 !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
354 "dirty.txt buffer should not be dirty after save"
355 );
356
357 let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
358 assert_eq!(
359 disk_dirty, "in memory: dirty\n",
360 "dirty.txt disk content should be updated"
361 );
362
363 // Sanity: clean buffer should remain clean and disk unchanged.
364 let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
365 assert_eq!(disk_clean, "on disk: clean\n");
366
367 // Test empty paths case.
368 let output = cx
369 .update(|cx| {
370 tool.clone().run(
371 SaveFileToolInput { paths: vec![] },
372 ToolCallEventStream::test().0,
373 cx,
374 )
375 })
376 .await
377 .unwrap();
378 assert_eq!(output, "No paths provided.");
379
380 // Test not-found path case.
381 let output = cx
382 .update(|cx| {
383 tool.clone().run(
384 SaveFileToolInput {
385 paths: vec![PathBuf::from("nonexistent/path.txt")],
386 },
387 ToolCallEventStream::test().0,
388 cx,
389 )
390 })
391 .await
392 .unwrap();
393 assert!(
394 output.contains("Not found (1):"),
395 "expected not-found header line, got:\n{output}"
396 );
397 assert!(
398 output.contains("- nonexistent/path.txt"),
399 "expected not-found path bullet, got:\n{output}"
400 );
401 }
402}