1use super::*;
2use crate::{AgentTool, EditFileTool, ReadFileTool};
3use acp_thread::UserMessageId;
4use fs::FakeFs;
5use language_model::{
6 LanguageModelCompletionEvent, LanguageModelToolUse, StopReason,
7 fake_provider::FakeLanguageModel,
8};
9use prompt_store::ProjectContext;
10use serde_json::json;
11use std::{sync::Arc, time::Duration};
12use util::path;
13
14#[gpui::test]
15async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
16 // This test verifies that the edit_file tool works correctly when invoked
17 // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back).
18 // This is different from tests that call tool.run() directly.
19 super::init_test(cx);
20 super::always_allow_tools(cx);
21
22 let fs = FakeFs::new(cx.executor());
23 fs.insert_tree(
24 path!("/project"),
25 json!({
26 "src": {
27 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n"
28 }
29 }),
30 )
31 .await;
32
33 let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
34 let project_context = cx.new(|_cx| ProjectContext::default());
35 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
36 let context_server_registry =
37 cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
38 let model = Arc::new(FakeLanguageModel::default());
39 let fake_model = model.as_fake();
40
41 let thread = cx.new(|cx| {
42 let mut thread = crate::Thread::new(
43 project.clone(),
44 project_context,
45 context_server_registry,
46 crate::Templates::new(),
47 Some(model.clone()),
48 cx,
49 );
50 // Add just the tools we need for this test
51 let language_registry = project.read(cx).languages().clone();
52 thread.add_tool(crate::ReadFileTool::new(
53 project.clone(),
54 thread.action_log().clone(),
55 true,
56 ));
57 thread.add_tool(crate::EditFileTool::new(
58 project.clone(),
59 cx.weak_entity(),
60 language_registry,
61 crate::Templates::new(),
62 ));
63 thread
64 });
65
66 // First, read the file so the thread knows about its contents
67 let _events = thread
68 .update(cx, |thread, cx| {
69 thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx)
70 })
71 .unwrap();
72 cx.run_until_parked();
73
74 // Model calls read_file tool
75 let read_tool_use = LanguageModelToolUse {
76 id: "read_tool_1".into(),
77 name: ReadFileTool::NAME.into(),
78 raw_input: json!({"path": "project/src/main.rs"}).to_string(),
79 input: json!({"path": "project/src/main.rs"}),
80 is_input_complete: true,
81 thought_signature: None,
82 };
83 fake_model
84 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
85 fake_model
86 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
87 fake_model.end_last_completion_stream();
88 cx.run_until_parked();
89
90 // Wait for the read tool to complete and model to be called again
91 while fake_model.pending_completions().is_empty() {
92 cx.run_until_parked();
93 }
94
95 // Model responds after seeing the file content, then calls edit_file
96 fake_model.send_last_completion_stream_text_chunk("I'll edit the file now.");
97 let edit_tool_use = LanguageModelToolUse {
98 id: "edit_tool_1".into(),
99 name: EditFileTool::NAME.into(),
100 raw_input: json!({
101 "display_description": "Change greeting message",
102 "path": "project/src/main.rs",
103 "mode": "edit"
104 })
105 .to_string(),
106 input: json!({
107 "display_description": "Change greeting message",
108 "path": "project/src/main.rs",
109 "mode": "edit"
110 }),
111 is_input_complete: true,
112 thought_signature: None,
113 };
114 fake_model
115 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
116 fake_model
117 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
118 fake_model.end_last_completion_stream();
119 cx.run_until_parked();
120
121 // The edit_file tool creates an EditAgent which makes its own model request.
122 // We need to respond to that request with the edit instructions.
123 // Wait for the edit agent's completion request
124 let deadline = std::time::Instant::now() + Duration::from_secs(5);
125 while fake_model.pending_completions().is_empty() {
126 if std::time::Instant::now() >= deadline {
127 panic!(
128 "Timed out waiting for edit agent completion request. Pending: {}",
129 fake_model.pending_completions().len()
130 );
131 }
132 cx.run_until_parked();
133 cx.background_executor
134 .timer(Duration::from_millis(10))
135 .await;
136 }
137
138 // Send the edit agent's response with the XML format it expects
139 let edit_response = "<old_text>println!(\"Hello, world!\");</old_text>\n<new_text>println!(\"Hello, Zed!\");</new_text>";
140 fake_model.send_last_completion_stream_text_chunk(edit_response);
141 fake_model.end_last_completion_stream();
142 cx.run_until_parked();
143
144 // Wait for the edit to complete and the thread to call the model again with tool results
145 let deadline = std::time::Instant::now() + Duration::from_secs(5);
146 while fake_model.pending_completions().is_empty() {
147 if std::time::Instant::now() >= deadline {
148 panic!("Timed out waiting for model to be called after edit completion");
149 }
150 cx.run_until_parked();
151 cx.background_executor
152 .timer(Duration::from_millis(10))
153 .await;
154 }
155
156 // Verify the file was edited
157 let file_content = fs
158 .load(path!("/project/src/main.rs").as_ref())
159 .await
160 .expect("file should exist");
161 assert!(
162 file_content.contains("Hello, Zed!"),
163 "File should have been edited. Content: {}",
164 file_content
165 );
166 assert!(
167 !file_content.contains("Hello, world!"),
168 "Old content should be replaced. Content: {}",
169 file_content
170 );
171
172 // Verify the tool result was sent back to the model
173 let pending = fake_model.pending_completions();
174 assert!(
175 !pending.is_empty(),
176 "Model should have been called with tool result"
177 );
178
179 let last_request = pending.last().unwrap();
180 let has_tool_result = last_request.messages.iter().any(|m| {
181 m.content
182 .iter()
183 .any(|c| matches!(c, language_model::MessageContent::ToolResult(_)))
184 });
185 assert!(
186 has_tool_result,
187 "Tool result should be in the messages sent back to the model"
188 );
189
190 // Complete the turn
191 fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message.");
192 fake_model
193 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
194 fake_model.end_last_completion_stream();
195 cx.run_until_parked();
196
197 // Verify the thread completed successfully
198 thread.update(cx, |thread, _cx| {
199 assert!(
200 thread.is_turn_complete(),
201 "Thread should be complete after the turn ends"
202 );
203 });
204}
205
206#[gpui::test]
207async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes(
208 cx: &mut TestAppContext,
209) {
210 super::init_test(cx);
211 super::always_allow_tools(cx);
212
213 // Enable the streaming edit file tool feature flag.
214 cx.update(|cx| {
215 cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]);
216 });
217
218 let fs = FakeFs::new(cx.executor());
219 fs.insert_tree(
220 path!("/project"),
221 json!({
222 "src": {
223 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n"
224 }
225 }),
226 )
227 .await;
228
229 let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
230 let project_context = cx.new(|_cx| ProjectContext::default());
231 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
232 let context_server_registry =
233 cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
234 let model = Arc::new(FakeLanguageModel::default());
235 model.as_fake().set_supports_streaming_tools(true);
236 let fake_model = model.as_fake();
237
238 let thread = cx.new(|cx| {
239 let mut thread = crate::Thread::new(
240 project.clone(),
241 project_context,
242 context_server_registry,
243 crate::Templates::new(),
244 Some(model.clone()),
245 cx,
246 );
247 let language_registry = project.read(cx).languages().clone();
248 thread.add_tool(crate::StreamingEditFileTool::new(
249 project.clone(),
250 cx.weak_entity(),
251 thread.action_log().clone(),
252 language_registry,
253 ));
254 thread
255 });
256
257 let _events = thread
258 .update(cx, |thread, cx| {
259 thread.send(
260 UserMessageId::new(),
261 ["Write new content to src/main.rs"],
262 cx,
263 )
264 })
265 .unwrap();
266 cx.run_until_parked();
267
268 let tool_use_id = "edit_1";
269 let partial_1 = LanguageModelToolUse {
270 id: tool_use_id.into(),
271 name: EditFileTool::NAME.into(),
272 raw_input: json!({
273 "display_description": "Rewrite main.rs",
274 "path": "project/src/main.rs",
275 "mode": "write"
276 })
277 .to_string(),
278 input: json!({
279 "display_description": "Rewrite main.rs",
280 "path": "project/src/main.rs",
281 "mode": "write"
282 }),
283 is_input_complete: false,
284 thought_signature: None,
285 };
286 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1));
287 cx.run_until_parked();
288
289 let partial_2 = LanguageModelToolUse {
290 id: tool_use_id.into(),
291 name: EditFileTool::NAME.into(),
292 raw_input: json!({
293 "display_description": "Rewrite main.rs",
294 "path": "project/src/main.rs",
295 "mode": "write",
296 "content": "fn main() { /* rewritten */ }"
297 })
298 .to_string(),
299 input: json!({
300 "display_description": "Rewrite main.rs",
301 "path": "project/src/main.rs",
302 "mode": "write",
303 "content": "fn main() { /* rewritten */ }"
304 }),
305 is_input_complete: false,
306 thought_signature: None,
307 };
308 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2));
309 cx.run_until_parked();
310
311 // Now send a json parse error. At this point we have started writing content to the buffer.
312 fake_model.send_last_completion_stream_event(
313 LanguageModelCompletionEvent::ToolUseJsonParseError {
314 id: tool_use_id.into(),
315 tool_name: EditFileTool::NAME.into(),
316 raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(),
317 json_parse_error: "EOF while parsing a string at line 1 column 95".into(),
318 },
319 );
320 fake_model
321 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
322 fake_model.end_last_completion_stream();
323 cx.run_until_parked();
324
325 // cx.executor().advance_clock(Duration::from_secs(5));
326 // cx.run_until_parked();
327
328 assert!(
329 !fake_model.pending_completions().is_empty(),
330 "Thread should have retried after the error"
331 );
332
333 // Respond with a new, well-formed, complete edit_file tool use.
334 let tool_use = LanguageModelToolUse {
335 id: "edit_2".into(),
336 name: EditFileTool::NAME.into(),
337 raw_input: json!({
338 "display_description": "Rewrite main.rs",
339 "path": "project/src/main.rs",
340 "mode": "write",
341 "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n"
342 })
343 .to_string(),
344 input: json!({
345 "display_description": "Rewrite main.rs",
346 "path": "project/src/main.rs",
347 "mode": "write",
348 "content": "fn main() {\n println!(\"Hello, rewritten!\");\n}\n"
349 }),
350 is_input_complete: true,
351 thought_signature: None,
352 };
353 fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
354 fake_model
355 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
356 fake_model.end_last_completion_stream();
357 cx.run_until_parked();
358
359 let pending_completions = fake_model.pending_completions();
360 assert!(
361 pending_completions.len() == 1,
362 "Expected only the follow-up completion containing the successful tool result"
363 );
364
365 let completion = pending_completions
366 .into_iter()
367 .last()
368 .expect("Expected a completion containing the tool result for edit_2");
369
370 let tool_result = completion
371 .messages
372 .iter()
373 .flat_map(|msg| &msg.content)
374 .find_map(|content| match content {
375 language_model::MessageContent::ToolResult(result)
376 if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") =>
377 {
378 Some(result)
379 }
380 _ => None,
381 })
382 .expect("Should have a tool result for edit_2");
383
384 // Ensure that the second tool call completed successfully and edits were applied.
385 assert!(
386 !tool_result.is_error,
387 "Tool result should succeed, got: {:?}",
388 tool_result
389 );
390 let content_text = match &tool_result.content {
391 language_model::LanguageModelToolResultContent::Text(t) => t.to_string(),
392 other => panic!("Expected text content, got: {:?}", other),
393 };
394 assert!(
395 !content_text.contains("file has been modified since you last read it"),
396 "Did not expect a stale last-read error, got: {content_text}"
397 );
398 assert!(
399 !content_text.contains("This file has unsaved changes"),
400 "Did not expect an unsaved-changes error, got: {content_text}"
401 );
402
403 let file_content = fs
404 .load(path!("/project/src/main.rs").as_ref())
405 .await
406 .expect("file should exist");
407 super::assert_eq!(
408 file_content,
409 "fn main() {\n println!(\"Hello, rewritten!\");\n}\n",
410 "The second edit should be applied and saved gracefully"
411 );
412
413 fake_model.end_last_completion_stream();
414 cx.run_until_parked();
415}