1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::{AppContext, Entity, TestAppContext};
6use indoc::indoc;
7#[cfg(test)]
8use project::agent_server_store::BuiltinAgentServerSettings;
9use project::{FakeFs, Project};
10#[cfg(test)]
11use settings::Settings;
12use std::{
13 path::{Path, PathBuf},
14 sync::Arc,
15 time::Duration,
16};
17use util::path;
18
19pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
20where
21 T: AgentServer + 'static,
22 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
23{
24 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
25 let project = Project::test(fs.clone(), [], cx).await;
26 let thread = new_test_thread(
27 server(&fs, &project, cx).await,
28 project.clone(),
29 "/private/tmp",
30 cx,
31 )
32 .await;
33
34 thread
35 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
36 .await
37 .unwrap();
38
39 thread.read_with(cx, |thread, _| {
40 assert!(
41 thread.entries().len() >= 2,
42 "Expected at least 2 entries. Got: {:?}",
43 thread.entries()
44 );
45 assert!(matches!(
46 thread.entries()[0],
47 AgentThreadEntry::UserMessage(_)
48 ));
49 assert!(matches!(
50 thread.entries()[1],
51 AgentThreadEntry::AssistantMessage(_)
52 ));
53 });
54}
55
56pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
57where
58 T: AgentServer + 'static,
59 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
60{
61 let fs = init_test(cx).await as _;
62
63 let tempdir = tempfile::tempdir().unwrap();
64 std::fs::write(
65 tempdir.path().join("foo.rs"),
66 indoc! {"
67 fn main() {
68 println!(\"Hello, world!\");
69 }
70 "},
71 )
72 .expect("failed to write file");
73 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
74 let thread = new_test_thread(
75 server(&fs, &project, cx).await,
76 project.clone(),
77 tempdir.path(),
78 cx,
79 )
80 .await;
81 thread
82 .update(cx, |thread, cx| {
83 thread.send(
84 vec![
85 acp::ContentBlock::Text(acp::TextContent {
86 text: "Read the file ".into(),
87 annotations: None,
88 meta: None,
89 }),
90 acp::ContentBlock::ResourceLink(acp::ResourceLink {
91 uri: "foo.rs".into(),
92 name: "foo.rs".into(),
93 annotations: None,
94 description: None,
95 mime_type: None,
96 size: None,
97 title: None,
98 meta: None,
99 }),
100 acp::ContentBlock::Text(acp::TextContent {
101 text: " and tell me what the content of the println! is".into(),
102 annotations: None,
103 meta: None,
104 }),
105 ],
106 cx,
107 )
108 })
109 .await
110 .unwrap();
111
112 thread.read_with(cx, |thread, cx| {
113 assert!(matches!(
114 thread.entries()[0],
115 AgentThreadEntry::UserMessage(_)
116 ));
117 let assistant_message = &thread
118 .entries()
119 .iter()
120 .rev()
121 .find_map(|entry| match entry {
122 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
123 _ => None,
124 })
125 .unwrap();
126
127 assert!(
128 assistant_message.to_markdown(cx).contains("Hello, world!"),
129 "unexpected assistant message: {:?}",
130 assistant_message.to_markdown(cx)
131 );
132 });
133
134 drop(tempdir);
135}
136
137pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
138where
139 T: AgentServer + 'static,
140 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
141{
142 let fs = init_test(cx).await as _;
143
144 let tempdir = tempfile::tempdir().unwrap();
145 let foo_path = tempdir.path().join("foo");
146 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
147
148 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
149 let thread = new_test_thread(
150 server(&fs, &project, cx).await,
151 project.clone(),
152 "/private/tmp",
153 cx,
154 )
155 .await;
156
157 thread
158 .update(cx, |thread, cx| {
159 thread.send_raw(
160 &format!("Read {} and tell me what you see.", foo_path.display()),
161 cx,
162 )
163 })
164 .await
165 .unwrap();
166 thread.read_with(cx, |thread, _cx| {
167 assert!(thread.entries().iter().any(|entry| {
168 matches!(
169 entry,
170 AgentThreadEntry::ToolCall(ToolCall {
171 status: ToolCallStatus::Pending
172 | ToolCallStatus::InProgress
173 | ToolCallStatus::Completed,
174 ..
175 })
176 )
177 }));
178 assert!(
179 thread
180 .entries()
181 .iter()
182 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
183 );
184 });
185
186 drop(tempdir);
187}
188
189pub async fn test_tool_call_with_permission<T, F>(
190 server: F,
191 allow_option_id: acp::PermissionOptionId,
192 cx: &mut TestAppContext,
193) where
194 T: AgentServer + 'static,
195 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
196{
197 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
198 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
199 let thread = new_test_thread(
200 server(&fs, &project, cx).await,
201 project.clone(),
202 "/private/tmp",
203 cx,
204 )
205 .await;
206 let full_turn = thread.update(cx, |thread, cx| {
207 thread.send_raw(
208 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
209 cx,
210 )
211 });
212
213 run_until_first_tool_call(
214 &thread,
215 |entry| {
216 matches!(
217 entry,
218 AgentThreadEntry::ToolCall(ToolCall {
219 status: ToolCallStatus::WaitingForConfirmation { .. },
220 ..
221 })
222 )
223 },
224 cx,
225 )
226 .await;
227
228 let tool_call_id = thread.read_with(cx, |thread, cx| {
229 let AgentThreadEntry::ToolCall(ToolCall {
230 id,
231 label,
232 status: ToolCallStatus::WaitingForConfirmation { .. },
233 ..
234 }) = &thread
235 .entries()
236 .iter()
237 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
238 .unwrap()
239 else {
240 panic!();
241 };
242
243 let label = label.read(cx).source();
244 assert!(label.contains("touch"), "Got: {}", label);
245
246 id.clone()
247 });
248
249 thread.update(cx, |thread, cx| {
250 thread.authorize_tool_call(
251 tool_call_id,
252 allow_option_id,
253 acp::PermissionOptionKind::AllowOnce,
254 cx,
255 );
256
257 assert!(thread.entries().iter().any(|entry| matches!(
258 entry,
259 AgentThreadEntry::ToolCall(ToolCall {
260 status: ToolCallStatus::Pending
261 | ToolCallStatus::InProgress
262 | ToolCallStatus::Completed,
263 ..
264 })
265 )));
266 });
267
268 full_turn.await.unwrap();
269
270 thread.read_with(cx, |thread, cx| {
271 let AgentThreadEntry::ToolCall(ToolCall {
272 content,
273 status: ToolCallStatus::Pending
274 | ToolCallStatus::InProgress
275 | ToolCallStatus::Completed,
276 ..
277 }) = thread
278 .entries()
279 .iter()
280 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
281 .unwrap()
282 else {
283 panic!();
284 };
285
286 assert!(
287 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
288 "Expected content to contain 'Hello'"
289 );
290 });
291}
292
293pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
294where
295 T: AgentServer + 'static,
296 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
297{
298 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
299
300 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
301 let thread = new_test_thread(
302 server(&fs, &project, cx).await,
303 project.clone(),
304 "/private/tmp",
305 cx,
306 )
307 .await;
308 let _ = thread.update(cx, |thread, cx| {
309 thread.send_raw(
310 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
311 cx,
312 )
313 });
314
315 let first_tool_call_ix = run_until_first_tool_call(
316 &thread,
317 |entry| {
318 matches!(
319 entry,
320 AgentThreadEntry::ToolCall(ToolCall {
321 status: ToolCallStatus::WaitingForConfirmation { .. },
322 ..
323 })
324 )
325 },
326 cx,
327 )
328 .await;
329
330 thread.read_with(cx, |thread, cx| {
331 let AgentThreadEntry::ToolCall(ToolCall {
332 id,
333 label,
334 status: ToolCallStatus::WaitingForConfirmation { .. },
335 ..
336 }) = &thread.entries()[first_tool_call_ix]
337 else {
338 panic!("{:?}", thread.entries()[1]);
339 };
340
341 let label = label.read(cx).source();
342 assert!(label.contains("touch"), "Got: {}", label);
343
344 id.clone()
345 });
346
347 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
348 thread.read_with(cx, |thread, _cx| {
349 let AgentThreadEntry::ToolCall(ToolCall {
350 status: ToolCallStatus::Canceled,
351 ..
352 }) = &thread.entries()[first_tool_call_ix]
353 else {
354 panic!();
355 };
356 });
357
358 thread
359 .update(cx, |thread, cx| {
360 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
361 })
362 .await
363 .unwrap();
364 thread.read_with(cx, |thread, _| {
365 assert!(matches!(
366 &thread.entries().last().unwrap(),
367 AgentThreadEntry::AssistantMessage(..),
368 ))
369 });
370}
371
372pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
373where
374 T: AgentServer + 'static,
375 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
376{
377 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
378 let project = Project::test(fs.clone(), [], cx).await;
379 let thread = new_test_thread(
380 server(&fs, &project, cx).await,
381 project.clone(),
382 "/private/tmp",
383 cx,
384 )
385 .await;
386
387 thread
388 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
389 .await
390 .unwrap();
391
392 thread.read_with(cx, |thread, _| {
393 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
394 });
395
396 let weak_thread = thread.downgrade();
397 drop(thread);
398
399 cx.executor().run_until_parked();
400 assert!(!weak_thread.is_upgradable());
401}
402
403#[macro_export]
404macro_rules! common_e2e_tests {
405 ($server:expr, allow_option_id = $allow_option_id:expr) => {
406 mod common_e2e {
407 use super::*;
408
409 #[::gpui::test]
410 #[cfg_attr(not(feature = "e2e"), ignore)]
411 async fn basic(cx: &mut ::gpui::TestAppContext) {
412 $crate::e2e_tests::test_basic($server, cx).await;
413 }
414
415 #[::gpui::test]
416 #[cfg_attr(not(feature = "e2e"), ignore)]
417 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
418 $crate::e2e_tests::test_path_mentions($server, cx).await;
419 }
420
421 #[::gpui::test]
422 #[cfg_attr(not(feature = "e2e"), ignore)]
423 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
424 $crate::e2e_tests::test_tool_call($server, cx).await;
425 }
426
427 #[::gpui::test]
428 #[cfg_attr(not(feature = "e2e"), ignore)]
429 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
430 $crate::e2e_tests::test_tool_call_with_permission(
431 $server,
432 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
433 cx,
434 )
435 .await;
436 }
437
438 #[::gpui::test]
439 #[cfg_attr(not(feature = "e2e"), ignore)]
440 async fn cancel(cx: &mut ::gpui::TestAppContext) {
441 $crate::e2e_tests::test_cancel($server, cx).await;
442 }
443
444 #[::gpui::test]
445 #[cfg_attr(not(feature = "e2e"), ignore)]
446 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
447 $crate::e2e_tests::test_thread_drop($server, cx).await;
448 }
449 }
450 };
451}
452pub use common_e2e_tests;
453
454// Helpers
455
456pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
457 env_logger::try_init().ok();
458
459 cx.update(|cx| {
460 let settings_store = settings::SettingsStore::test(cx);
461 cx.set_global(settings_store);
462 gpui_tokio::init(cx);
463 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
464 cx.set_http_client(Arc::new(http_client));
465 let client = client::Client::production(cx);
466 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
467 language_model::init(client.clone(), cx);
468 language_models::init(user_store, client, cx);
469
470 #[cfg(test)]
471 project::agent_server_store::AllAgentServersSettings::override_global(
472 project::agent_server_store::AllAgentServersSettings {
473 claude: Some(BuiltinAgentServerSettings {
474 path: Some("claude-code-acp".into()),
475 args: None,
476 env: None,
477 ignore_system_version: None,
478 default_mode: None,
479 }),
480 gemini: Some(crate::gemini::tests::local_command().into()),
481 codex: Some(BuiltinAgentServerSettings {
482 path: Some("codex-acp".into()),
483 args: None,
484 env: None,
485 ignore_system_version: None,
486 default_mode: None,
487 }),
488 custom: collections::HashMap::default(),
489 },
490 cx,
491 );
492 });
493
494 cx.executor().allow_parking();
495
496 FakeFs::new(cx.executor())
497}
498
499pub async fn new_test_thread(
500 server: impl AgentServer + 'static,
501 project: Entity<Project>,
502 current_dir: impl AsRef<Path>,
503 cx: &mut TestAppContext,
504) -> Entity<AcpThread> {
505 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
506 let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
507
508 let (connection, _) = cx
509 .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
510 .await
511 .unwrap();
512
513 cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
514 .await
515 .unwrap()
516}
517
518pub async fn run_until_first_tool_call(
519 thread: &Entity<AcpThread>,
520 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
521 cx: &mut TestAppContext,
522) -> usize {
523 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
524
525 let subscription = cx.update(|cx| {
526 cx.subscribe(thread, move |thread, _, cx| {
527 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
528 if wait_until(entry) {
529 return tx.try_send(ix).unwrap();
530 }
531 }
532 })
533 });
534
535 select! {
536 // We have to use a smol timer here because
537 // cx.background_executor().timer isn't real in the test context
538 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
539 panic!("Timeout waiting for tool call")
540 }
541 ix = rx.next().fuse() => {
542 drop(subscription);
543 ix.unwrap()
544 }
545 }
546}
547
548pub fn get_zed_path() -> PathBuf {
549 let mut zed_path = std::env::current_exe().unwrap();
550
551 while zed_path
552 .file_name()
553 .is_none_or(|name| name.to_string_lossy() != "debug")
554 {
555 if !zed_path.pop() {
556 panic!("Could not find target directory");
557 }
558 }
559
560 zed_path.push("zed");
561
562 if !zed_path.exists() {
563 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
564 }
565
566 zed_path
567}