1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::{AppContext, Entity, TestAppContext};
6use indoc::indoc;
7#[cfg(test)]
8use project::agent_server_store::BuiltinAgentServerSettings;
9use project::{FakeFs, Project};
10#[cfg(test)]
11use settings::Settings;
12use std::{
13 path::{Path, PathBuf},
14 sync::Arc,
15 time::Duration,
16};
17use util::path;
18
19pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
20where
21 T: AgentServer + 'static,
22 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
23{
24 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
25 let project = Project::test(fs.clone(), [], cx).await;
26 let thread = new_test_thread(
27 server(&fs, &project, cx).await,
28 project.clone(),
29 "/private/tmp",
30 cx,
31 )
32 .await;
33
34 thread
35 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
36 .await
37 .unwrap();
38
39 thread.read_with(cx, |thread, _| {
40 assert!(
41 thread.entries().len() >= 2,
42 "Expected at least 2 entries. Got: {:?}",
43 thread.entries()
44 );
45 assert!(matches!(
46 thread.entries()[0],
47 AgentThreadEntry::UserMessage(_)
48 ));
49 assert!(matches!(
50 thread.entries()[1],
51 AgentThreadEntry::AssistantMessage(_)
52 ));
53 });
54}
55
56pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
57where
58 T: AgentServer + 'static,
59 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
60{
61 let fs = init_test(cx).await as _;
62
63 let tempdir = tempfile::tempdir().unwrap();
64 std::fs::write(
65 tempdir.path().join("foo.rs"),
66 indoc! {"
67 fn main() {
68 println!(\"Hello, world!\");
69 }
70 "},
71 )
72 .expect("failed to write file");
73 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
74 let thread = new_test_thread(
75 server(&fs, &project, cx).await,
76 project.clone(),
77 tempdir.path(),
78 cx,
79 )
80 .await;
81 thread
82 .update(cx, |thread, cx| {
83 thread.send(
84 vec![
85 "Read the file ".into(),
86 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
87 " and tell me what the content of the println! is".into(),
88 ],
89 cx,
90 )
91 })
92 .await
93 .unwrap();
94
95 thread.read_with(cx, |thread, cx| {
96 assert!(matches!(
97 thread.entries()[0],
98 AgentThreadEntry::UserMessage(_)
99 ));
100 let assistant_message = &thread
101 .entries()
102 .iter()
103 .rev()
104 .find_map(|entry| match entry {
105 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
106 _ => None,
107 })
108 .unwrap();
109
110 assert!(
111 assistant_message.to_markdown(cx).contains("Hello, world!"),
112 "unexpected assistant message: {:?}",
113 assistant_message.to_markdown(cx)
114 );
115 });
116
117 drop(tempdir);
118}
119
120pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
121where
122 T: AgentServer + 'static,
123 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
124{
125 let fs = init_test(cx).await as _;
126
127 let tempdir = tempfile::tempdir().unwrap();
128 let foo_path = tempdir.path().join("foo");
129 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
130
131 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
132 let thread = new_test_thread(
133 server(&fs, &project, cx).await,
134 project.clone(),
135 "/private/tmp",
136 cx,
137 )
138 .await;
139
140 thread
141 .update(cx, |thread, cx| {
142 thread.send_raw(
143 &format!("Read {} and tell me what you see.", foo_path.display()),
144 cx,
145 )
146 })
147 .await
148 .unwrap();
149 thread.read_with(cx, |thread, _cx| {
150 assert!(thread.entries().iter().any(|entry| {
151 matches!(
152 entry,
153 AgentThreadEntry::ToolCall(ToolCall {
154 status: ToolCallStatus::Pending
155 | ToolCallStatus::InProgress
156 | ToolCallStatus::Completed,
157 ..
158 })
159 )
160 }));
161 assert!(
162 thread
163 .entries()
164 .iter()
165 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
166 );
167 });
168
169 drop(tempdir);
170}
171
172pub async fn test_tool_call_with_permission<T, F>(
173 server: F,
174 allow_option_id: acp::PermissionOptionId,
175 cx: &mut TestAppContext,
176) where
177 T: AgentServer + 'static,
178 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
179{
180 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
181 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
182 let thread = new_test_thread(
183 server(&fs, &project, cx).await,
184 project.clone(),
185 "/private/tmp",
186 cx,
187 )
188 .await;
189 let full_turn = thread.update(cx, |thread, cx| {
190 thread.send_raw(
191 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
192 cx,
193 )
194 });
195
196 run_until_first_tool_call(
197 &thread,
198 |entry| {
199 matches!(
200 entry,
201 AgentThreadEntry::ToolCall(ToolCall {
202 status: ToolCallStatus::WaitingForConfirmation { .. },
203 ..
204 })
205 )
206 },
207 cx,
208 )
209 .await;
210
211 let tool_call_id = thread.read_with(cx, |thread, cx| {
212 let AgentThreadEntry::ToolCall(ToolCall {
213 id,
214 label,
215 status: ToolCallStatus::WaitingForConfirmation { .. },
216 ..
217 }) = &thread
218 .entries()
219 .iter()
220 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
221 .unwrap()
222 else {
223 panic!();
224 };
225
226 let label = label.read(cx).source();
227 assert!(label.contains("touch"), "Got: {}", label);
228
229 id.clone()
230 });
231
232 thread.update(cx, |thread, cx| {
233 thread.authorize_tool_call(
234 tool_call_id,
235 allow_option_id,
236 acp::PermissionOptionKind::AllowOnce,
237 cx,
238 );
239
240 assert!(thread.entries().iter().any(|entry| matches!(
241 entry,
242 AgentThreadEntry::ToolCall(ToolCall {
243 status: ToolCallStatus::Pending
244 | ToolCallStatus::InProgress
245 | ToolCallStatus::Completed,
246 ..
247 })
248 )));
249 });
250
251 full_turn.await.unwrap();
252
253 thread.read_with(cx, |thread, cx| {
254 let AgentThreadEntry::ToolCall(ToolCall {
255 content,
256 status: ToolCallStatus::Pending
257 | ToolCallStatus::InProgress
258 | ToolCallStatus::Completed,
259 ..
260 }) = thread
261 .entries()
262 .iter()
263 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
264 .unwrap()
265 else {
266 panic!();
267 };
268
269 assert!(
270 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
271 "Expected content to contain 'Hello'"
272 );
273 });
274}
275
276pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
277where
278 T: AgentServer + 'static,
279 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
280{
281 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
282
283 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
284 let thread = new_test_thread(
285 server(&fs, &project, cx).await,
286 project.clone(),
287 "/private/tmp",
288 cx,
289 )
290 .await;
291 let _ = thread.update(cx, |thread, cx| {
292 thread.send_raw(
293 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
294 cx,
295 )
296 });
297
298 let first_tool_call_ix = run_until_first_tool_call(
299 &thread,
300 |entry| {
301 matches!(
302 entry,
303 AgentThreadEntry::ToolCall(ToolCall {
304 status: ToolCallStatus::WaitingForConfirmation { .. },
305 ..
306 })
307 )
308 },
309 cx,
310 )
311 .await;
312
313 thread.read_with(cx, |thread, cx| {
314 let AgentThreadEntry::ToolCall(ToolCall {
315 id,
316 label,
317 status: ToolCallStatus::WaitingForConfirmation { .. },
318 ..
319 }) = &thread.entries()[first_tool_call_ix]
320 else {
321 panic!("{:?}", thread.entries()[1]);
322 };
323
324 let label = label.read(cx).source();
325 assert!(label.contains("touch"), "Got: {}", label);
326
327 id.clone()
328 });
329
330 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
331 thread.read_with(cx, |thread, _cx| {
332 let AgentThreadEntry::ToolCall(ToolCall {
333 status: ToolCallStatus::Canceled,
334 ..
335 }) = &thread.entries()[first_tool_call_ix]
336 else {
337 panic!();
338 };
339 });
340
341 thread
342 .update(cx, |thread, cx| {
343 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
344 })
345 .await
346 .unwrap();
347 thread.read_with(cx, |thread, _| {
348 assert!(matches!(
349 &thread.entries().last().unwrap(),
350 AgentThreadEntry::AssistantMessage(..),
351 ))
352 });
353}
354
355pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
356where
357 T: AgentServer + 'static,
358 F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
359{
360 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
361 let project = Project::test(fs.clone(), [], cx).await;
362 let thread = new_test_thread(
363 server(&fs, &project, cx).await,
364 project.clone(),
365 "/private/tmp",
366 cx,
367 )
368 .await;
369
370 thread
371 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
372 .await
373 .unwrap();
374
375 thread.read_with(cx, |thread, _| {
376 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
377 });
378
379 let weak_thread = thread.downgrade();
380 drop(thread);
381
382 cx.executor().run_until_parked();
383 assert!(!weak_thread.is_upgradable());
384}
385
386#[macro_export]
387macro_rules! common_e2e_tests {
388 ($server:expr, allow_option_id = $allow_option_id:expr) => {
389 mod common_e2e {
390 use super::*;
391
392 #[::gpui::test]
393 #[cfg_attr(not(feature = "e2e"), ignore)]
394 async fn basic(cx: &mut ::gpui::TestAppContext) {
395 $crate::e2e_tests::test_basic($server, cx).await;
396 }
397
398 #[::gpui::test]
399 #[cfg_attr(not(feature = "e2e"), ignore)]
400 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
401 $crate::e2e_tests::test_path_mentions($server, cx).await;
402 }
403
404 #[::gpui::test]
405 #[cfg_attr(not(feature = "e2e"), ignore)]
406 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
407 $crate::e2e_tests::test_tool_call($server, cx).await;
408 }
409
410 #[::gpui::test]
411 #[cfg_attr(not(feature = "e2e"), ignore)]
412 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
413 $crate::e2e_tests::test_tool_call_with_permission(
414 $server,
415 ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
416 cx,
417 )
418 .await;
419 }
420
421 #[::gpui::test]
422 #[cfg_attr(not(feature = "e2e"), ignore)]
423 async fn cancel(cx: &mut ::gpui::TestAppContext) {
424 $crate::e2e_tests::test_cancel($server, cx).await;
425 }
426
427 #[::gpui::test]
428 #[cfg_attr(not(feature = "e2e"), ignore)]
429 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
430 $crate::e2e_tests::test_thread_drop($server, cx).await;
431 }
432 }
433 };
434}
435pub use common_e2e_tests;
436
437// Helpers
438
439pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
440 env_logger::try_init().ok();
441
442 cx.update(|cx| {
443 let settings_store = settings::SettingsStore::test(cx);
444 cx.set_global(settings_store);
445 gpui_tokio::init(cx);
446 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
447 cx.set_http_client(Arc::new(http_client));
448 let client = client::Client::production(cx);
449 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
450 language_model::init(client.clone(), cx);
451 language_models::init(user_store, client, cx);
452
453 #[cfg(test)]
454 project::agent_server_store::AllAgentServersSettings::override_global(
455 project::agent_server_store::AllAgentServersSettings {
456 claude: Some(BuiltinAgentServerSettings {
457 path: Some("claude-code-acp".into()),
458 args: None,
459 env: None,
460 ignore_system_version: None,
461 default_mode: None,
462 default_model: None,
463 }),
464 gemini: Some(crate::gemini::tests::local_command().into()),
465 codex: Some(BuiltinAgentServerSettings {
466 path: Some("codex-acp".into()),
467 args: None,
468 env: None,
469 ignore_system_version: None,
470 default_mode: None,
471 default_model: None,
472 }),
473 custom: collections::HashMap::default(),
474 },
475 cx,
476 );
477 });
478
479 cx.executor().allow_parking();
480
481 FakeFs::new(cx.executor())
482}
483
484pub async fn new_test_thread(
485 server: impl AgentServer + 'static,
486 project: Entity<Project>,
487 current_dir: impl AsRef<Path>,
488 cx: &mut TestAppContext,
489) -> Entity<AcpThread> {
490 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
491 let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
492
493 let (connection, _) = cx
494 .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
495 .await
496 .unwrap();
497
498 cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
499 .await
500 .unwrap()
501}
502
503pub async fn run_until_first_tool_call(
504 thread: &Entity<AcpThread>,
505 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
506 cx: &mut TestAppContext,
507) -> usize {
508 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
509
510 let subscription = cx.update(|cx| {
511 cx.subscribe(thread, move |thread, _, cx| {
512 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
513 if wait_until(entry) {
514 return tx.try_send(ix).unwrap();
515 }
516 }
517 })
518 });
519
520 select! {
521 // We have to use a smol timer here because
522 // cx.background_executor().timer isn't real in the test context
523 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
524 panic!("Timeout waiting for tool call")
525 }
526 ix = rx.next().fuse() => {
527 drop(subscription);
528 ix.unwrap()
529 }
530 }
531}
532
533pub fn get_zed_path() -> PathBuf {
534 let mut zed_path = std::env::current_exe().unwrap();
535
536 while zed_path
537 .file_name()
538 .is_none_or(|name| name.to_string_lossy() != "debug")
539 {
540 if !zed_path.pop() {
541 panic!("Could not find target directory");
542 }
543 }
544
545 zed_path.push("zed");
546
547 if !zed_path.exists() {
548 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
549 }
550
551 zed_path
552}