1use super::*;
2use crate::{AgentTool, EditFileTool, ReadFileTool};
3use acp_thread::UserMessageId;
4use fs::FakeFs;
5use language_model::{
6 LanguageModelCompletionEvent, LanguageModelToolUse, StopReason,
7 fake_provider::FakeLanguageModel,
8};
9use prompt_store::ProjectContext;
10use serde_json::json;
11use std::{sync::Arc, time::Duration};
12use util::path;
13
14#[gpui::test]
15async fn test_edit_file_tool_in_thread_context(cx: &mut TestAppContext) {
16 // This test verifies that the edit_file tool works correctly when invoked
17 // through the full thread flow (model sends ToolUse event -> tool runs -> result sent back).
18 // This is different from tests that call tool.run() directly.
19 super::init_test(cx);
20 super::always_allow_tools(cx);
21
22 let fs = FakeFs::new(cx.executor());
23 fs.insert_tree(
24 path!("/project"),
25 json!({
26 "src": {
27 "main.rs": "fn main() {\n println!(\"Hello, world!\");\n}\n"
28 }
29 }),
30 )
31 .await;
32
33 let project = project::Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
34 let project_context = cx.new(|_cx| ProjectContext::default());
35 let context_server_store = project.read_with(cx, |project, _| project.context_server_store());
36 let context_server_registry =
37 cx.new(|cx| crate::ContextServerRegistry::new(context_server_store.clone(), cx));
38 let model = Arc::new(FakeLanguageModel::default());
39 let fake_model = model.as_fake();
40
41 let thread = cx.new(|cx| {
42 let mut thread = crate::Thread::new(
43 project.clone(),
44 project_context,
45 context_server_registry,
46 crate::Templates::new(),
47 Some(model.clone()),
48 cx,
49 );
50 // Add just the tools we need for this test
51 let language_registry = project.read(cx).languages().clone();
52 thread.add_tool(
53 crate::ReadFileTool::new(
54 cx.weak_entity(),
55 project.clone(),
56 thread.action_log().clone(),
57 ),
58 None,
59 );
60 thread.add_tool(
61 crate::EditFileTool::new(
62 project.clone(),
63 cx.weak_entity(),
64 language_registry,
65 crate::Templates::new(),
66 ),
67 None,
68 );
69 thread
70 });
71
72 // First, read the file so the thread knows about its contents
73 let _events = thread
74 .update(cx, |thread, cx| {
75 thread.send(UserMessageId::new(), ["Read the file src/main.rs"], cx)
76 })
77 .unwrap();
78 cx.run_until_parked();
79
80 // Model calls read_file tool
81 let read_tool_use = LanguageModelToolUse {
82 id: "read_tool_1".into(),
83 name: ReadFileTool::NAME.into(),
84 raw_input: json!({"path": "project/src/main.rs"}).to_string(),
85 input: json!({"path": "project/src/main.rs"}),
86 is_input_complete: true,
87 thought_signature: None,
88 };
89 fake_model
90 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(read_tool_use));
91 fake_model
92 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
93 fake_model.end_last_completion_stream();
94 cx.run_until_parked();
95
96 // Wait for the read tool to complete and model to be called again
97 while fake_model.pending_completions().is_empty() {
98 cx.run_until_parked();
99 }
100
101 // Model responds after seeing the file content, then calls edit_file
102 fake_model.send_last_completion_stream_text_chunk("I'll edit the file now.");
103 let edit_tool_use = LanguageModelToolUse {
104 id: "edit_tool_1".into(),
105 name: EditFileTool::NAME.into(),
106 raw_input: json!({
107 "display_description": "Change greeting message",
108 "path": "project/src/main.rs",
109 "mode": "edit"
110 })
111 .to_string(),
112 input: json!({
113 "display_description": "Change greeting message",
114 "path": "project/src/main.rs",
115 "mode": "edit"
116 }),
117 is_input_complete: true,
118 thought_signature: None,
119 };
120 fake_model
121 .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(edit_tool_use));
122 fake_model
123 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::ToolUse));
124 fake_model.end_last_completion_stream();
125 cx.run_until_parked();
126
127 // The edit_file tool creates an EditAgent which makes its own model request.
128 // We need to respond to that request with the edit instructions.
129 // Wait for the edit agent's completion request
130 let deadline = std::time::Instant::now() + Duration::from_secs(5);
131 while fake_model.pending_completions().is_empty() {
132 if std::time::Instant::now() >= deadline {
133 panic!(
134 "Timed out waiting for edit agent completion request. Pending: {}",
135 fake_model.pending_completions().len()
136 );
137 }
138 cx.run_until_parked();
139 cx.background_executor
140 .timer(Duration::from_millis(10))
141 .await;
142 }
143
144 // Send the edit agent's response with the XML format it expects
145 let edit_response = "<old_text>println!(\"Hello, world!\");</old_text>\n<new_text>println!(\"Hello, Zed!\");</new_text>";
146 fake_model.send_last_completion_stream_text_chunk(edit_response);
147 fake_model.end_last_completion_stream();
148 cx.run_until_parked();
149
150 // Wait for the edit to complete and the thread to call the model again with tool results
151 let deadline = std::time::Instant::now() + Duration::from_secs(5);
152 while fake_model.pending_completions().is_empty() {
153 if std::time::Instant::now() >= deadline {
154 panic!("Timed out waiting for model to be called after edit completion");
155 }
156 cx.run_until_parked();
157 cx.background_executor
158 .timer(Duration::from_millis(10))
159 .await;
160 }
161
162 // Verify the file was edited
163 let file_content = fs
164 .load(path!("/project/src/main.rs").as_ref())
165 .await
166 .expect("file should exist");
167 assert!(
168 file_content.contains("Hello, Zed!"),
169 "File should have been edited. Content: {}",
170 file_content
171 );
172 assert!(
173 !file_content.contains("Hello, world!"),
174 "Old content should be replaced. Content: {}",
175 file_content
176 );
177
178 // Verify the tool result was sent back to the model
179 let pending = fake_model.pending_completions();
180 assert!(
181 !pending.is_empty(),
182 "Model should have been called with tool result"
183 );
184
185 let last_request = pending.last().unwrap();
186 let has_tool_result = last_request.messages.iter().any(|m| {
187 m.content
188 .iter()
189 .any(|c| matches!(c, language_model::MessageContent::ToolResult(_)))
190 });
191 assert!(
192 has_tool_result,
193 "Tool result should be in the messages sent back to the model"
194 );
195
196 // Complete the turn
197 fake_model.send_last_completion_stream_text_chunk("I've updated the greeting message.");
198 fake_model
199 .send_last_completion_stream_event(LanguageModelCompletionEvent::Stop(StopReason::EndTurn));
200 fake_model.end_last_completion_stream();
201 cx.run_until_parked();
202
203 // Verify the thread completed successfully
204 thread.update(cx, |thread, _cx| {
205 assert!(
206 thread.is_turn_complete(),
207 "Thread should be complete after the turn ends"
208 );
209 });
210}