1use agent_client_protocol as acp;
2use anyhow::Result;
3use collections::FxHashSet;
4use gpui::{App, Entity, SharedString, Task};
5use language::Buffer;
6use project::Project;
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::path::PathBuf;
10use std::sync::Arc;
11
12use crate::{AgentTool, ToolCallEventStream};
13
14/// Saves files that have unsaved changes.
15///
16/// Use this tool when you need to edit files but they have unsaved changes that must be saved first.
17/// Only use this tool after asking the user for permission to save their unsaved changes.
18#[derive(Debug, Serialize, Deserialize, JsonSchema)]
19pub struct SaveFileToolInput {
20 /// The paths of the files to save.
21 pub paths: Vec<PathBuf>,
22}
23
24pub struct SaveFileTool {
25 project: Entity<Project>,
26}
27
28impl SaveFileTool {
29 pub fn new(project: Entity<Project>) -> Self {
30 Self { project }
31 }
32}
33
34impl AgentTool for SaveFileTool {
35 type Input = SaveFileToolInput;
36 type Output = String;
37
38 fn name() -> &'static str {
39 "save_file"
40 }
41
42 fn kind() -> acp::ToolKind {
43 acp::ToolKind::Other
44 }
45
46 fn initial_title(
47 &self,
48 input: Result<Self::Input, serde_json::Value>,
49 _cx: &mut App,
50 ) -> SharedString {
51 match input {
52 Ok(input) if input.paths.len() == 1 => "Save file".into(),
53 Ok(input) => format!("Save {} files", input.paths.len()).into(),
54 Err(_) => "Save files".into(),
55 }
56 }
57
58 fn run(
59 self: Arc<Self>,
60 input: Self::Input,
61 _event_stream: ToolCallEventStream,
62 cx: &mut App,
63 ) -> Task<Result<String>> {
64 let project = self.project.clone();
65 let input_paths = input.paths;
66
67 cx.spawn(async move |cx| {
68 let mut buffers_to_save: FxHashSet<Entity<Buffer>> = FxHashSet::default();
69
70 let mut saved_paths: Vec<PathBuf> = Vec::new();
71 let mut clean_paths: Vec<PathBuf> = Vec::new();
72 let mut not_found_paths: Vec<PathBuf> = Vec::new();
73 let mut open_errors: Vec<(PathBuf, String)> = Vec::new();
74 let dirty_check_errors: Vec<(PathBuf, String)> = Vec::new();
75 let mut save_errors: Vec<(String, String)> = Vec::new();
76
77 for path in input_paths {
78 let Some(project_path) =
79 project.read_with(cx, |project, cx| project.find_project_path(&path, cx))
80 else {
81 not_found_paths.push(path);
82 continue;
83 };
84
85 let open_buffer_task =
86 project.update(cx, |project, cx| project.open_buffer(project_path, cx));
87
88 let buffer = match open_buffer_task.await {
89 Ok(buffer) => buffer,
90 Err(error) => {
91 open_errors.push((path, error.to_string()));
92 continue;
93 }
94 };
95
96 let is_dirty = buffer.read_with(cx, |buffer, _| buffer.is_dirty());
97
98 if is_dirty {
99 buffers_to_save.insert(buffer);
100 saved_paths.push(path);
101 } else {
102 clean_paths.push(path);
103 }
104 }
105
106 // Save each buffer individually since there's no batch save API.
107 for buffer in buffers_to_save {
108 let path_for_buffer = buffer
109 .read_with(cx, |buffer, _| {
110 buffer
111 .file()
112 .map(|file| file.path().to_rel_path_buf())
113 .map(|path| path.as_rel_path().as_unix_str().to_owned())
114 })
115 .unwrap_or_else(|| "<unknown>".to_string());
116
117 let save_task = project.update(cx, |project, cx| project.save_buffer(buffer, cx));
118
119 if let Err(error) = save_task.await {
120 save_errors.push((path_for_buffer, error.to_string()));
121 }
122 }
123
124 let mut lines: Vec<String> = Vec::new();
125
126 if !saved_paths.is_empty() {
127 lines.push(format!("Saved {} file(s).", saved_paths.len()));
128 }
129 if !clean_paths.is_empty() {
130 lines.push(format!("{} clean.", clean_paths.len()));
131 }
132
133 if !not_found_paths.is_empty() {
134 lines.push(format!("Not found ({}):", not_found_paths.len()));
135 for path in ¬_found_paths {
136 lines.push(format!("- {}", path.display()));
137 }
138 }
139 if !open_errors.is_empty() {
140 lines.push(format!("Open failed ({}):", open_errors.len()));
141 for (path, error) in &open_errors {
142 lines.push(format!("- {}: {}", path.display(), error));
143 }
144 }
145 if !dirty_check_errors.is_empty() {
146 lines.push(format!(
147 "Dirty check failed ({}):",
148 dirty_check_errors.len()
149 ));
150 for (path, error) in &dirty_check_errors {
151 lines.push(format!("- {}: {}", path.display(), error));
152 }
153 }
154 if !save_errors.is_empty() {
155 lines.push(format!("Save failed ({}):", save_errors.len()));
156 for (path, error) in &save_errors {
157 lines.push(format!("- {}: {}", path, error));
158 }
159 }
160
161 if lines.is_empty() {
162 Ok("No paths provided.".to_string())
163 } else {
164 Ok(lines.join("\n"))
165 }
166 })
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use fs::Fs;
174 use gpui::TestAppContext;
175 use project::FakeFs;
176 use serde_json::json;
177 use settings::SettingsStore;
178 use util::path;
179
180 fn init_test(cx: &mut TestAppContext) {
181 cx.update(|cx| {
182 let settings_store = SettingsStore::test(cx);
183 cx.set_global(settings_store);
184 });
185 }
186
187 #[gpui::test]
188 async fn test_save_file_output_and_effects(cx: &mut TestAppContext) {
189 init_test(cx);
190
191 let fs = FakeFs::new(cx.executor());
192 fs.insert_tree(
193 "/root",
194 json!({
195 "dirty.txt": "on disk: dirty\n",
196 "clean.txt": "on disk: clean\n",
197 }),
198 )
199 .await;
200
201 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
202 let tool = Arc::new(SaveFileTool::new(project.clone()));
203
204 // Make dirty.txt dirty in-memory.
205 let dirty_project_path = project.read_with(cx, |project, cx| {
206 project
207 .find_project_path("root/dirty.txt", cx)
208 .expect("dirty.txt should exist in project")
209 });
210
211 let dirty_buffer = project
212 .update(cx, |project, cx| {
213 project.open_buffer(dirty_project_path, cx)
214 })
215 .await
216 .unwrap();
217 dirty_buffer.update(cx, |buffer, cx| {
218 buffer.edit([(0..buffer.len(), "in memory: dirty\n")], None, cx);
219 });
220 assert!(
221 dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
222 "dirty.txt buffer should be dirty before save"
223 );
224
225 // Ensure clean.txt is opened but remains clean.
226 let clean_project_path = project.read_with(cx, |project, cx| {
227 project
228 .find_project_path("root/clean.txt", cx)
229 .expect("clean.txt should exist in project")
230 });
231
232 let clean_buffer = project
233 .update(cx, |project, cx| {
234 project.open_buffer(clean_project_path, cx)
235 })
236 .await
237 .unwrap();
238 assert!(
239 !clean_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
240 "clean.txt buffer should start clean"
241 );
242
243 let output = cx
244 .update(|cx| {
245 tool.clone().run(
246 SaveFileToolInput {
247 paths: vec![
248 PathBuf::from("root/dirty.txt"),
249 PathBuf::from("root/clean.txt"),
250 ],
251 },
252 ToolCallEventStream::test().0,
253 cx,
254 )
255 })
256 .await
257 .unwrap();
258
259 // Output should mention saved + clean.
260 assert!(
261 output.contains("Saved 1 file(s)."),
262 "expected saved count line, got:\n{output}"
263 );
264 assert!(
265 output.contains("1 clean."),
266 "expected clean count line, got:\n{output}"
267 );
268
269 // Effect: dirty buffer should now be clean and disk should have new content.
270 assert!(
271 !dirty_buffer.read_with(cx, |buffer, _| buffer.is_dirty()),
272 "dirty.txt buffer should not be dirty after save"
273 );
274
275 let disk_dirty = fs.load(path!("/root/dirty.txt").as_ref()).await.unwrap();
276 assert_eq!(
277 disk_dirty, "in memory: dirty\n",
278 "dirty.txt disk content should be updated"
279 );
280
281 // Sanity: clean buffer should remain clean and disk unchanged.
282 let disk_clean = fs.load(path!("/root/clean.txt").as_ref()).await.unwrap();
283 assert_eq!(disk_clean, "on disk: clean\n");
284
285 // Test empty paths case.
286 let output = cx
287 .update(|cx| {
288 tool.clone().run(
289 SaveFileToolInput { paths: vec![] },
290 ToolCallEventStream::test().0,
291 cx,
292 )
293 })
294 .await
295 .unwrap();
296 assert_eq!(output, "No paths provided.");
297
298 // Test not-found path case.
299 let output = cx
300 .update(|cx| {
301 tool.clone().run(
302 SaveFileToolInput {
303 paths: vec![PathBuf::from("nonexistent/path.txt")],
304 },
305 ToolCallEventStream::test().0,
306 cx,
307 )
308 })
309 .await
310 .unwrap();
311 assert!(
312 output.contains("Not found (1):"),
313 "expected not-found header line, got:\n{output}"
314 );
315 assert!(
316 output.contains("- nonexistent/path.txt"),
317 "expected not-found path bullet, got:\n{output}"
318 );
319 }
320}