1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::{AppContext, Entity, TestAppContext};
6use indoc::indoc;
7#[cfg(test)]
8use project::agent_server_store::BuiltinAgentServerSettings;
9use project::{FakeFs, Project, agent_server_store::AllAgentServersSettings};
10use std::{
11 path::{Path, PathBuf},
12 sync::Arc,
13 time::Duration,
14};
15use util::path;
16
17pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
18where
19 T: AgentServer + 'static,
20 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
21{
22 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
23 let project = Project::test(fs.clone(), [], cx).await;
24 let thread = new_test_thread(
25 server(&fs, &project, cx).await,
26 project.clone(),
27 "/private/tmp",
28 cx,
29 )
30 .await;
31
32 thread
33 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
34 .await
35 .unwrap();
36
37 thread.read_with(cx, |thread, _| {
38 assert!(
39 thread.entries().len() >= 2,
40 "Expected at least 2 entries. Got: {:?}",
41 thread.entries()
42 );
43 assert!(matches!(
44 thread.entries()[0],
45 AgentThreadEntry::UserMessage(_)
46 ));
47 assert!(matches!(
48 thread.entries()[1],
49 AgentThreadEntry::AssistantMessage(_)
50 ));
51 });
52}
53
54pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
55where
56 T: AgentServer + 'static,
57 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
58{
59 let fs = init_test(cx).await as _;
60
61 let tempdir = tempfile::tempdir().unwrap();
62 std::fs::write(
63 tempdir.path().join("foo.rs"),
64 indoc! {"
65 fn main() {
66 println!(\"Hello, world!\");
67 }
68 "},
69 )
70 .expect("failed to write file");
71 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
72 let thread = new_test_thread(
73 server(&fs, &project, cx).await,
74 project.clone(),
75 tempdir.path(),
76 cx,
77 )
78 .await;
79 thread
80 .update(cx, |thread, cx| {
81 thread.send(
82 vec![
83 acp::ContentBlock::Text(acp::TextContent {
84 text: "Read the file ".into(),
85 annotations: None,
86 }),
87 acp::ContentBlock::ResourceLink(acp::ResourceLink {
88 uri: "foo.rs".into(),
89 name: "foo.rs".into(),
90 annotations: None,
91 description: None,
92 mime_type: None,
93 size: None,
94 title: None,
95 }),
96 acp::ContentBlock::Text(acp::TextContent {
97 text: " and tell me what the content of the println! is".into(),
98 annotations: None,
99 }),
100 ],
101 cx,
102 )
103 })
104 .await
105 .unwrap();
106
107 thread.read_with(cx, |thread, cx| {
108 assert!(matches!(
109 thread.entries()[0],
110 AgentThreadEntry::UserMessage(_)
111 ));
112 let assistant_message = &thread
113 .entries()
114 .iter()
115 .rev()
116 .find_map(|entry| match entry {
117 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
118 _ => None,
119 })
120 .unwrap();
121
122 assert!(
123 assistant_message.to_markdown(cx).contains("Hello, world!"),
124 "unexpected assistant message: {:?}",
125 assistant_message.to_markdown(cx)
126 );
127 });
128
129 drop(tempdir);
130}
131
132pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
133where
134 T: AgentServer + 'static,
135 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
136{
137 let fs = init_test(cx).await as _;
138
139 let tempdir = tempfile::tempdir().unwrap();
140 let foo_path = tempdir.path().join("foo");
141 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
142
143 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
144 let thread = new_test_thread(
145 server(&fs, &project, cx).await,
146 project.clone(),
147 "/private/tmp",
148 cx,
149 )
150 .await;
151
152 thread
153 .update(cx, |thread, cx| {
154 thread.send_raw(
155 &format!("Read {} and tell me what you see.", foo_path.display()),
156 cx,
157 )
158 })
159 .await
160 .unwrap();
161 thread.read_with(cx, |thread, _cx| {
162 assert!(thread.entries().iter().any(|entry| {
163 matches!(
164 entry,
165 AgentThreadEntry::ToolCall(ToolCall {
166 status: ToolCallStatus::Pending
167 | ToolCallStatus::InProgress
168 | ToolCallStatus::Completed,
169 ..
170 })
171 )
172 }));
173 assert!(
174 thread
175 .entries()
176 .iter()
177 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
178 );
179 });
180
181 drop(tempdir);
182}
183
184pub async fn test_tool_call_with_permission<T, F>(
185 server: F,
186 allow_option_id: acp::PermissionOptionId,
187 cx: &mut TestAppContext,
188) where
189 T: AgentServer + 'static,
190 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
191{
192 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
193 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
194 let thread = new_test_thread(
195 server(&fs, &project, cx).await,
196 project.clone(),
197 "/private/tmp",
198 cx,
199 )
200 .await;
201 let full_turn = thread.update(cx, |thread, cx| {
202 thread.send_raw(
203 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
204 cx,
205 )
206 });
207
208 run_until_first_tool_call(
209 &thread,
210 |entry| {
211 matches!(
212 entry,
213 AgentThreadEntry::ToolCall(ToolCall {
214 status: ToolCallStatus::WaitingForConfirmation { .. },
215 ..
216 })
217 )
218 },
219 cx,
220 )
221 .await;
222
223 let tool_call_id = thread.read_with(cx, |thread, cx| {
224 let AgentThreadEntry::ToolCall(ToolCall {
225 id,
226 label,
227 status: ToolCallStatus::WaitingForConfirmation { .. },
228 ..
229 }) = &thread
230 .entries()
231 .iter()
232 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
233 .unwrap()
234 else {
235 panic!();
236 };
237
238 let label = label.read(cx).source();
239 assert!(label.contains("touch"), "Got: {}", label);
240
241 id.clone()
242 });
243
244 thread.update(cx, |thread, cx| {
245 thread.authorize_tool_call(
246 tool_call_id,
247 allow_option_id,
248 acp::PermissionOptionKind::AllowOnce,
249 cx,
250 );
251
252 assert!(thread.entries().iter().any(|entry| matches!(
253 entry,
254 AgentThreadEntry::ToolCall(ToolCall {
255 status: ToolCallStatus::Pending
256 | ToolCallStatus::InProgress
257 | ToolCallStatus::Completed,
258 ..
259 })
260 )));
261 });
262
263 full_turn.await.unwrap();
264
265 thread.read_with(cx, |thread, cx| {
266 let AgentThreadEntry::ToolCall(ToolCall {
267 content,
268 status: ToolCallStatus::Pending
269 | ToolCallStatus::InProgress
270 | ToolCallStatus::Completed,
271 ..
272 }) = thread
273 .entries()
274 .iter()
275 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
276 .unwrap()
277 else {
278 panic!();
279 };
280
281 assert!(
282 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
283 "Expected content to contain 'Hello'"
284 );
285 });
286}
287
288pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
289where
290 T: AgentServer + 'static,
291 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
292{
293 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
294
295 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
296 let thread = new_test_thread(
297 server(&fs, &project, cx).await,
298 project.clone(),
299 "/private/tmp",
300 cx,
301 )
302 .await;
303 let _ = thread.update(cx, |thread, cx| {
304 thread.send_raw(
305 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
306 cx,
307 )
308 });
309
310 let first_tool_call_ix = run_until_first_tool_call(
311 &thread,
312 |entry| {
313 matches!(
314 entry,
315 AgentThreadEntry::ToolCall(ToolCall {
316 status: ToolCallStatus::WaitingForConfirmation { .. },
317 ..
318 })
319 )
320 },
321 cx,
322 )
323 .await;
324
325 thread.read_with(cx, |thread, cx| {
326 let AgentThreadEntry::ToolCall(ToolCall {
327 id,
328 label,
329 status: ToolCallStatus::WaitingForConfirmation { .. },
330 ..
331 }) = &thread.entries()[first_tool_call_ix]
332 else {
333 panic!("{:?}", thread.entries()[1]);
334 };
335
336 let label = label.read(cx).source();
337 assert!(label.contains("touch"), "Got: {}", label);
338
339 id.clone()
340 });
341
342 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
343 thread.read_with(cx, |thread, _cx| {
344 let AgentThreadEntry::ToolCall(ToolCall {
345 status: ToolCallStatus::Canceled,
346 ..
347 }) = &thread.entries()[first_tool_call_ix]
348 else {
349 panic!();
350 };
351 });
352
353 thread
354 .update(cx, |thread, cx| {
355 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
356 })
357 .await
358 .unwrap();
359 thread.read_with(cx, |thread, _| {
360 assert!(matches!(
361 &thread.entries().last().unwrap(),
362 AgentThreadEntry::AssistantMessage(..),
363 ))
364 });
365}
366
367pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
368where
369 T: AgentServer + 'static,
370 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
371{
372 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
373 let project = Project::test(fs.clone(), [], cx).await;
374 let thread = new_test_thread(
375 server(&fs, &project, cx).await,
376 project.clone(),
377 "/private/tmp",
378 cx,
379 )
380 .await;
381
382 thread
383 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
384 .await
385 .unwrap();
386
387 thread.read_with(cx, |thread, _| {
388 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
389 });
390
391 let weak_thread = thread.downgrade();
392 drop(thread);
393
394 cx.executor().run_until_parked();
395 assert!(!weak_thread.is_upgradable());
396}
397
398#[macro_export]
399macro_rules! common_e2e_tests {
400 ($server:expr, allow_option_id = $allow_option_id:expr) => {
401 mod common_e2e {
402 use super::*;
403
404 #[::gpui::test]
405 #[cfg_attr(not(feature = "e2e"), ignore)]
406 async fn basic(cx: &mut ::gpui::TestAppContext) {
407 $crate::e2e_tests::test_basic($server, cx).await;
408 }
409
410 #[::gpui::test]
411 #[cfg_attr(not(feature = "e2e"), ignore)]
412 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
413 $crate::e2e_tests::test_path_mentions($server, cx).await;
414 }
415
416 #[::gpui::test]
417 #[cfg_attr(not(feature = "e2e"), ignore)]
418 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
419 $crate::e2e_tests::test_tool_call($server, cx).await;
420 }
421
422 #[::gpui::test]
423 #[cfg_attr(not(feature = "e2e"), ignore)]
424 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
425 $crate::e2e_tests::test_tool_call_with_permission(
426 $server,
427 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
428 cx,
429 )
430 .await;
431 }
432
433 #[::gpui::test]
434 #[cfg_attr(not(feature = "e2e"), ignore)]
435 async fn cancel(cx: &mut ::gpui::TestAppContext) {
436 $crate::e2e_tests::test_cancel($server, cx).await;
437 }
438
439 #[::gpui::test]
440 #[cfg_attr(not(feature = "e2e"), ignore)]
441 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
442 $crate::e2e_tests::test_thread_drop($server, cx).await;
443 }
444 }
445 };
446}
447pub use common_e2e_tests;
448
449// Helpers
450
451pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
452 use settings::Settings;
453
454 env_logger::try_init().ok();
455
456 cx.update(|cx| {
457 let settings_store = settings::SettingsStore::test(cx);
458 cx.set_global(settings_store);
459 Project::init_settings(cx);
460 language::init(cx);
461 gpui_tokio::init(cx);
462 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
463 cx.set_http_client(Arc::new(http_client));
464 client::init_settings(cx);
465 let client = client::Client::production(cx);
466 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
467 language_model::init(client.clone(), cx);
468 language_models::init(user_store, client, cx);
469 agent_settings::init(cx);
470 AllAgentServersSettings::register(cx);
471
472 #[cfg(test)]
473 AllAgentServersSettings::override_global(
474 AllAgentServersSettings {
475 claude: Some(BuiltinAgentServerSettings {
476 path: Some("claude-code-acp".into()),
477 args: None,
478 env: None,
479 ignore_system_version: None,
480 default_mode: None,
481 }),
482 gemini: Some(crate::gemini::tests::local_command().into()),
483 custom: collections::HashMap::default(),
484 },
485 cx,
486 );
487 });
488
489 cx.executor().allow_parking();
490
491 FakeFs::new(cx.executor())
492}
493
494pub async fn new_test_thread(
495 server: impl AgentServer + 'static,
496 project: Entity<Project>,
497 current_dir: impl AsRef<Path>,
498 cx: &mut TestAppContext,
499) -> Entity<AcpThread> {
500 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
501 let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
502
503 let (connection, _) = cx
504 .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
505 .await
506 .unwrap();
507
508 cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
509 .await
510 .unwrap()
511}
512
513pub async fn run_until_first_tool_call(
514 thread: &Entity<AcpThread>,
515 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
516 cx: &mut TestAppContext,
517) -> usize {
518 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
519
520 let subscription = cx.update(|cx| {
521 cx.subscribe(thread, move |thread, _, cx| {
522 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
523 if wait_until(entry) {
524 return tx.try_send(ix).unwrap();
525 }
526 }
527 })
528 });
529
530 select! {
531 // We have to use a smol timer here because
532 // cx.background_executor().timer isn't real in the test context
533 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
534 panic!("Timeout waiting for tool call")
535 }
536 ix = rx.next().fuse() => {
537 drop(subscription);
538 ix.unwrap()
539 }
540 }
541}
542
543pub fn get_zed_path() -> PathBuf {
544 let mut zed_path = std::env::current_exe().unwrap();
545
546 while zed_path
547 .file_name()
548 .is_none_or(|name| name.to_string_lossy() != "debug")
549 {
550 if !zed_path.pop() {
551 panic!("Could not find target directory");
552 }
553 }
554
555 zed_path.push("zed");
556
557 if !zed_path.exists() {
558 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
559 }
560
561 zed_path
562}