1use super::*;
2use crate::{AgentTool, EditFileTool, ReadFileTool};
3use acp_thread::UserMessageId;
4use fs::FakeFs;
5use language_model::{
6 LanguageModelCompletionEvent, LanguageModelToolUse, StopReason,
7 fake_provider::FakeLanguageModel,
8};
9use prompt_store::ProjectContext;
10use serde_json::json;
11use std::{sync::Arc, time::Duration};
12use util::path;
13
14#[gpui::test]
15async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
16 // This test verifies that the edit_file tool works correctly when invoked
17 // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back).
18 // This is different from tests that call tool.run() directly.
19 super::init_test(cx);
20 super::always_allow_tools(cx);
21
22 let fs = FakeFs::new(cx.executor());
23 fs.insert_tree(
24 path!("/project"),
25 json!({
26 "src": {
27 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n"
28 }
29 }),
30 )
31 .await;
32
33 let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
34 let project_context = cx.new(|_cx| ProjectContext::default());
35 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
36 let context_server_registry =
37 cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
38 let model = Arc::new(FakeLanguageModel::default());
39 let fake_model = model.as_fake();
40
41 let thread = cx.new(|cx| {
42 let mut thread = crate::Thread::new(
43 project.clone(),
44 project_context,
45 context_server_registry,
46 crate::Templates::new(),
47 Some(model.clone()),
48 cx,
49 );
50 // Add just the tools we need for this test
51 let language_registry = project.read(cx).languages().clone();
52 thread.add_tool(crate::ReadFileTool::new(
53 cx.weak_entity(),
54 project.clone(),
55 thread.action_log().clone(),
56 ));
57 thread.add_tool(crate::EditFileTool::new(
58 project.clone(),
59 cx.weak_entity(),
60 language_registry,
61 crate::Templates::new(),
62 ));
63 thread
64 });
65
66 // First, read the file so the thread knows about its contents
67 let _events = thread
68 .update(cx, |thread, cx| {
69 thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx)
70 })
71 .unwrap();
72 cx.run_until_parked();
73
74 // Model calls read_file tool
75 let read_tool_use = LanguageModelToolUse {
76 id: "read_tool_1".into(),
77 name: ReadFileTool::NAME.into(),
78 raw_input: json!({"path": "project/src/main.rs"}).to_string(),
79 input: json!({"path": "project/src/main.rs"}),
80 is_input_complete: true,
81 thought_signature: None,
82 };
83 fake_model
84 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
85 fake_model
86 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
87 fake_model.end_last_completion_stream();
88 cx.run_until_parked();
89
90 // Wait for the read tool to complete and model to be called again
91 while fake_model.pending_completions().is_empty() {
92 cx.run_until_parked();
93 }
94
95 // Model responds after seeing the file content, then calls edit_file
96 fake_model.send_last_completion_stream_text_chunk("I'll edit the file now.");
97 let edit_tool_use = LanguageModelToolUse {
98 id: "edit_tool_1".into(),
99 name: EditFileTool::NAME.into(),
100 raw_input: json!({
101 "display_description": "Change greeting message",
102 "path": "project/src/main.rs",
103 "mode": "edit"
104 })
105 .to_string(),
106 input: json!({
107 "display_description": "Change greeting message",
108 "path": "project/src/main.rs",
109 "mode": "edit"
110 }),
111 is_input_complete: true,
112 thought_signature: None,
113 };
114 fake_model
115 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
116 fake_model
117 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
118 fake_model.end_last_completion_stream();
119 cx.run_until_parked();
120
121 // The edit_file tool creates an EditAgent which makes its own model request.
122 // We need to respond to that request with the edit instructions.
123 // Wait for the edit agent's completion request
124 let deadline = std::time::Instant::now() + Duration::from_secs(5);
125 while fake_model.pending_completions().is_empty() {
126 if std::time::Instant::now() >= deadline {
127 panic!(
128 "Timed out waiting for edit agent completion request. Pending: {}",
129 fake_model.pending_completions().len()
130 );
131 }
132 cx.run_until_parked();
133 cx.background_executor
134 .timer(Duration::from_millis(10))
135 .await;
136 }
137
138 // Send the edit agent's response with the XML format it expects
139 let edit_response = "<old_text>println!(\"Hello, world!\");</old_text>\n<new_text>println!(\"Hello, Zed!\");</new_text>";
140 fake_model.send_last_completion_stream_text_chunk(edit_response);
141 fake_model.end_last_completion_stream();
142 cx.run_until_parked();
143
144 // Wait for the edit to complete and the thread to call the model again with tool results
145 let deadline = std::time::Instant::now() + Duration::from_secs(5);
146 while fake_model.pending_completions().is_empty() {
147 if std::time::Instant::now() >= deadline {
148 panic!("Timed out waiting for model to be called after edit completion");
149 }
150 cx.run_until_parked();
151 cx.background_executor
152 .timer(Duration::from_millis(10))
153 .await;
154 }
155
156 // Verify the file was edited
157 let file_content = fs
158 .load(path!("/project/src/main.rs").as_ref())
159 .await
160 .expect("file should exist");
161 assert!(
162 file_content.contains("Hello, Zed!"),
163 "File should have been edited. Content: {}",
164 file_content
165 );
166 assert!(
167 !file_content.contains("Hello, world!"),
168 "Old content should be replaced. Content: {}",
169 file_content
170 );
171
172 // Verify the tool result was sent back to the model
173 let pending = fake_model.pending_completions();
174 assert!(
175 !pending.is_empty(),
176 "Model should have been called with tool result"
177 );
178
179 let last_request = pending.last().unwrap();
180 let has_tool_result = last_request.messages.iter().any(|m| {
181 m.content
182 .iter()
183 .any(|c| matches!(c, language_model::MessageContent::ToolResult(_)))
184 });
185 assert!(
186 has_tool_result,
187 "Tool result should be in the messages sent back to the model"
188 );
189
190 // Complete the turn
191 fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message.");
192 fake_model
193 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
194 fake_model.end_last_completion_stream();
195 cx.run_until_parked();
196
197 // Verify the thread completed successfully
198 thread.update(cx, |thread, _cx| {
199 assert!(
200 thread.is_turn_complete(),
201 "Thread should be complete after the turn ends"
202 );
203 });
204}