1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::{AppContext, Entity, TestAppContext};
6use indoc::indoc;
7#[cfg(test)]
8use project::agent_server_store::BuiltinAgentServerSettings;
9use project::{FakeFs, Project, agent_server_store::AllAgentServersSettings};
10use std::{
11 path::{Path, PathBuf},
12 sync::Arc,
13 time::Duration,
14};
15use util::path;
16
17pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
18where
19 T: AgentServer + 'static,
20 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
21{
22 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
23 let project = Project::test(fs.clone(), [], cx).await;
24 let thread = new_test_thread(
25 server(&fs, &project, cx).await,
26 project.clone(),
27 "/private/tmp",
28 cx,
29 )
30 .await;
31
32 thread
33 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
34 .await
35 .unwrap();
36
37 thread.read_with(cx, |thread, _| {
38 assert!(
39 thread.entries().len() >= 2,
40 "Expected at least 2 entries. Got: {:?}",
41 thread.entries()
42 );
43 assert!(matches!(
44 thread.entries()[0],
45 AgentThreadEntry::UserMessage(_)
46 ));
47 assert!(matches!(
48 thread.entries()[1],
49 AgentThreadEntry::AssistantMessage(_)
50 ));
51 });
52}
53
54pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
55where
56 T: AgentServer + 'static,
57 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
58{
59 let fs = init_test(cx).await as _;
60
61 let tempdir = tempfile::tempdir().unwrap();
62 std::fs::write(
63 tempdir.path().join("foo.rs"),
64 indoc! {"
65 fn main() {
66 println!(\"Hello, world!\");
67 }
68 "},
69 )
70 .expect("failed to write file");
71 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
72 let thread = new_test_thread(
73 server(&fs, &project, cx).await,
74 project.clone(),
75 tempdir.path(),
76 cx,
77 )
78 .await;
79 thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 vec![
83 acp::ContentBlock::Text(acp::TextContent {
84 text: "Read the file ".into(),
85 annotations: None,
86 meta: None,
87 }),
88 acp::ContentBlock::ResourceLink(acp::ResourceLink {
89 uri: "foo.rs".into(),
90 name: "foo.rs".into(),
91 annotations: None,
92 description: None,
93 mime_type: None,
94 size: None,
95 title: None,
96 meta: None,
97 }),
98 acp::ContentBlock::Text(acp::TextContent {
99 text: " and tell me what the content of the println! is".into(),
100 annotations: None,
101 meta: None,
102 }),
103 ],
104 cx,
105 )
106 })
107 .await
108 .unwrap();
109
110 thread.read_with(cx, |thread, cx| {
111 assert!(matches!(
112 thread.entries()[0],
113 AgentThreadEntry::UserMessage(_)
114 ));
115 let assistant_message = &thread
116 .entries()
117 .iter()
118 .rev()
119 .find_map(|entry| match entry {
120 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
121 _ => None,
122 })
123 .unwrap();
124
125 assert!(
126 assistant_message.to_markdown(cx).contains("Hello, world!"),
127 "unexpected assistant message: {:?}",
128 assistant_message.to_markdown(cx)
129 );
130 });
131
132 drop(tempdir);
133}
134
135pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
136where
137 T: AgentServer + 'static,
138 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
139{
140 let fs = init_test(cx).await as _;
141
142 let tempdir = tempfile::tempdir().unwrap();
143 let foo_path = tempdir.path().join("foo");
144 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
145
146 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
147 let thread = new_test_thread(
148 server(&fs, &project, cx).await,
149 project.clone(),
150 "/private/tmp",
151 cx,
152 )
153 .await;
154
155 thread
156 .update(cx, |thread, cx| {
157 thread.send_raw(
158 &format!("Read {} and tell me what you see.", foo_path.display()),
159 cx,
160 )
161 })
162 .await
163 .unwrap();
164 thread.read_with(cx, |thread, _cx| {
165 assert!(thread.entries().iter().any(|entry| {
166 matches!(
167 entry,
168 AgentThreadEntry::ToolCall(ToolCall {
169 status: ToolCallStatus::Pending
170 | ToolCallStatus::InProgress
171 | ToolCallStatus::Completed,
172 ..
173 })
174 )
175 }));
176 assert!(
177 thread
178 .entries()
179 .iter()
180 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
181 );
182 });
183
184 drop(tempdir);
185}
186
187pub async fn test_tool_call_with_permission<T, F>(
188 server: F,
189 allow_option_id: acp::PermissionOptionId,
190 cx: &mut TestAppContext,
191) where
192 T: AgentServer + 'static,
193 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
194{
195 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
196 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
197 let thread = new_test_thread(
198 server(&fs, &project, cx).await,
199 project.clone(),
200 "/private/tmp",
201 cx,
202 )
203 .await;
204 let full_turn = thread.update(cx, |thread, cx| {
205 thread.send_raw(
206 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
207 cx,
208 )
209 });
210
211 run_until_first_tool_call(
212 &thread,
213 |entry| {
214 matches!(
215 entry,
216 AgentThreadEntry::ToolCall(ToolCall {
217 status: ToolCallStatus::WaitingForConfirmation { .. },
218 ..
219 })
220 )
221 },
222 cx,
223 )
224 .await;
225
226 let tool_call_id = thread.read_with(cx, |thread, cx| {
227 let AgentThreadEntry::ToolCall(ToolCall {
228 id,
229 label,
230 status: ToolCallStatus::WaitingForConfirmation { .. },
231 ..
232 }) = &thread
233 .entries()
234 .iter()
235 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
236 .unwrap()
237 else {
238 panic!();
239 };
240
241 let label = label.read(cx).source();
242 assert!(label.contains("touch"), "Got: {}", label);
243
244 id.clone()
245 });
246
247 thread.update(cx, |thread, cx| {
248 thread.authorize_tool_call(
249 tool_call_id,
250 allow_option_id,
251 acp::PermissionOptionKind::AllowOnce,
252 cx,
253 );
254
255 assert!(thread.entries().iter().any(|entry| matches!(
256 entry,
257 AgentThreadEntry::ToolCall(ToolCall {
258 status: ToolCallStatus::Pending
259 | ToolCallStatus::InProgress
260 | ToolCallStatus::Completed,
261 ..
262 })
263 )));
264 });
265
266 full_turn.await.unwrap();
267
268 thread.read_with(cx, |thread, cx| {
269 let AgentThreadEntry::ToolCall(ToolCall {
270 content,
271 status: ToolCallStatus::Pending
272 | ToolCallStatus::InProgress
273 | ToolCallStatus::Completed,
274 ..
275 }) = thread
276 .entries()
277 .iter()
278 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
279 .unwrap()
280 else {
281 panic!();
282 };
283
284 assert!(
285 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
286 "Expected content to contain 'Hello'"
287 );
288 });
289}
290
291pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
292where
293 T: AgentServer + 'static,
294 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
295{
296 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
297
298 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
299 let thread = new_test_thread(
300 server(&fs, &project, cx).await,
301 project.clone(),
302 "/private/tmp",
303 cx,
304 )
305 .await;
306 let _ = thread.update(cx, |thread, cx| {
307 thread.send_raw(
308 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
309 cx,
310 )
311 });
312
313 let first_tool_call_ix = run_until_first_tool_call(
314 &thread,
315 |entry| {
316 matches!(
317 entry,
318 AgentThreadEntry::ToolCall(ToolCall {
319 status: ToolCallStatus::WaitingForConfirmation { .. },
320 ..
321 })
322 )
323 },
324 cx,
325 )
326 .await;
327
328 thread.read_with(cx, |thread, cx| {
329 let AgentThreadEntry::ToolCall(ToolCall {
330 id,
331 label,
332 status: ToolCallStatus::WaitingForConfirmation { .. },
333 ..
334 }) = &thread.entries()[first_tool_call_ix]
335 else {
336 panic!("{:?}", thread.entries()[1]);
337 };
338
339 let label = label.read(cx).source();
340 assert!(label.contains("touch"), "Got: {}", label);
341
342 id.clone()
343 });
344
345 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
346 thread.read_with(cx, |thread, _cx| {
347 let AgentThreadEntry::ToolCall(ToolCall {
348 status: ToolCallStatus::Canceled,
349 ..
350 }) = &thread.entries()[first_tool_call_ix]
351 else {
352 panic!();
353 };
354 });
355
356 thread
357 .update(cx, |thread, cx| {
358 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
359 })
360 .await
361 .unwrap();
362 thread.read_with(cx, |thread, _| {
363 assert!(matches!(
364 &thread.entries().last().unwrap(),
365 AgentThreadEntry::AssistantMessage(..),
366 ))
367 });
368}
369
370pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
371where
372 T: AgentServer + 'static,
373 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
374{
375 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
376 let project = Project::test(fs.clone(), [], cx).await;
377 let thread = new_test_thread(
378 server(&fs, &project, cx).await,
379 project.clone(),
380 "/private/tmp",
381 cx,
382 )
383 .await;
384
385 thread
386 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
387 .await
388 .unwrap();
389
390 thread.read_with(cx, |thread, _| {
391 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
392 });
393
394 let weak_thread = thread.downgrade();
395 drop(thread);
396
397 cx.executor().run_until_parked();
398 assert!(!weak_thread.is_upgradable());
399}
400
401#[macro_export]
402macro_rules! common_e2e_tests {
403 ($server:expr, allow_option_id = $allow_option_id:expr) => {
404 mod common_e2e {
405 use super::*;
406
407 #[::gpui::test]
408 #[cfg_attr(not(feature = "e2e"), ignore)]
409 async fn basic(cx: &mut ::gpui::TestAppContext) {
410 $crate::e2e_tests::test_basic($server, cx).await;
411 }
412
413 #[::gpui::test]
414 #[cfg_attr(not(feature = "e2e"), ignore)]
415 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
416 $crate::e2e_tests::test_path_mentions($server, cx).await;
417 }
418
419 #[::gpui::test]
420 #[cfg_attr(not(feature = "e2e"), ignore)]
421 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
422 $crate::e2e_tests::test_tool_call($server, cx).await;
423 }
424
425 #[::gpui::test]
426 #[cfg_attr(not(feature = "e2e"), ignore)]
427 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
428 $crate::e2e_tests::test_tool_call_with_permission(
429 $server,
430 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
431 cx,
432 )
433 .await;
434 }
435
436 #[::gpui::test]
437 #[cfg_attr(not(feature = "e2e"), ignore)]
438 async fn cancel(cx: &mut ::gpui::TestAppContext) {
439 $crate::e2e_tests::test_cancel($server, cx).await;
440 }
441
442 #[::gpui::test]
443 #[cfg_attr(not(feature = "e2e"), ignore)]
444 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
445 $crate::e2e_tests::test_thread_drop($server, cx).await;
446 }
447 }
448 };
449}
450pub use common_e2e_tests;
451
452// Helpers
453
454pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
455 use settings::Settings;
456
457 env_logger::try_init().ok();
458
459 cx.update(|cx| {
460 let settings_store = settings::SettingsStore::test(cx);
461 cx.set_global(settings_store);
462 Project::init_settings(cx);
463 language::init(cx);
464 gpui_tokio::init(cx);
465 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
466 cx.set_http_client(Arc::new(http_client));
467 client::init_settings(cx);
468 let client = client::Client::production(cx);
469 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
470 language_model::init(client.clone(), cx);
471 language_models::init(user_store, client, cx);
472 agent_settings::init(cx);
473 AllAgentServersSettings::register(cx);
474
475 #[cfg(test)]
476 AllAgentServersSettings::override_global(
477 AllAgentServersSettings {
478 claude: Some(BuiltinAgentServerSettings {
479 path: Some("claude-code-acp".into()),
480 args: None,
481 env: None,
482 ignore_system_version: None,
483 default_mode: None,
484 }),
485 gemini: Some(crate::gemini::tests::local_command().into()),
486 codex: Some(BuiltinAgentServerSettings {
487 path: Some("codex-acp".into()),
488 args: None,
489 env: None,
490 ignore_system_version: None,
491 default_mode: None,
492 }),
493 custom: collections::HashMap::default(),
494 },
495 cx,
496 );
497 });
498
499 cx.executor().allow_parking();
500
501 FakeFs::new(cx.executor())
502}
503
504pub async fn new_test_thread(
505 server: impl AgentServer + 'static,
506 project: Entity<Project>,
507 current_dir: impl AsRef<Path>,
508 cx: &mut TestAppContext,
509) -> Entity<AcpThread> {
510 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
511 let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
512
513 let (connection, _) = cx
514 .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
515 .await
516 .unwrap();
517
518 cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
519 .await
520 .unwrap()
521}
522
523pub async fn run_until_first_tool_call(
524 thread: &Entity<AcpThread>,
525 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
526 cx: &mut TestAppContext,
527) -> usize {
528 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
529
530 let subscription = cx.update(|cx| {
531 cx.subscribe(thread, move |thread, _, cx| {
532 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
533 if wait_until(entry) {
534 return tx.try_send(ix).unwrap();
535 }
536 }
537 })
538 });
539
540 select! {
541 // We have to use a smol timer here because
542 // cx.background_executor().timer isn't real in the test context
543 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
544 panic!("Timeout waiting for tool call")
545 }
546 ix = rx.next().fuse() => {
547 drop(subscription);
548 ix.unwrap()
549 }
550 }
551}
552
553pub fn get_zed_path() -> PathBuf {
554 let mut zed_path = std::env::current_exe().unwrap();
555
556 while zed_path
557 .file_name()
558 .is_none_or(|name| name.to_string_lossy() != "debug")
559 {
560 if !zed_path.pop() {
561 panic!("Could not find target directory");
562 }
563 }
564
565 zed_path.push("zed");
566
567 if !zed_path.exists() {
568 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
569 }
570
571 zed_path
572}