1use std::{path::Path, sync::Arc, time::Duration};
2
3use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
4use acp_thread::{
5 AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
6};
7use agentic_coding_protocol as acp_old;
8use futures::{FutureExt, StreamExt, channel::mpsc, select};
9use gpui::{Entity, TestAppContext};
10use indoc::indoc;
11use project::{FakeFs, Project};
12use serde_json::json;
13use settings::{Settings, SettingsStore};
14use util::path;
15
16pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
17 let fs = init_test(cx).await;
18 let project = Project::test(fs, [], cx).await;
19 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
20
21 thread
22 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
23 .await
24 .unwrap();
25
26 thread.read_with(cx, |thread, _| {
27 assert_eq!(thread.entries().len(), 2);
28 assert!(matches!(
29 thread.entries()[0],
30 AgentThreadEntry::UserMessage(_)
31 ));
32 assert!(matches!(
33 thread.entries()[1],
34 AgentThreadEntry::AssistantMessage(_)
35 ));
36 });
37}
38
39pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
40 let _fs = init_test(cx).await;
41
42 let tempdir = tempfile::tempdir().unwrap();
43 std::fs::write(
44 tempdir.path().join("foo.rs"),
45 indoc! {"
46 fn main() {
47 println!(\"Hello, world!\");
48 }
49 "},
50 )
51 .expect("failed to write file");
52 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
53 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
54 thread
55 .update(cx, |thread, cx| {
56 thread.send(
57 acp_old::SendUserMessageParams {
58 chunks: vec![
59 acp_old::UserMessageChunk::Text {
60 text: "Read the file ".into(),
61 },
62 acp_old::UserMessageChunk::Path {
63 path: Path::new("foo.rs").into(),
64 },
65 acp_old::UserMessageChunk::Text {
66 text: " and tell me what the content of the println! is".into(),
67 },
68 ],
69 },
70 cx,
71 )
72 })
73 .await
74 .unwrap();
75
76 thread.read_with(cx, |thread, cx| {
77 assert_eq!(thread.entries().len(), 3);
78 assert!(matches!(
79 thread.entries()[0],
80 AgentThreadEntry::UserMessage(_)
81 ));
82 assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
83 let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
84 panic!("Expected AssistantMessage")
85 };
86 assert!(
87 assistant_message.to_markdown(cx).contains("Hello, world!"),
88 "unexpected assistant message: {:?}",
89 assistant_message.to_markdown(cx)
90 );
91 });
92}
93
94pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
95 let fs = init_test(cx).await;
96 fs.insert_tree(
97 path!("/private/tmp"),
98 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
99 )
100 .await;
101 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
102 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
103
104 thread
105 .update(cx, |thread, cx| {
106 thread.send_raw(
107 "Read the '/private/tmp/foo' file and tell me what you see.",
108 cx,
109 )
110 })
111 .await
112 .unwrap();
113 thread.read_with(cx, |thread, _cx| {
114 assert!(thread.entries().iter().any(|entry| {
115 matches!(
116 entry,
117 AgentThreadEntry::ToolCall(ToolCall {
118 status: ToolCallStatus::Allowed { .. },
119 ..
120 })
121 )
122 }));
123 assert!(
124 thread
125 .entries()
126 .iter()
127 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
128 );
129 });
130}
131
132pub async fn test_tool_call_with_confirmation(
133 server: impl AgentServer + 'static,
134 cx: &mut TestAppContext,
135) {
136 let fs = init_test(cx).await;
137 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
138 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
139 let full_turn = thread.update(cx, |thread, cx| {
140 thread.send_raw(
141 r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
142 cx,
143 )
144 });
145
146 run_until_first_tool_call(
147 &thread,
148 |entry| {
149 matches!(
150 entry,
151 AgentThreadEntry::ToolCall(ToolCall {
152 status: ToolCallStatus::WaitingForConfirmation { .. },
153 ..
154 })
155 )
156 },
157 cx,
158 )
159 .await;
160
161 let tool_call_id = thread.read_with(cx, |thread, _cx| {
162 let AgentThreadEntry::ToolCall(ToolCall {
163 id,
164 content: Some(content),
165 status: ToolCallStatus::WaitingForConfirmation { .. },
166 ..
167 }) = &thread
168 .entries()
169 .iter()
170 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
171 .unwrap()
172 else {
173 panic!();
174 };
175
176 assert!(content.to_markdown(cx).contains("touch"));
177
178 *id
179 });
180
181 thread.update(cx, |thread, cx| {
182 thread.authorize_tool_call(
183 tool_call_id,
184 acp_old::ToolCallConfirmationOutcome::Allow,
185 cx,
186 );
187
188 assert!(thread.entries().iter().any(|entry| matches!(
189 entry,
190 AgentThreadEntry::ToolCall(ToolCall {
191 status: ToolCallStatus::Allowed { .. },
192 ..
193 })
194 )));
195 });
196
197 full_turn.await.unwrap();
198
199 thread.read_with(cx, |thread, cx| {
200 let AgentThreadEntry::ToolCall(ToolCall {
201 content: Some(ToolCallContent::Markdown { markdown }),
202 status: ToolCallStatus::Allowed { .. },
203 ..
204 }) = thread
205 .entries()
206 .iter()
207 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
208 .unwrap()
209 else {
210 panic!();
211 };
212
213 markdown.read_with(cx, |md, _cx| {
214 assert!(
215 md.source().contains("Hello"),
216 r#"Expected '{}' to contain "Hello""#,
217 md.source()
218 );
219 });
220 });
221}
222
223pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
224 let fs = init_test(cx).await;
225
226 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
227 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
228 let full_turn = thread.update(cx, |thread, cx| {
229 thread.send_raw(
230 r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
231 cx,
232 )
233 });
234
235 let first_tool_call_ix = run_until_first_tool_call(
236 &thread,
237 |entry| {
238 matches!(
239 entry,
240 AgentThreadEntry::ToolCall(ToolCall {
241 status: ToolCallStatus::WaitingForConfirmation { .. },
242 ..
243 })
244 )
245 },
246 cx,
247 )
248 .await;
249
250 thread.read_with(cx, |thread, _cx| {
251 let AgentThreadEntry::ToolCall(ToolCall {
252 id,
253 content: Some(content),
254 status: ToolCallStatus::WaitingForConfirmation { .. },
255 ..
256 }) = &thread.entries()[first_tool_call_ix]
257 else {
258 panic!("{:?}", thread.entries()[1]);
259 };
260
261 assert!(content.to_markdown(cx).contains("touch"));
262
263 *id
264 });
265
266 thread
267 .update(cx, |thread, cx| thread.cancel(cx))
268 .await
269 .unwrap();
270 full_turn.await.unwrap();
271 thread.read_with(cx, |thread, _| {
272 let AgentThreadEntry::ToolCall(ToolCall {
273 status: ToolCallStatus::Canceled,
274 ..
275 }) = &thread.entries()[first_tool_call_ix]
276 else {
277 panic!();
278 };
279 });
280
281 thread
282 .update(cx, |thread, cx| {
283 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
284 })
285 .await
286 .unwrap();
287 thread.read_with(cx, |thread, _| {
288 assert!(matches!(
289 &thread.entries().last().unwrap(),
290 AgentThreadEntry::AssistantMessage(..),
291 ))
292 });
293}
294
295#[macro_export]
296macro_rules! common_e2e_tests {
297 ($server:expr) => {
298 mod common_e2e {
299 use super::*;
300
301 #[::gpui::test]
302 #[cfg_attr(not(feature = "e2e"), ignore)]
303 async fn basic(cx: &mut ::gpui::TestAppContext) {
304 $crate::e2e_tests::test_basic($server, cx).await;
305 }
306
307 #[::gpui::test]
308 #[cfg_attr(not(feature = "e2e"), ignore)]
309 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
310 $crate::e2e_tests::test_path_mentions($server, cx).await;
311 }
312
313 #[::gpui::test]
314 #[cfg_attr(not(feature = "e2e"), ignore)]
315 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
316 $crate::e2e_tests::test_tool_call($server, cx).await;
317 }
318
319 #[::gpui::test]
320 #[cfg_attr(not(feature = "e2e"), ignore)]
321 async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
322 $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
323 }
324
325 #[::gpui::test]
326 #[cfg_attr(not(feature = "e2e"), ignore)]
327 async fn cancel(cx: &mut ::gpui::TestAppContext) {
328 $crate::e2e_tests::test_cancel($server, cx).await;
329 }
330 }
331 };
332}
333
334// Helpers
335
336pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
337 env_logger::try_init().ok();
338
339 cx.update(|cx| {
340 let settings_store = SettingsStore::test(cx);
341 cx.set_global(settings_store);
342 Project::init_settings(cx);
343 language::init(cx);
344 crate::settings::init(cx);
345
346 crate::AllAgentServersSettings::override_global(
347 AllAgentServersSettings {
348 claude: Some(AgentServerSettings {
349 command: crate::claude::tests::local_command(),
350 }),
351 codex: Some(AgentServerSettings {
352 command: crate::codex::tests::local_command(),
353 }),
354 gemini: Some(AgentServerSettings {
355 command: crate::gemini::tests::local_command(),
356 }),
357 },
358 cx,
359 );
360 });
361
362 cx.executor().allow_parking();
363
364 FakeFs::new(cx.executor())
365}
366
367pub async fn new_test_thread(
368 server: impl AgentServer + 'static,
369 project: Entity<Project>,
370 current_dir: impl AsRef<Path>,
371 cx: &mut TestAppContext,
372) -> Entity<AcpThread> {
373 let thread = cx
374 .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
375 .await
376 .unwrap();
377
378 thread
379 .update(cx, |thread, _| thread.initialize())
380 .await
381 .unwrap();
382 thread
383}
384
385pub async fn run_until_first_tool_call(
386 thread: &Entity<AcpThread>,
387 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
388 cx: &mut TestAppContext,
389) -> usize {
390 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
391
392 let subscription = cx.update(|cx| {
393 cx.subscribe(thread, move |thread, _, cx| {
394 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
395 if wait_until(entry) {
396 return tx.try_send(ix).unwrap();
397 }
398 }
399 })
400 });
401
402 select! {
403 // We have to use a smol timer here because
404 // cx.background_executor().timer isn't real in the test context
405 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
406 panic!("Timeout waiting for tool call")
407 }
408 ix = rx.next().fuse() => {
409 drop(subscription);
410 ix.unwrap()
411 }
412 }
413}