edit_file_thread_test.rs

  1use super::*;
  2use crate::{AgentTool, EditFileTool, ReadFileTool};
  3use acp_thread::UserMessageId;
  4use fs::FakeFs;
  5use language_model::{
  6    LanguageModelCompletionEvent, LanguageModelToolUse, StopReason,
  7    fake_provider::FakeLanguageModel,
  8};
  9use prompt_store::ProjectContext;
 10use serde_json::json;
 11use std::{sync::Arc, time::Duration};
 12use util::path;
 13
 14#[gpui::test]
 15async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
 16    // This test verifies that the edit_file tool works correctly when invoked
 17    // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back).
 18    // This is different from tests that call tool.run() directly.
 19    super::init_test(cx);
 20    super::always_allow_tools(cx);
 21
 22    let fs = FakeFs::new(cx.executor());
 23    fs.insert_tree(
 24        path!("/project"),
 25        json!({
 26            "src": {
 27                "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}\n"
 28            }
 29        }),
 30    )
 31    .await;
 32
 33    let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
 34    let project_context = cx.new(|_cx| ProjectContext::default());
 35    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
 36    let context_server_registry =
 37        cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
 38    let model = Arc::new(FakeLanguageModel::default());
 39    let fake_model = model.as_fake();
 40
 41    let thread = cx.new(|cx| {
 42        let mut thread = crate::Thread::new(
 43            project.clone(),
 44            project_context,
 45            context_server_registry,
 46            crate::Templates::new(),
 47            Some(model.clone()),
 48            cx,
 49        );
 50        // Add just the tools we need for this test
 51        let language_registry = project.read(cx).languages().clone();
 52        thread.add_tool(crate::ReadFileTool::new(
 53            project.clone(),
 54            thread.action_log().clone(),
 55            true,
 56        ));
 57        thread.add_tool(crate::EditFileTool::new(
 58            project.clone(),
 59            cx.weak_entity(),
 60            language_registry,
 61            crate::Templates::new(),
 62        ));
 63        thread
 64    });
 65
 66    // First, read the file so the thread knows about its contents
 67    let _events = thread
 68        .update(cx, |thread, cx| {
 69            thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx)
 70        })
 71        .unwrap();
 72    cx.run_until_parked();
 73
 74    // Model calls read_file tool
 75    let read_tool_use = LanguageModelToolUse {
 76        id: "read_tool_1".into(),
 77        name: ReadFileTool::NAME.into(),
 78        raw_input: json!({"path": "project/src/main.rs"}).to_string(),
 79        input: json!({"path": "project/src/main.rs"}),
 80        is_input_complete: true,
 81        thought_signature: None,
 82    };
 83    fake_model
 84        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
 85    fake_model
 86        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
 87    fake_model.end_last_completion_stream();
 88    cx.run_until_parked();
 89
 90    // Wait for the read tool to complete and model to be called again
 91    while fake_model.pending_completions().is_empty() {
 92        cx.run_until_parked();
 93    }
 94
 95    // Model responds after seeing the file content, then calls edit_file
 96    fake_model.send_last_completion_stream_text_chunk("I'll edit the file now.");
 97    let edit_tool_use = LanguageModelToolUse {
 98        id: "edit_tool_1".into(),
 99        name: EditFileTool::NAME.into(),
100        raw_input: json!({
101            "display_description": "Change greeting message",
102            "path": "project/src/main.rs",
103            "mode": "edit"
104        })
105        .to_string(),
106        input: json!({
107            "display_description": "Change greeting message",
108            "path": "project/src/main.rs",
109            "mode": "edit"
110        }),
111        is_input_complete: true,
112        thought_signature: None,
113    };
114    fake_model
115        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
116    fake_model
117        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
118    fake_model.end_last_completion_stream();
119    cx.run_until_parked();
120
121    // The edit_file tool creates an EditAgent which makes its own model request.
122    // We need to respond to that request with the edit instructions.
123    // Wait for the edit agent's completion request
124    let deadline = std::time::Instant::now() + Duration::from_secs(5);
125    while fake_model.pending_completions().is_empty() {
126        if std::time::Instant::now() >= deadline {
127            panic!(
128                "Timed out waiting for edit agent completion request. Pending: {}",
129                fake_model.pending_completions().len()
130            );
131        }
132        cx.run_until_parked();
133        cx.background_executor
134            .timer(Duration::from_millis(10))
135            .await;
136    }
137
138    // Send the edit agent's response with the XML format it expects
139    let edit_response = "<old_text>println!(\"Hello, world!\");</old_text>\n<new_text>println!(\"Hello, Zed!\");</new_text>";
140    fake_model.send_last_completion_stream_text_chunk(edit_response);
141    fake_model.end_last_completion_stream();
142    cx.run_until_parked();
143
144    // Wait for the edit to complete and the thread to call the model again with tool results
145    let deadline = std::time::Instant::now() + Duration::from_secs(5);
146    while fake_model.pending_completions().is_empty() {
147        if std::time::Instant::now() >= deadline {
148            panic!("Timed out waiting for model to be called after edit completion");
149        }
150        cx.run_until_parked();
151        cx.background_executor
152            .timer(Duration::from_millis(10))
153            .await;
154    }
155
156    // Verify the file was edited
157    let file_content = fs
158        .load(path!("/project/src/main.rs").as_ref())
159        .await
160        .expect("file should exist");
161    assert!(
162        file_content.contains("Hello, Zed!"),
163        "File should have been edited. Content: {}",
164        file_content
165    );
166    assert!(
167        !file_content.contains("Hello, world!"),
168        "Old content should be replaced. Content: {}",
169        file_content
170    );
171
172    // Verify the tool result was sent back to the model
173    let pending = fake_model.pending_completions();
174    assert!(
175        !pending.is_empty(),
176        "Model should have been called with tool result"
177    );
178
179    let last_request = pending.last().unwrap();
180    let has_tool_result = last_request.messages.iter().any(|m| {
181        m.content
182            .iter()
183            .any(|c| matches!(c, language_model::MessageContent::ToolResult(_)))
184    });
185    assert!(
186        has_tool_result,
187        "Tool result should be in the messages sent back to the model"
188    );
189
190    // Complete the turn
191    fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message.");
192    fake_model
193        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
194    fake_model.end_last_completion_stream();
195    cx.run_until_parked();
196
197    // Verify the thread completed successfully
198    thread.update(cx, |thread, _cx| {
199        assert!(
200            thread.is_turn_complete(),
201            "Thread should be complete after the turn ends"
202        );
203    });
204}
205
206#[gpui::test]
207async fn test_streaming_edit_json_parse_error_does_not_cause_unsaved_changes(
208    cx: &mut TestAppContext,
209) {
210    super::init_test(cx);
211    super::always_allow_tools(cx);
212
213    // Enable the streaming edit file tool feature flag.
214    cx.update(|cx| {
215        cx.update_flags(true, vec!["streaming-edit-file-tool".to_string()]);
216    });
217
218    let fs = FakeFs::new(cx.executor());
219    fs.insert_tree(
220        path!("/project"),
221        json!({
222            "src": {
223                "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}\n"
224            }
225        }),
226    )
227    .await;
228
229    let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
230    let project_context = cx.new(|_cx| ProjectContext::default());
231    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
232    let context_server_registry =
233        cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
234    let model = Arc::new(FakeLanguageModel::default());
235    model.as_fake().set_supports_streaming_tools(true);
236    let fake_model = model.as_fake();
237
238    let thread = cx.new(|cx| {
239        let mut thread = crate::Thread::new(
240            project.clone(),
241            project_context,
242            context_server_registry,
243            crate::Templates::new(),
244            Some(model.clone()),
245            cx,
246        );
247        let language_registry = project.read(cx).languages().clone();
248        thread.add_tool(crate::StreamingEditFileTool::new(
249            project.clone(),
250            cx.weak_entity(),
251            thread.action_log().clone(),
252            language_registry,
253        ));
254        thread
255    });
256
257    let _events = thread
258        .update(cx, |thread, cx| {
259            thread.send(
260                UserMessageId::new(),
261                ["Write new content to src/main.rs"],
262                cx,
263            )
264        })
265        .unwrap();
266    cx.run_until_parked();
267
268    let tool_use_id = "edit_1";
269    let partial_1 = LanguageModelToolUse {
270        id: tool_use_id.into(),
271        name: EditFileTool::NAME.into(),
272        raw_input: json!({
273            "display_description": "Rewrite main.rs",
274            "path": "project/src/main.rs",
275            "mode": "write"
276        })
277        .to_string(),
278        input: json!({
279            "display_description": "Rewrite main.rs",
280            "path": "project/src/main.rs",
281            "mode": "write"
282        }),
283        is_input_complete: false,
284        thought_signature: None,
285    };
286    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_1));
287    cx.run_until_parked();
288
289    let partial_2 = LanguageModelToolUse {
290        id: tool_use_id.into(),
291        name: EditFileTool::NAME.into(),
292        raw_input: json!({
293            "display_description": "Rewrite main.rs",
294            "path": "project/src/main.rs",
295            "mode": "write",
296            "content": "fn main() { /* rewritten */ }"
297        })
298        .to_string(),
299        input: json!({
300            "display_description": "Rewrite main.rs",
301            "path": "project/src/main.rs",
302            "mode": "write",
303            "content": "fn main() { /* rewritten */ }"
304        }),
305        is_input_complete: false,
306        thought_signature: None,
307    };
308    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(partial_2));
309    cx.run_until_parked();
310
311    // Now send a json parse error. At this point we have started writing content to the buffer.
312    fake_model.send_last_completion_stream_event(
313        LanguageModelCompletionEvent::ToolUseJsonParseError {
314            id: tool_use_id.into(),
315            tool_name: EditFileTool::NAME.into(),
316            raw_input: r#"{"display_description":"Rewrite main.rs","path":"project/src/main.rs","mode":"write","content":"fn main() { /* rewritten "#.into(),
317            json_parse_error: "EOF while parsing a string at line 1 column 95".into(),
318        },
319    );
320    fake_model
321        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
322    fake_model.end_last_completion_stream();
323    cx.run_until_parked();
324
325    // cx.executor().advance_clock(Duration::from_secs(5));
326    // cx.run_until_parked();
327
328    assert!(
329        !fake_model.pending_completions().is_empty(),
330        "Thread should have retried after the error"
331    );
332
333    // Respond with a new, well-formed, complete edit_file tool use.
334    let tool_use = LanguageModelToolUse {
335        id: "edit_2".into(),
336        name: EditFileTool::NAME.into(),
337        raw_input: json!({
338            "display_description": "Rewrite main.rs",
339            "path": "project/src/main.rs",
340            "mode": "write",
341            "content": "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n"
342        })
343        .to_string(),
344        input: json!({
345            "display_description": "Rewrite main.rs",
346            "path": "project/src/main.rs",
347            "mode": "write",
348            "content": "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n"
349        }),
350        is_input_complete: true,
351        thought_signature: None,
352    };
353    fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use));
354    fake_model
355        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
356    fake_model.end_last_completion_stream();
357    cx.run_until_parked();
358
359    let pending_completions = fake_model.pending_completions();
360    assert!(
361        pending_completions.len() == 1,
362        "Expected only the follow-up completion containing the successful tool result"
363    );
364
365    let completion = pending_completions
366        .into_iter()
367        .last()
368        .expect("Expected a completion containing the tool result for edit_2");
369
370    let tool_result = completion
371        .messages
372        .iter()
373        .flat_map(|msg| &msg.content)
374        .find_map(|content| match content {
375            language_model::MessageContent::ToolResult(result)
376                if result.tool_use_id == language_model::LanguageModelToolUseId::from("edit_2") =>
377            {
378                Some(result)
379            }
380            _ => None,
381        })
382        .expect("Should have a tool result for edit_2");
383
384    // Ensure that the second tool call completed successfully and edits were applied.
385    assert!(
386        !tool_result.is_error,
387        "Tool result should succeed, got: {:?}",
388        tool_result
389    );
390    let content_text = match &tool_result.content {
391        language_model::LanguageModelToolResultContent::Text(t) => t.to_string(),
392        other => panic!("Expected text content, got: {:?}", other),
393    };
394    assert!(
395        !content_text.contains("file has been modified since you last read it"),
396        "Did not expect a stale last-read error, got: {content_text}"
397    );
398    assert!(
399        !content_text.contains("This file has unsaved changes"),
400        "Did not expect an unsaved-changes error, got: {content_text}"
401    );
402
403    let file_content = fs
404        .load(path!("/project/src/main.rs").as_ref())
405        .await
406        .expect("file should exist");
407    super::assert_eq!(
408        file_content,
409        "fn main() {\n    println!(\"Hello, rewritten!\");\n}\n",
410        "The second edit should be applied and saved gracefully"
411    );
412
413    fake_model.end_last_completion_stream();
414    cx.run_until_parked();
415}