1use agent_client_protocol as acp;
2use anyhow::Result;
3use collections::FxHashSet;
4use futures::FutureExt as _;
5use gpui::{App, Entity, SharedString, Task};
6use language::Buffer;
7use project::Project;
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11use std::sync::Arc;
12
13use crate::{AgentTool, ToolCallEventStream};
14
15/// Saves files that have unsaved changes.
16///
17/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
18/// Only use this tool after asking the user for permission to save their unsaved changes.
19#[derive(Debug, Serialize, Deserialize, JsonSchema)]
20pub struct SaveFileToolInput {
21 /// The paths of the files to save.
22 pub paths: Vec<PathBuf>,
23}
24
25pub struct SaveFileTool {
26 project: Entity<Project>,
27}
28
29impl SaveFileTool {
30 pub fn new(project: Entity<Project>) -> Self {
31 Self { project }
32 }
33}
34
35impl AgentTool for SaveFileTool {
36 type Input = SaveFileToolInput;
37 type Output = String;
38
39 fn name() -> &'static str {
40 "save_file"
41 }
42
43 fn kind() -> acp::ToolKind {
44 acp::ToolKind::Other
45 }
46
47 fn initial_title(
48 &self,
49 input: Result<Self::Input, serde_json::Value>,
50 _cx: &mut App,
51 ) -> SharedString {
52 match input {
53 Ok(input) if input.paths.len() == 1 => "Save file".into(),
54 Ok(input) => format!("Save {} files", input.paths.len()).into(),
55 Err(_) => "Save files".into(),
56 }
57 }
58
59 fn run(
60 self: Arc<Self>,
61 input: Self::Input,
62 event_stream: ToolCallEventStream,
63 cx: &mut App,
64 ) -> Task<Result<String>> {
65 let project = self.project.clone();
66 let input_paths = input.paths;
67
68 cx.spawn(async move |cx| {
69 let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
70
71 let mut saved_paths: Vec<PathBuf> = Vec::new();
72 let mut clean_paths: Vec<PathBuf> = Vec::new();
73 let mut not_found_paths: Vec<PathBuf> = Vec::new();
74 let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
75 let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
76 let mut save_errors: Vec<(String, String)> = Vec::new();
77
78 for path in input_paths {
79 let Some(project_path) =
80 project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
81 else {
82 not_found_paths.push(path);
83 continue;
84 };
85
86 let open_buffer_task =
87 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
88
89 let buffer = futures::select! {
90 result = open_buffer_task.fuse() => {
91 match result {
92 Ok(buffer) => buffer,
93 Err(error) => {
94 open_errors.push((path, error.to_string()));
95 continue;
96 }
97 }
98 }
99 _ = event_stream.cancelled_by_user().fuse() => {
100 anyhow::bail!("Save cancelled by user");
101 }
102 };
103
104 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
105
106 if is_dirty {
107 buffers_to_save.insert(buffer);
108 saved_paths.push(path);
109 } else {
110 clean_paths.push(path);
111 }
112 }
113
114 // Save each buffer individually since there's no batch save API.
115 for buffer in buffers_to_save {
116 let path_for_buffer = buffer
117 .read_with(cx, |buffer, _| {
118 buffer
119 .file()
120 .map(|file| file.path().to_rel_path_buf())
121 .map(|path| path.as_rel_path().as_unix_str().to_owned())
122 })
123 .unwrap_or_else(|| "<unknown>".to_string());
124
125 let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
126
127 let save_result = futures::select! {
128 result = save_task.fuse() => result,
129 _ = event_stream.cancelled_by_user().fuse() => {
130 anyhow::bail!("Save cancelled by user");
131 }
132 };
133 if let Err(error) = save_result {
134 save_errors.push((path_for_buffer, error.to_string()));
135 }
136 }
137
138 let mut lines: Vec<String> = Vec::new();
139
140 if !saved_paths.is_empty() {
141 lines.push(format!("Saved {} file(s).", saved_paths.len()));
142 }
143 if !clean_paths.is_empty() {
144 lines.push(format!("{} clean.", clean_paths.len()));
145 }
146
147 if !not_found_paths.is_empty() {
148 lines.push(format!("Not found ({}):", not_found_paths.len()));
149 for path in ¬_found_paths {
150 lines.push(format!("- {}", path.display()));
151 }
152 }
153 if !open_errors.is_empty() {
154 lines.push(format!("Open failed ({}):", open_errors.len()));
155 for (path, error) in &open_errors {
156 lines.push(format!("- {}: {}", path.display(), error));
157 }
158 }
159 if !dirty_check_errors.is_empty() {
160 lines.push(format!(
161 "Dirty check failed ({}):",
162 dirty_check_errors.len()
163 ));
164 for (path, error) in &dirty_check_errors {
165 lines.push(format!("- {}: {}", path.display(), error));
166 }
167 }
168 if !save_errors.is_empty() {
169 lines.push(format!("Save failed ({}):", save_errors.len()));
170 for (path, error) in &save_errors {
171 lines.push(format!("- {}: {}", path, error));
172 }
173 }
174
175 if lines.is_empty() {
176 Ok("No paths provided.".to_string())
177 } else {
178 Ok(lines.join("\n"))
179 }
180 })
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use super::*;
187 use fs::Fs;
188 use gpui::TestAppContext;
189 use project::FakeFs;
190 use serde_json::json;
191 use settings::SettingsStore;
192 use util::path;
193
194 fn init_test(cx: &mut TestAppContext) {
195 cx.update(|cx| {
196 let settings_store = SettingsStore::test(cx);
197 cx.set_global(settings_store);
198 });
199 }
200
201 #[gpui::test]
202 async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
203 init_test(cx);
204
205 let fs = FakeFs::new(cx.executor());
206 fs.insert_tree(
207 "/root",
208 json!({
209 "dirty.txt": "on disk: dirty\n",
210 "clean.txt": "on disk: clean\n",
211 }),
212 )
213 .await;
214
215 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
216 let tool = Arc::new(SaveFileTool::new(project.clone()));
217
218 // Make dirty.txt dirty in-memory.
219 let dirty_project_path = project.read_with(cx, |project, cx| {
220 project
221 .find_project_path("root/dirty.txt", cx)
222 .expect("dirty.txt should exist in project")
223 });
224
225 let dirty_buffer = project
226 .update(cx, |project, cx| {
227 project.open_buffer(dirty_project_path, cx)
228 })
229 .await
230 .unwrap();
231 dirty_buffer.update(cx, |buffer, cx| {
232 buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
233 });
234 assert!(
235 dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
236 "dirty.txt buffer should be dirty before save"
237 );
238
239 // Ensure clean.txt is opened but remains clean.
240 let clean_project_path = project.read_with(cx, |project, cx| {
241 project
242 .find_project_path("root/clean.txt", cx)
243 .expect("clean.txt should exist in project")
244 });
245
246 let clean_buffer = project
247 .update(cx, |project, cx| {
248 project.open_buffer(clean_project_path, cx)
249 })
250 .await
251 .unwrap();
252 assert!(
253 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
254 "clean.txt buffer should start clean"
255 );
256
257 let output = cx
258 .update(|cx| {
259 tool.clone().run(
260 SaveFileToolInput {
261 paths: vec![
262 PathBuf::from("root/dirty.txt"),
263 PathBuf::from("root/clean.txt"),
264 ],
265 },
266 ToolCallEventStream::test().0,
267 cx,
268 )
269 })
270 .await
271 .unwrap();
272
273 // Output should mention saved + clean.
274 assert!(
275 output.contains("Saved 1 file(s)."),
276 "expected saved count line, got:\n{output}"
277 );
278 assert!(
279 output.contains("1 clean."),
280 "expected clean count line, got:\n{output}"
281 );
282
283 // Effect: dirty buffer should now be clean and disk should have new content.
284 assert!(
285 !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
286 "dirty.txt buffer should not be dirty after save"
287 );
288
289 let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
290 assert_eq!(
291 disk_dirty, "in memory: dirty\n",
292 "dirty.txt disk content should be updated"
293 );
294
295 // Sanity: clean buffer should remain clean and disk unchanged.
296 let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
297 assert_eq!(disk_clean, "on disk: clean\n");
298
299 // Test empty paths case.
300 let output = cx
301 .update(|cx| {
302 tool.clone().run(
303 SaveFileToolInput { paths: vec![] },
304 ToolCallEventStream::test().0,
305 cx,
306 )
307 })
308 .await
309 .unwrap();
310 assert_eq!(output, "No paths provided.");
311
312 // Test not-found path case.
313 let output = cx
314 .update(|cx| {
315 tool.clone().run(
316 SaveFileToolInput {
317 paths: vec![PathBuf::from("nonexistent/path.txt")],
318 },
319 ToolCallEventStream::test().0,
320 cx,
321 )
322 })
323 .await
324 .unwrap();
325 assert!(
326 output.contains("Not found (1):"),
327 "expected not-found header line, got:\n{output}"
328 );
329 assert!(
330 output.contains("- nonexistent/path.txt"),
331 "expected not-found path bullet, got:\n{output}"
332 );
333 }
334}