1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::{AppContext, Entity, TestAppContext};
6use indoc::indoc;
7#[cfg(test)]
8use project::agent_server_store::BuiltinAgentServerSettings;
9use project::{FakeFs, Project};
10#[cfg(test)]
11use settings::Settings;
12use std::{
13 path::{Path, PathBuf},
14 sync::Arc,
15 time::Duration,
16};
17use util::path;
18
19pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
20where
21 T: AgentServer + 'static,
22 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
23{
24 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
25 let project = Project::test(fs.clone(), [], cx).await;
26 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
27
28 thread
29 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
30 .await
31 .unwrap();
32
33 thread.read_with(cx, |thread, _| {
34 assert!(
35 thread.entries().len() >= 2,
36 "Expected at least 2 entries. Got: {:?}",
37 thread.entries()
38 );
39 assert!(matches!(
40 thread.entries()[0],
41 AgentThreadEntry::UserMessage(_)
42 ));
43 assert!(matches!(
44 thread.entries()[1],
45 AgentThreadEntry::AssistantMessage(_)
46 ));
47 });
48}
49
50pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
51where
52 T: AgentServer + 'static,
53 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
54{
55 let fs = init_test(cx).await as _;
56
57 let tempdir = tempfile::tempdir().unwrap();
58 std::fs::write(
59 tempdir.path().join("foo.rs"),
60 indoc! {"
61 fn main() {
62 println!(\"Hello, world!\");
63 }
64 "},
65 )
66 .expect("failed to write file");
67 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
68 let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
69 thread
70 .update(cx, |thread, cx| {
71 thread.send(
72 vec![
73 "Read the file ".into(),
74 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
75 " and tell me what the content of the println! is".into(),
76 ],
77 cx,
78 )
79 })
80 .await
81 .unwrap();
82
83 thread.read_with(cx, |thread, cx| {
84 assert!(matches!(
85 thread.entries()[0],
86 AgentThreadEntry::UserMessage(_)
87 ));
88 let assistant_message = &thread
89 .entries()
90 .iter()
91 .rev()
92 .find_map(|entry| match entry {
93 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
94 _ => None,
95 })
96 .unwrap();
97
98 assert!(
99 assistant_message.to_markdown(cx).contains("Hello, world!"),
100 "unexpected assistant message: {:?}",
101 assistant_message.to_markdown(cx)
102 );
103 });
104
105 drop(tempdir);
106}
107
108pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
109where
110 T: AgentServer + 'static,
111 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
112{
113 let fs = init_test(cx).await as _;
114
115 let tempdir = tempfile::tempdir().unwrap();
116 let foo_path = tempdir.path().join("foo");
117 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
118
119 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
120 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
121
122 thread
123 .update(cx, |thread, cx| {
124 thread.send_raw(
125 &format!("Read {} and tell me what you see.", foo_path.display()),
126 cx,
127 )
128 })
129 .await
130 .unwrap();
131 thread.read_with(cx, |thread, _cx| {
132 assert!(thread.entries().iter().any(|entry| {
133 matches!(
134 entry,
135 AgentThreadEntry::ToolCall(ToolCall {
136 status: ToolCallStatus::Pending
137 | ToolCallStatus::InProgress
138 | ToolCallStatus::Completed,
139 ..
140 })
141 )
142 }));
143 assert!(
144 thread
145 .entries()
146 .iter()
147 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
148 );
149 });
150
151 drop(tempdir);
152}
153
154pub async fn test_tool_call_with_permission<T, F>(
155 server: F,
156 allow_option_id: acp::PermissionOptionId,
157 cx: &mut TestAppContext,
158) where
159 T: AgentServer + 'static,
160 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
161{
162 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
163 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
164 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
165 let full_turn = thread.update(cx, |thread, cx| {
166 thread.send_raw(
167 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
168 cx,
169 )
170 });
171
172 run_until_first_tool_call(
173 &thread,
174 |entry| {
175 matches!(
176 entry,
177 AgentThreadEntry::ToolCall(ToolCall {
178 status: ToolCallStatus::WaitingForConfirmation { .. },
179 ..
180 })
181 )
182 },
183 cx,
184 )
185 .await;
186
187 let tool_call_id = thread.read_with(cx, |thread, cx| {
188 let AgentThreadEntry::ToolCall(ToolCall {
189 id,
190 label,
191 status: ToolCallStatus::WaitingForConfirmation { .. },
192 ..
193 }) = &thread
194 .entries()
195 .iter()
196 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
197 .unwrap()
198 else {
199 panic!();
200 };
201
202 let label = label.read(cx).source();
203 assert!(label.contains("touch"), "Got: {}", label);
204
205 id.clone()
206 });
207
208 thread.update(cx, |thread, cx| {
209 thread.authorize_tool_call(
210 tool_call_id,
211 allow_option_id,
212 acp::PermissionOptionKind::AllowOnce,
213 cx,
214 );
215
216 assert!(thread.entries().iter().any(|entry| matches!(
217 entry,
218 AgentThreadEntry::ToolCall(ToolCall {
219 status: ToolCallStatus::Pending
220 | ToolCallStatus::InProgress
221 | ToolCallStatus::Completed,
222 ..
223 })
224 )));
225 });
226
227 full_turn.await.unwrap();
228
229 thread.read_with(cx, |thread, cx| {
230 let AgentThreadEntry::ToolCall(ToolCall {
231 content,
232 status: ToolCallStatus::Pending
233 | ToolCallStatus::InProgress
234 | ToolCallStatus::Completed,
235 ..
236 }) = thread
237 .entries()
238 .iter()
239 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
240 .unwrap()
241 else {
242 panic!();
243 };
244
245 assert!(
246 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
247 "Expected content to contain 'Hello'"
248 );
249 });
250}
251
252pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
253where
254 T: AgentServer + 'static,
255 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
256{
257 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
258
259 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
260 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
261 let _ = thread.update(cx, |thread, cx| {
262 thread.send_raw(
263 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
264 cx,
265 )
266 });
267
268 let first_tool_call_ix = run_until_first_tool_call(
269 &thread,
270 |entry| {
271 matches!(
272 entry,
273 AgentThreadEntry::ToolCall(ToolCall {
274 status: ToolCallStatus::WaitingForConfirmation { .. },
275 ..
276 })
277 )
278 },
279 cx,
280 )
281 .await;
282
283 thread.read_with(cx, |thread, cx| {
284 let AgentThreadEntry::ToolCall(ToolCall {
285 id,
286 label,
287 status: ToolCallStatus::WaitingForConfirmation { .. },
288 ..
289 }) = &thread.entries()[first_tool_call_ix]
290 else {
291 panic!("{:?}", thread.entries()[1]);
292 };
293
294 let label = label.read(cx).source();
295 assert!(label.contains("touch"), "Got: {}", label);
296
297 id.clone()
298 });
299
300 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
301 thread.read_with(cx, |thread, _cx| {
302 let AgentThreadEntry::ToolCall(ToolCall {
303 status: ToolCallStatus::Canceled,
304 ..
305 }) = &thread.entries()[first_tool_call_ix]
306 else {
307 panic!();
308 };
309 });
310
311 thread
312 .update(cx, |thread, cx| {
313 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
314 })
315 .await
316 .unwrap();
317 thread.read_with(cx, |thread, _| {
318 assert!(matches!(
319 &thread.entries().last().unwrap(),
320 AgentThreadEntry::AssistantMessage(..),
321 ))
322 });
323}
324
325pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
326where
327 T: AgentServer + 'static,
328 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
329{
330 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
331 let project = Project::test(fs.clone(), [], cx).await;
332 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
333
334 thread
335 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
336 .await
337 .unwrap();
338
339 thread.read_with(cx, |thread, _| {
340 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
341 });
342
343 let weak_thread = thread.downgrade();
344 drop(thread);
345
346 cx.executor().run_until_parked();
347 assert!(!weak_thread.is_upgradable());
348}
349
350#[macro_export]
351macro_rules! common_e2e_tests {
352 ($server:expr, allow_option_id = $allow_option_id:expr) => {
353 mod common_e2e {
354 use super::*;
355
356 #[::gpui::test]
357 #[cfg_attr(not(feature = "e2e"), ignore)]
358 async fn basic(cx: &mut ::gpui::TestAppContext) {
359 $crate::e2e_tests::test_basic($server, cx).await;
360 }
361
362 #[::gpui::test]
363 #[cfg_attr(not(feature = "e2e"), ignore)]
364 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
365 $crate::e2e_tests::test_path_mentions($server, cx).await;
366 }
367
368 #[::gpui::test]
369 #[cfg_attr(not(feature = "e2e"), ignore)]
370 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
371 $crate::e2e_tests::test_tool_call($server, cx).await;
372 }
373
374 #[::gpui::test]
375 #[cfg_attr(not(feature = "e2e"), ignore)]
376 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
377 $crate::e2e_tests::test_tool_call_with_permission(
378 $server,
379 ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
380 cx,
381 )
382 .await;
383 }
384
385 #[::gpui::test]
386 #[cfg_attr(not(feature = "e2e"), ignore)]
387 async fn cancel(cx: &mut ::gpui::TestAppContext) {
388 $crate::e2e_tests::test_cancel($server, cx).await;
389 }
390
391 #[::gpui::test]
392 #[cfg_attr(not(feature = "e2e"), ignore)]
393 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
394 $crate::e2e_tests::test_thread_drop($server, cx).await;
395 }
396 }
397 };
398}
399pub use common_e2e_tests;
400
401// Helpers
402
403pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
404 env_logger::try_init().ok();
405
406 cx.update(|cx| {
407 let settings_store = settings::SettingsStore::test(cx);
408 cx.set_global(settings_store);
409 gpui_tokio::init(cx);
410 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
411 cx.set_http_client(Arc::new(http_client));
412 let client = client::Client::production(cx);
413 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
414 language_model::init(client.clone(), cx);
415 language_models::init(user_store, client, cx);
416
417 #[cfg(test)]
418 project::agent_server_store::AllAgentServersSettings::override_global(
419 project::agent_server_store::AllAgentServersSettings {
420 claude: Some(BuiltinAgentServerSettings {
421 path: Some("claude-code-acp".into()),
422 ..Default::default()
423 }),
424 gemini: Some(crate::gemini::tests::local_command().into()),
425 codex: Some(BuiltinAgentServerSettings {
426 path: Some("codex-acp".into()),
427 ..Default::default()
428 }),
429 custom: collections::HashMap::default(),
430 },
431 cx,
432 );
433 });
434
435 cx.executor().allow_parking();
436
437 FakeFs::new(cx.executor())
438}
439
440pub async fn new_test_thread(
441 server: impl AgentServer + 'static,
442 project: Entity<Project>,
443 current_dir: impl AsRef<Path>,
444 cx: &mut TestAppContext,
445) -> Entity<AcpThread> {
446 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
447 let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
448
449 let (connection, _) = cx
450 .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
451 .await
452 .unwrap();
453
454 cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
455 .await
456 .unwrap()
457}
458
459pub async fn run_until_first_tool_call(
460 thread: &Entity<AcpThread>,
461 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
462 cx: &mut TestAppContext,
463) -> usize {
464 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
465
466 let subscription = cx.update(|cx| {
467 cx.subscribe(thread, move |thread, _, cx| {
468 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
469 if wait_until(entry) {
470 return tx.try_send(ix).unwrap();
471 }
472 }
473 })
474 });
475
476 select! {
477 // We have to use a smol timer here because
478 // cx.background_executor().timer isn't real in the test context
479 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
480 panic!("Timeout waiting for tool call")
481 }
482 ix = rx.next().fuse() => {
483 drop(subscription);
484 ix.unwrap()
485 }
486 }
487}
488
489pub fn get_zed_path() -> PathBuf {
490 let mut zed_path = std::env::current_exe().unwrap();
491
492 while zed_path
493 .file_name()
494 .is_none_or(|name| name.to_string_lossy() != "debug")
495 {
496 if !zed_path.pop() {
497 panic!("Could not find target directory");
498 }
499 }
500
501 zed_path.push("zed");
502
503 if !zed_path.exists() {
504 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
505 }
506
507 zed_path
508}