edit_file_thread_test.rs

  1use super::*;
  2use crate::{AgentTool, EditFileTool, ReadFileTool};
  3use acp_thread::UserMessageId;
  4use fs::FakeFs;
  5use language_model::{
  6    LanguageModelCompletionEvent, LanguageModelToolUse, StopReason,
  7    fake_provider::FakeLanguageModel,
  8};
  9use prompt_store::ProjectContext;
 10use serde_json::json;
 11use std::{sync::Arc, time::Duration};
 12use util::path;
 13
 14#[gpui::test]
 15async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
 16    // This test verifies that the edit_file tool works correctly when invoked
 17    // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back).
 18    // This is different from tests that call tool.run() directly.
 19    super::init_test(cx);
 20    super::always_allow_tools(cx);
 21
 22    let fs = FakeFs::new(cx.executor());
 23    fs.insert_tree(
 24        path!("/project"),
 25        json!({
 26            "src": {
 27                "main.rs": "fn main() {\n    println!(\"Hello, world!\");\n}\n"
 28            }
 29        }),
 30    )
 31    .await;
 32
 33    let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
 34    let project_context = cx.new(|_cx| ProjectContext::default());
 35    let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
 36    let context_server_registry =
 37        cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
 38    let model = Arc::new(FakeLanguageModel::default());
 39    let fake_model = model.as_fake();
 40
 41    let thread = cx.new(|cx| {
 42        let mut thread = crate::Thread::new(
 43            project.clone(),
 44            project_context,
 45            context_server_registry,
 46            crate::Templates::new(),
 47            Some(model.clone()),
 48            cx,
 49        );
 50        // Add just the tools we need for this test
 51        let language_registry = project.read(cx).languages().clone();
 52        thread.add_tool(
 53            crate::ReadFileTool::new(
 54                cx.weak_entity(),
 55                project.clone(),
 56                thread.action_log().clone(),
 57            ),
 58            None,
 59        );
 60        thread.add_tool(
 61            crate::EditFileTool::new(
 62                project.clone(),
 63                cx.weak_entity(),
 64                language_registry,
 65                crate::Templates::new(),
 66            ),
 67            None,
 68        );
 69        thread
 70    });
 71
 72    // First, read the file so the thread knows about its contents
 73    let _events = thread
 74        .update(cx, |thread, cx| {
 75            thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx)
 76        })
 77        .unwrap();
 78    cx.run_until_parked();
 79
 80    // Model calls read_file tool
 81    let read_tool_use = LanguageModelToolUse {
 82        id: "read_tool_1".into(),
 83        name: ReadFileTool::NAME.into(),
 84        raw_input: json!({"path": "project/src/main.rs"}).to_string(),
 85        input: json!({"path": "project/src/main.rs"}),
 86        is_input_complete: true,
 87        thought_signature: None,
 88    };
 89    fake_model
 90        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
 91    fake_model
 92        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
 93    fake_model.end_last_completion_stream();
 94    cx.run_until_parked();
 95
 96    // Wait for the read tool to complete and model to be called again
 97    while fake_model.pending_completions().is_empty() {
 98        cx.run_until_parked();
 99    }
100
101    // Model responds after seeing the file content, then calls edit_file
102    fake_model.send_last_completion_stream_text_chunk("I'll edit the file now.");
103    let edit_tool_use = LanguageModelToolUse {
104        id: "edit_tool_1".into(),
105        name: EditFileTool::NAME.into(),
106        raw_input: json!({
107            "display_description": "Change greeting message",
108            "path": "project/src/main.rs",
109            "mode": "edit"
110        })
111        .to_string(),
112        input: json!({
113            "display_description": "Change greeting message",
114            "path": "project/src/main.rs",
115            "mode": "edit"
116        }),
117        is_input_complete: true,
118        thought_signature: None,
119    };
120    fake_model
121        .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
122    fake_model
123        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
124    fake_model.end_last_completion_stream();
125    cx.run_until_parked();
126
127    // The edit_file tool creates an EditAgent which makes its own model request.
128    // We need to respond to that request with the edit instructions.
129    // Wait for the edit agent's completion request
130    let deadline = std::time::Instant::now() + Duration::from_secs(5);
131    while fake_model.pending_completions().is_empty() {
132        if std::time::Instant::now() >= deadline {
133            panic!(
134                "Timed out waiting for edit agent completion request. Pending: {}",
135                fake_model.pending_completions().len()
136            );
137        }
138        cx.run_until_parked();
139        cx.background_executor
140            .timer(Duration::from_millis(10))
141            .await;
142    }
143
144    // Send the edit agent's response with the XML format it expects
145    let edit_response = "<old_text>println!(\"Hello, world!\");</old_text>\n<new_text>println!(\"Hello, Zed!\");</new_text>";
146    fake_model.send_last_completion_stream_text_chunk(edit_response);
147    fake_model.end_last_completion_stream();
148    cx.run_until_parked();
149
150    // Wait for the edit to complete and the thread to call the model again with tool results
151    let deadline = std::time::Instant::now() + Duration::from_secs(5);
152    while fake_model.pending_completions().is_empty() {
153        if std::time::Instant::now() >= deadline {
154            panic!("Timed out waiting for model to be called after edit completion");
155        }
156        cx.run_until_parked();
157        cx.background_executor
158            .timer(Duration::from_millis(10))
159            .await;
160    }
161
162    // Verify the file was edited
163    let file_content = fs
164        .load(path!("/project/src/main.rs").as_ref())
165        .await
166        .expect("file should exist");
167    assert!(
168        file_content.contains("Hello, Zed!"),
169        "File should have been edited. Content: {}",
170        file_content
171    );
172    assert!(
173        !file_content.contains("Hello, world!"),
174        "Old content should be replaced. Content: {}",
175        file_content
176    );
177
178    // Verify the tool result was sent back to the model
179    let pending = fake_model.pending_completions();
180    assert!(
181        !pending.is_empty(),
182        "Model should have been called with tool result"
183    );
184
185    let last_request = pending.last().unwrap();
186    let has_tool_result = last_request.messages.iter().any(|m| {
187        m.content
188            .iter()
189            .any(|c| matches!(c, language_model::MessageContent::ToolResult(_)))
190    });
191    assert!(
192        has_tool_result,
193        "Tool result should be in the messages sent back to the model"
194    );
195
196    // Complete the turn
197    fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message.");
198    fake_model
199        .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
200    fake_model.end_last_completion_stream();
201    cx.run_until_parked();
202
203    // Verify the thread completed successfully
204    thread.update(cx, |thread, _cx| {
205        assert!(
206            thread.is_turn_complete(),
207            "Thread should be complete after the turn ends"
208        );
209    });
210}