1use crate::AgentServer;
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::{AppContext, Entity, TestAppContext};
6use indoc::indoc;
7use project::{FakeFs, Project};
8use std::{
9 path::{Path, PathBuf},
10 sync::Arc,
11 time::Duration,
12};
13use util::path;
14
15pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
16where
17 T: AgentServer + 'static,
18 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
19{
20 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
21 let project = Project::test(fs.clone(), [], cx).await;
22 let thread = new_test_thread(
23 server(&fs, &project, cx).await,
24 project.clone(),
25 "/private/tmp",
26 cx,
27 )
28 .await;
29
30 thread
31 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
32 .await
33 .unwrap();
34
35 thread.read_with(cx, |thread, _| {
36 assert!(
37 thread.entries().len() >= 2,
38 "Expected at least 2 entries. Got: {:?}",
39 thread.entries()
40 );
41 assert!(matches!(
42 thread.entries()[0],
43 AgentThreadEntry::UserMessage(_)
44 ));
45 assert!(matches!(
46 thread.entries()[1],
47 AgentThreadEntry::AssistantMessage(_)
48 ));
49 });
50}
51
52pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
53where
54 T: AgentServer + 'static,
55 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
56{
57 let fs = init_test(cx).await as _;
58
59 let tempdir = tempfile::tempdir().unwrap();
60 std::fs::write(
61 tempdir.path().join("foo.rs"),
62 indoc! {"
63 fn main() {
64 println!(\"Hello, world!\");
65 }
66 "},
67 )
68 .expect("failed to write file");
69 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
70 let thread = new_test_thread(
71 server(&fs, &project, cx).await,
72 project.clone(),
73 tempdir.path(),
74 cx,
75 )
76 .await;
77 thread
78 .update(cx, |thread, cx| {
79 thread.send(
80 vec![
81 acp::ContentBlock::Text(acp::TextContent {
82 text: "Read the file ".into(),
83 annotations: None,
84 }),
85 acp::ContentBlock::ResourceLink(acp::ResourceLink {
86 uri: "foo.rs".into(),
87 name: "foo.rs".into(),
88 annotations: None,
89 description: None,
90 mime_type: None,
91 size: None,
92 title: None,
93 }),
94 acp::ContentBlock::Text(acp::TextContent {
95 text: " and tell me what the content of the println! is".into(),
96 annotations: None,
97 }),
98 ],
99 cx,
100 )
101 })
102 .await
103 .unwrap();
104
105 thread.read_with(cx, |thread, cx| {
106 assert!(matches!(
107 thread.entries()[0],
108 AgentThreadEntry::UserMessage(_)
109 ));
110 let assistant_message = &thread
111 .entries()
112 .iter()
113 .rev()
114 .find_map(|entry| match entry {
115 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
116 _ => None,
117 })
118 .unwrap();
119
120 assert!(
121 assistant_message.to_markdown(cx).contains("Hello, world!"),
122 "unexpected assistant message: {:?}",
123 assistant_message.to_markdown(cx)
124 );
125 });
126
127 drop(tempdir);
128}
129
130pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
131where
132 T: AgentServer + 'static,
133 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
134{
135 let fs = init_test(cx).await as _;
136
137 let tempdir = tempfile::tempdir().unwrap();
138 let foo_path = tempdir.path().join("foo");
139 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
140
141 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
142 let thread = new_test_thread(
143 server(&fs, &project, cx).await,
144 project.clone(),
145 "/private/tmp",
146 cx,
147 )
148 .await;
149
150 thread
151 .update(cx, |thread, cx| {
152 thread.send_raw(
153 &format!("Read {} and tell me what you see.", foo_path.display()),
154 cx,
155 )
156 })
157 .await
158 .unwrap();
159 thread.read_with(cx, |thread, _cx| {
160 assert!(thread.entries().iter().any(|entry| {
161 matches!(
162 entry,
163 AgentThreadEntry::ToolCall(ToolCall {
164 status: ToolCallStatus::Pending
165 | ToolCallStatus::InProgress
166 | ToolCallStatus::Completed,
167 ..
168 })
169 )
170 }));
171 assert!(
172 thread
173 .entries()
174 .iter()
175 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
176 );
177 });
178
179 drop(tempdir);
180}
181
182pub async fn test_tool_call_with_permission<T, F>(
183 server: F,
184 allow_option_id: acp::PermissionOptionId,
185 cx: &mut TestAppContext,
186) where
187 T: AgentServer + 'static,
188 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
189{
190 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
191 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
192 let thread = new_test_thread(
193 server(&fs, &project, cx).await,
194 project.clone(),
195 "/private/tmp",
196 cx,
197 )
198 .await;
199 let full_turn = thread.update(cx, |thread, cx| {
200 thread.send_raw(
201 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
202 cx,
203 )
204 });
205
206 run_until_first_tool_call(
207 &thread,
208 |entry| {
209 matches!(
210 entry,
211 AgentThreadEntry::ToolCall(ToolCall {
212 status: ToolCallStatus::WaitingForConfirmation { .. },
213 ..
214 })
215 )
216 },
217 cx,
218 )
219 .await;
220
221 let tool_call_id = thread.read_with(cx, |thread, cx| {
222 let AgentThreadEntry::ToolCall(ToolCall {
223 id,
224 label,
225 status: ToolCallStatus::WaitingForConfirmation { .. },
226 ..
227 }) = &thread
228 .entries()
229 .iter()
230 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
231 .unwrap()
232 else {
233 panic!();
234 };
235
236 let label = label.read(cx).source();
237 assert!(label.contains("touch"), "Got: {}", label);
238
239 id.clone()
240 });
241
242 thread.update(cx, |thread, cx| {
243 thread.authorize_tool_call(
244 tool_call_id,
245 allow_option_id,
246 acp::PermissionOptionKind::AllowOnce,
247 cx,
248 );
249
250 assert!(thread.entries().iter().any(|entry| matches!(
251 entry,
252 AgentThreadEntry::ToolCall(ToolCall {
253 status: ToolCallStatus::Pending
254 | ToolCallStatus::InProgress
255 | ToolCallStatus::Completed,
256 ..
257 })
258 )));
259 });
260
261 full_turn.await.unwrap();
262
263 thread.read_with(cx, |thread, cx| {
264 let AgentThreadEntry::ToolCall(ToolCall {
265 content,
266 status: ToolCallStatus::Pending
267 | ToolCallStatus::InProgress
268 | ToolCallStatus::Completed,
269 ..
270 }) = thread
271 .entries()
272 .iter()
273 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
274 .unwrap()
275 else {
276 panic!();
277 };
278
279 assert!(
280 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
281 "Expected content to contain 'Hello'"
282 );
283 });
284}
285
286pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
287where
288 T: AgentServer + 'static,
289 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
290{
291 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
292
293 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
294 let thread = new_test_thread(
295 server(&fs, &project, cx).await,
296 project.clone(),
297 "/private/tmp",
298 cx,
299 )
300 .await;
301 let _ = thread.update(cx, |thread, cx| {
302 thread.send_raw(
303 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
304 cx,
305 )
306 });
307
308 let first_tool_call_ix = run_until_first_tool_call(
309 &thread,
310 |entry| {
311 matches!(
312 entry,
313 AgentThreadEntry::ToolCall(ToolCall {
314 status: ToolCallStatus::WaitingForConfirmation { .. },
315 ..
316 })
317 )
318 },
319 cx,
320 )
321 .await;
322
323 thread.read_with(cx, |thread, cx| {
324 let AgentThreadEntry::ToolCall(ToolCall {
325 id,
326 label,
327 status: ToolCallStatus::WaitingForConfirmation { .. },
328 ..
329 }) = &thread.entries()[first_tool_call_ix]
330 else {
331 panic!("{:?}", thread.entries()[1]);
332 };
333
334 let label = label.read(cx).source();
335 assert!(label.contains("touch"), "Got: {}", label);
336
337 id.clone()
338 });
339
340 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
341 thread.read_with(cx, |thread, _cx| {
342 let AgentThreadEntry::ToolCall(ToolCall {
343 status: ToolCallStatus::Canceled,
344 ..
345 }) = &thread.entries()[first_tool_call_ix]
346 else {
347 panic!();
348 };
349 });
350
351 thread
352 .update(cx, |thread, cx| {
353 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
354 })
355 .await
356 .unwrap();
357 thread.read_with(cx, |thread, _| {
358 assert!(matches!(
359 &thread.entries().last().unwrap(),
360 AgentThreadEntry::AssistantMessage(..),
361 ))
362 });
363}
364
365pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
366where
367 T: AgentServer + 'static,
368 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
369{
370 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
371 let project = Project::test(fs.clone(), [], cx).await;
372 let thread = new_test_thread(
373 server(&fs, &project, cx).await,
374 project.clone(),
375 "/private/tmp",
376 cx,
377 )
378 .await;
379
380 thread
381 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
382 .await
383 .unwrap();
384
385 thread.read_with(cx, |thread, _| {
386 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
387 });
388
389 let weak_thread = thread.downgrade();
390 drop(thread);
391
392 cx.executor().run_until_parked();
393 assert!(!weak_thread.is_upgradable());
394}
395
396#[macro_export]
397macro_rules! common_e2e_tests {
398 ($server:expr, allow_option_id = $allow_option_id:expr) => {
399 mod common_e2e {
400 use super::*;
401
402 #[::gpui::test]
403 #[cfg_attr(not(feature = "e2e"), ignore)]
404 async fn basic(cx: &mut ::gpui::TestAppContext) {
405 $crate::e2e_tests::test_basic($server, cx).await;
406 }
407
408 #[::gpui::test]
409 #[cfg_attr(not(feature = "e2e"), ignore)]
410 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
411 $crate::e2e_tests::test_path_mentions($server, cx).await;
412 }
413
414 #[::gpui::test]
415 #[cfg_attr(not(feature = "e2e"), ignore)]
416 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
417 $crate::e2e_tests::test_tool_call($server, cx).await;
418 }
419
420 #[::gpui::test]
421 #[cfg_attr(not(feature = "e2e"), ignore)]
422 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
423 $crate::e2e_tests::test_tool_call_with_permission(
424 $server,
425 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
426 cx,
427 )
428 .await;
429 }
430
431 #[::gpui::test]
432 #[cfg_attr(not(feature = "e2e"), ignore)]
433 async fn cancel(cx: &mut ::gpui::TestAppContext) {
434 $crate::e2e_tests::test_cancel($server, cx).await;
435 }
436
437 #[::gpui::test]
438 #[cfg_attr(not(feature = "e2e"), ignore)]
439 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
440 $crate::e2e_tests::test_thread_drop($server, cx).await;
441 }
442 }
443 };
444}
445pub use common_e2e_tests;
446
447// Helpers
448
449pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
450 #[cfg(test)]
451 use settings::Settings;
452
453 env_logger::try_init().ok();
454
455 cx.update(|cx| {
456 let settings_store = settings::SettingsStore::test(cx);
457 cx.set_global(settings_store);
458 Project::init_settings(cx);
459 language::init(cx);
460 gpui_tokio::init(cx);
461 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
462 cx.set_http_client(Arc::new(http_client));
463 client::init_settings(cx);
464 let client = client::Client::production(cx);
465 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
466 language_model::init(client.clone(), cx);
467 language_models::init(user_store, client, cx);
468 agent_settings::init(cx);
469 crate::settings::init(cx);
470
471 #[cfg(test)]
472 crate::AllAgentServersSettings::override_global(
473 crate::AllAgentServersSettings {
474 claude: Some(crate::AgentServerSettings {
475 command: crate::claude::tests::local_command(),
476 }),
477 gemini: Some(crate::AgentServerSettings {
478 command: crate::gemini::tests::local_command(),
479 }),
480 custom: collections::HashMap::default(),
481 },
482 cx,
483 );
484 });
485
486 cx.executor().allow_parking();
487
488 FakeFs::new(cx.executor())
489}
490
491pub async fn new_test_thread(
492 server: impl AgentServer + 'static,
493 project: Entity<Project>,
494 current_dir: impl AsRef<Path>,
495 cx: &mut TestAppContext,
496) -> Entity<AcpThread> {
497 let connection = cx
498 .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
499 .await
500 .unwrap();
501
502 cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
503 .await
504 .unwrap()
505}
506
507pub async fn run_until_first_tool_call(
508 thread: &Entity<AcpThread>,
509 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
510 cx: &mut TestAppContext,
511) -> usize {
512 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
513
514 let subscription = cx.update(|cx| {
515 cx.subscribe(thread, move |thread, _, cx| {
516 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
517 if wait_until(entry) {
518 return tx.try_send(ix).unwrap();
519 }
520 }
521 })
522 });
523
524 select! {
525 // We have to use a smol timer here because
526 // cx.background_executor().timer isn't real in the test context
527 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
528 panic!("Timeout waiting for tool call")
529 }
530 ix = rx.next().fuse() => {
531 drop(subscription);
532 ix.unwrap()
533 }
534 }
535}
536
537pub fn get_zed_path() -> PathBuf {
538 let mut zed_path = std::env::current_exe().unwrap();
539
540 while zed_path
541 .file_name()
542 .is_none_or(|name| name.to_string_lossy() != "debug")
543 {
544 if !zed_path.pop() {
545 panic!("Could not find target directory");
546 }
547 }
548
549 zed_path.push("zed");
550
551 if !zed_path.exists() {
552 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
553 }
554
555 zed_path
556}