1use crate::AgentServer;
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus, new_prompt_id};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::{AppContext, Entity, TestAppContext};
6use indoc::indoc;
7use project::{FakeFs, Project};
8use std::{
9 path::{Path, PathBuf},
10 sync::Arc,
11 time::Duration,
12};
13use util::path;
14
15pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
16where
17 T: AgentServer + 'static,
18 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
19{
20 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
21 let project = Project::test(fs.clone(), [], cx).await;
22 let thread = new_test_thread(
23 server(&fs, &project, cx).await,
24 project.clone(),
25 "/private/tmp",
26 cx,
27 )
28 .await;
29
30 thread
31 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
32 .await
33 .unwrap();
34
35 thread.read_with(cx, |thread, _| {
36 assert!(
37 thread.entries().len() >= 2,
38 "Expected at least 2 entries. Got: {:?}",
39 thread.entries()
40 );
41 assert!(matches!(
42 thread.entries()[0],
43 AgentThreadEntry::UserMessage(_)
44 ));
45 assert!(matches!(
46 thread.entries()[1],
47 AgentThreadEntry::AssistantMessage(_)
48 ));
49 });
50}
51
52pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
53where
54 T: AgentServer + 'static,
55 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
56{
57 let fs = init_test(cx).await as _;
58
59 let tempdir = tempfile::tempdir().unwrap();
60 std::fs::write(
61 tempdir.path().join("foo.rs"),
62 indoc! {"
63 fn main() {
64 println!(\"Hello, world!\");
65 }
66 "},
67 )
68 .expect("failed to write file");
69 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
70 let thread = new_test_thread(
71 server(&fs, &project, cx).await,
72 project.clone(),
73 tempdir.path(),
74 cx,
75 )
76 .await;
77 thread
78 .update(cx, |thread, cx| {
79 thread.send(
80 new_prompt_id(),
81 vec![
82 acp::ContentBlock::Text(acp::TextContent {
83 text: "Read the file ".into(),
84 annotations: None,
85 }),
86 acp::ContentBlock::ResourceLink(acp::ResourceLink {
87 uri: "foo.rs".into(),
88 name: "foo.rs".into(),
89 annotations: None,
90 description: None,
91 mime_type: None,
92 size: None,
93 title: None,
94 }),
95 acp::ContentBlock::Text(acp::TextContent {
96 text: " and tell me what the content of the println! is".into(),
97 annotations: None,
98 }),
99 ],
100 cx,
101 )
102 })
103 .await
104 .unwrap();
105
106 thread.read_with(cx, |thread, cx| {
107 assert!(matches!(
108 thread.entries()[0],
109 AgentThreadEntry::UserMessage(_)
110 ));
111 let assistant_message = &thread
112 .entries()
113 .iter()
114 .rev()
115 .find_map(|entry| match entry {
116 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
117 _ => None,
118 })
119 .unwrap();
120
121 assert!(
122 assistant_message.to_markdown(cx).contains("Hello, world!"),
123 "unexpected assistant message: {:?}",
124 assistant_message.to_markdown(cx)
125 );
126 });
127
128 drop(tempdir);
129}
130
131pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
132where
133 T: AgentServer + 'static,
134 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
135{
136 let fs = init_test(cx).await as _;
137
138 let tempdir = tempfile::tempdir().unwrap();
139 let foo_path = tempdir.path().join("foo");
140 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
141
142 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
143 let thread = new_test_thread(
144 server(&fs, &project, cx).await,
145 project.clone(),
146 "/private/tmp",
147 cx,
148 )
149 .await;
150
151 thread
152 .update(cx, |thread, cx| {
153 thread.send_raw(
154 &format!("Read {} and tell me what you see.", foo_path.display()),
155 cx,
156 )
157 })
158 .await
159 .unwrap();
160 thread.read_with(cx, |thread, _cx| {
161 assert!(thread.entries().iter().any(|entry| {
162 matches!(
163 entry,
164 AgentThreadEntry::ToolCall(ToolCall {
165 status: ToolCallStatus::Pending
166 | ToolCallStatus::InProgress
167 | ToolCallStatus::Completed,
168 ..
169 })
170 )
171 }));
172 assert!(
173 thread
174 .entries()
175 .iter()
176 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
177 );
178 });
179
180 drop(tempdir);
181}
182
183pub async fn test_tool_call_with_permission<T, F>(
184 server: F,
185 allow_option_id: acp::PermissionOptionId,
186 cx: &mut TestAppContext,
187) where
188 T: AgentServer + 'static,
189 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
190{
191 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
192 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
193 let thread = new_test_thread(
194 server(&fs, &project, cx).await,
195 project.clone(),
196 "/private/tmp",
197 cx,
198 )
199 .await;
200 let full_turn = thread.update(cx, |thread, cx| {
201 thread.send_raw(
202 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
203 cx,
204 )
205 });
206
207 run_until_first_tool_call(
208 &thread,
209 |entry| {
210 matches!(
211 entry,
212 AgentThreadEntry::ToolCall(ToolCall {
213 status: ToolCallStatus::WaitingForConfirmation { .. },
214 ..
215 })
216 )
217 },
218 cx,
219 )
220 .await;
221
222 let tool_call_id = thread.read_with(cx, |thread, cx| {
223 let AgentThreadEntry::ToolCall(ToolCall {
224 id,
225 label,
226 status: ToolCallStatus::WaitingForConfirmation { .. },
227 ..
228 }) = &thread
229 .entries()
230 .iter()
231 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
232 .unwrap()
233 else {
234 panic!();
235 };
236
237 let label = label.read(cx).source();
238 assert!(label.contains("touch"), "Got: {}", label);
239
240 id.clone()
241 });
242
243 thread.update(cx, |thread, cx| {
244 thread.authorize_tool_call(
245 tool_call_id,
246 allow_option_id,
247 acp::PermissionOptionKind::AllowOnce,
248 cx,
249 );
250
251 assert!(thread.entries().iter().any(|entry| matches!(
252 entry,
253 AgentThreadEntry::ToolCall(ToolCall {
254 status: ToolCallStatus::Pending
255 | ToolCallStatus::InProgress
256 | ToolCallStatus::Completed,
257 ..
258 })
259 )));
260 });
261
262 full_turn.await.unwrap();
263
264 thread.read_with(cx, |thread, cx| {
265 let AgentThreadEntry::ToolCall(ToolCall {
266 content,
267 status: ToolCallStatus::Pending
268 | ToolCallStatus::InProgress
269 | ToolCallStatus::Completed,
270 ..
271 }) = thread
272 .entries()
273 .iter()
274 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
275 .unwrap()
276 else {
277 panic!();
278 };
279
280 assert!(
281 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
282 "Expected content to contain 'Hello'"
283 );
284 });
285}
286
287pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
288where
289 T: AgentServer + 'static,
290 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
291{
292 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
293
294 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
295 let thread = new_test_thread(
296 server(&fs, &project, cx).await,
297 project.clone(),
298 "/private/tmp",
299 cx,
300 )
301 .await;
302 let _ = thread.update(cx, |thread, cx| {
303 thread.send_raw(
304 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
305 cx,
306 )
307 });
308
309 let first_tool_call_ix = run_until_first_tool_call(
310 &thread,
311 |entry| {
312 matches!(
313 entry,
314 AgentThreadEntry::ToolCall(ToolCall {
315 status: ToolCallStatus::WaitingForConfirmation { .. },
316 ..
317 })
318 )
319 },
320 cx,
321 )
322 .await;
323
324 thread.read_with(cx, |thread, cx| {
325 let AgentThreadEntry::ToolCall(ToolCall {
326 id,
327 label,
328 status: ToolCallStatus::WaitingForConfirmation { .. },
329 ..
330 }) = &thread.entries()[first_tool_call_ix]
331 else {
332 panic!("{:?}", thread.entries()[1]);
333 };
334
335 let label = label.read(cx).source();
336 assert!(label.contains("touch"), "Got: {}", label);
337
338 id.clone()
339 });
340
341 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
342 thread.read_with(cx, |thread, _cx| {
343 let AgentThreadEntry::ToolCall(ToolCall {
344 status: ToolCallStatus::Canceled,
345 ..
346 }) = &thread.entries()[first_tool_call_ix]
347 else {
348 panic!();
349 };
350 });
351
352 thread
353 .update(cx, |thread, cx| {
354 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
355 })
356 .await
357 .unwrap();
358 thread.read_with(cx, |thread, _| {
359 assert!(matches!(
360 &thread.entries().last().unwrap(),
361 AgentThreadEntry::AssistantMessage(..),
362 ))
363 });
364}
365
366pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
367where
368 T: AgentServer + 'static,
369 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
370{
371 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
372 let project = Project::test(fs.clone(), [], cx).await;
373 let thread = new_test_thread(
374 server(&fs, &project, cx).await,
375 project.clone(),
376 "/private/tmp",
377 cx,
378 )
379 .await;
380
381 thread
382 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
383 .await
384 .unwrap();
385
386 thread.read_with(cx, |thread, _| {
387 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
388 });
389
390 let weak_thread = thread.downgrade();
391 drop(thread);
392
393 cx.executor().run_until_parked();
394 assert!(!weak_thread.is_upgradable());
395}
396
397#[macro_export]
398macro_rules! common_e2e_tests {
399 ($server:expr, allow_option_id = $allow_option_id:expr) => {
400 mod common_e2e {
401 use super::*;
402
403 #[::gpui::test]
404 #[cfg_attr(not(feature = "e2e"), ignore)]
405 async fn basic(cx: &mut ::gpui::TestAppContext) {
406 $crate::e2e_tests::test_basic($server, cx).await;
407 }
408
409 #[::gpui::test]
410 #[cfg_attr(not(feature = "e2e"), ignore)]
411 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
412 $crate::e2e_tests::test_path_mentions($server, cx).await;
413 }
414
415 #[::gpui::test]
416 #[cfg_attr(not(feature = "e2e"), ignore)]
417 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
418 $crate::e2e_tests::test_tool_call($server, cx).await;
419 }
420
421 #[::gpui::test]
422 #[cfg_attr(not(feature = "e2e"), ignore)]
423 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
424 $crate::e2e_tests::test_tool_call_with_permission(
425 $server,
426 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
427 cx,
428 )
429 .await;
430 }
431
432 #[::gpui::test]
433 #[cfg_attr(not(feature = "e2e"), ignore)]
434 async fn cancel(cx: &mut ::gpui::TestAppContext) {
435 $crate::e2e_tests::test_cancel($server, cx).await;
436 }
437
438 #[::gpui::test]
439 #[cfg_attr(not(feature = "e2e"), ignore)]
440 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
441 $crate::e2e_tests::test_thread_drop($server, cx).await;
442 }
443 }
444 };
445}
446pub use common_e2e_tests;
447
448// Helpers
449
450pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
451 #[cfg(test)]
452 use settings::Settings;
453
454 env_logger::try_init().ok();
455
456 cx.update(|cx| {
457 let settings_store = settings::SettingsStore::test(cx);
458 cx.set_global(settings_store);
459 Project::init_settings(cx);
460 language::init(cx);
461 gpui_tokio::init(cx);
462 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
463 cx.set_http_client(Arc::new(http_client));
464 client::init_settings(cx);
465 let client = client::Client::production(cx);
466 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
467 language_model::init(client.clone(), cx);
468 language_models::init(user_store, client, cx);
469 agent_settings::init(cx);
470 crate::settings::init(cx);
471
472 #[cfg(test)]
473 crate::AllAgentServersSettings::override_global(
474 crate::AllAgentServersSettings {
475 claude: Some(crate::AgentServerSettings {
476 command: crate::claude::tests::local_command(),
477 }),
478 gemini: Some(crate::AgentServerSettings {
479 command: crate::gemini::tests::local_command(),
480 }),
481 custom: collections::HashMap::default(),
482 },
483 cx,
484 );
485 });
486
487 cx.executor().allow_parking();
488
489 FakeFs::new(cx.executor())
490}
491
492pub async fn new_test_thread(
493 server: impl AgentServer + 'static,
494 project: Entity<Project>,
495 current_dir: impl AsRef<Path>,
496 cx: &mut TestAppContext,
497) -> Entity<AcpThread> {
498 let connection = cx
499 .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
500 .await
501 .unwrap();
502
503 cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
504 .await
505 .unwrap()
506}
507
508pub async fn run_until_first_tool_call(
509 thread: &Entity<AcpThread>,
510 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
511 cx: &mut TestAppContext,
512) -> usize {
513 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
514
515 let subscription = cx.update(|cx| {
516 cx.subscribe(thread, move |thread, _, cx| {
517 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
518 if wait_until(entry) {
519 return tx.try_send(ix).unwrap();
520 }
521 }
522 })
523 });
524
525 select! {
526 // We have to use a smol timer here because
527 // cx.background_executor().timer isn't real in the test context
528 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
529 panic!("Timeout waiting for tool call")
530 }
531 ix = rx.next().fuse() => {
532 drop(subscription);
533 ix.unwrap()
534 }
535 }
536}
537
538pub fn get_zed_path() -> PathBuf {
539 let mut zed_path = std::env::current_exe().unwrap();
540
541 while zed_path
542 .file_name()
543 .is_none_or(|name| name.to_string_lossy() != "debug")
544 {
545 if !zed_path.pop() {
546 panic!("Could not find target directory");
547 }
548 }
549
550 zed_path.push("zed");
551
552 if !zed_path.exists() {
553 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
554 }
555
556 zed_path
557}