1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::AppContext;
6use gpui::{Entity, TestAppContext};
7use indoc::indoc;
8use project::{FakeFs, Project};
9#[cfg(test)]
10use settings::Settings;
11use std::{
12 path::{Path, PathBuf},
13 sync::Arc,
14 time::Duration,
15};
16use util::path;
17use util::path_list::PathList;
18
19pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
20where
21 T: AgentServer + 'static,
22 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
23{
24 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
25 let project = Project::test(fs.clone(), [], cx).await;
26 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
27
28 thread
29 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
30 .await
31 .unwrap();
32
33 thread.read_with(cx, |thread, _| {
34 assert!(
35 thread.entries().len() >= 2,
36 "Expected at least 2 entries. Got: {:?}",
37 thread.entries()
38 );
39 assert!(matches!(
40 thread.entries()[0],
41 AgentThreadEntry::UserMessage(_)
42 ));
43 assert!(matches!(
44 thread.entries()[1],
45 AgentThreadEntry::AssistantMessage(_)
46 ));
47 });
48}
49
50pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
51where
52 T: AgentServer + 'static,
53 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
54{
55 let fs = init_test(cx).await as _;
56
57 let tempdir = tempfile::tempdir().unwrap();
58 std::fs::write(
59 tempdir.path().join("foo.rs"),
60 indoc! {"
61 fn main() {
62 println!(\"Hello, world!\");
63 }
64 "},
65 )
66 .expect("failed to write file");
67 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
68 let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
69 thread
70 .update(cx, |thread, cx| {
71 thread.send(
72 vec![
73 "Read the file ".into(),
74 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
75 " and tell me what the content of the println! is".into(),
76 ],
77 cx,
78 )
79 })
80 .await
81 .unwrap();
82
83 thread.read_with(cx, |thread, cx| {
84 assert!(matches!(
85 thread.entries()[0],
86 AgentThreadEntry::UserMessage(_)
87 ));
88 let assistant_message = &thread
89 .entries()
90 .iter()
91 .rev()
92 .find_map(|entry| match entry {
93 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
94 _ => None,
95 })
96 .unwrap();
97
98 assert!(
99 assistant_message.to_markdown(cx).contains("Hello, world!"),
100 "unexpected assistant message: {:?}",
101 assistant_message.to_markdown(cx)
102 );
103 });
104
105 drop(tempdir);
106}
107
108pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
109where
110 T: AgentServer + 'static,
111 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
112{
113 let fs = init_test(cx).await as _;
114
115 let tempdir = tempfile::tempdir().unwrap();
116 let foo_path = tempdir.path().join("foo");
117 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
118
119 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
120 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
121
122 thread
123 .update(cx, |thread, cx| {
124 thread.send_raw(
125 &format!("Read {} and tell me what you see.", foo_path.display()),
126 cx,
127 )
128 })
129 .await
130 .unwrap();
131 thread.read_with(cx, |thread, _cx| {
132 assert!(thread.entries().iter().any(|entry| {
133 matches!(
134 entry,
135 AgentThreadEntry::ToolCall(ToolCall {
136 status: ToolCallStatus::Pending
137 | ToolCallStatus::InProgress
138 | ToolCallStatus::Completed,
139 ..
140 })
141 )
142 }));
143 assert!(
144 thread
145 .entries()
146 .iter()
147 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
148 );
149 });
150
151 drop(tempdir);
152}
153
154pub async fn test_tool_call_with_permission<T, F>(
155 server: F,
156 allow_option_id: acp::PermissionOptionId,
157 cx: &mut TestAppContext,
158) where
159 T: AgentServer + 'static,
160 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
161{
162 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
163 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
164 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
165 let full_turn = thread.update(cx, |thread, cx| {
166 thread.send_raw(
167 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
168 cx,
169 )
170 });
171
172 run_until_first_tool_call(
173 &thread,
174 |entry| {
175 matches!(
176 entry,
177 AgentThreadEntry::ToolCall(ToolCall {
178 status: ToolCallStatus::WaitingForConfirmation { .. },
179 ..
180 })
181 )
182 },
183 cx,
184 )
185 .await;
186
187 let tool_call_id = thread.read_with(cx, |thread, cx| {
188 let AgentThreadEntry::ToolCall(ToolCall {
189 id,
190 label,
191 status: ToolCallStatus::WaitingForConfirmation { .. },
192 ..
193 }) = &thread
194 .entries()
195 .iter()
196 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
197 .unwrap()
198 else {
199 panic!();
200 };
201
202 let label = label.read(cx).source();
203 assert!(label.contains("touch"), "Got: {}", label);
204
205 id.clone()
206 });
207
208 thread.update(cx, |thread, cx| {
209 thread.authorize_tool_call(
210 tool_call_id,
211 acp_thread::SelectedPermissionOutcome::new(
212 allow_option_id,
213 acp::PermissionOptionKind::AllowOnce,
214 ),
215 cx,
216 );
217
218 assert!(thread.entries().iter().any(|entry| matches!(
219 entry,
220 AgentThreadEntry::ToolCall(ToolCall {
221 status: ToolCallStatus::Pending
222 | ToolCallStatus::InProgress
223 | ToolCallStatus::Completed,
224 ..
225 })
226 )));
227 });
228
229 full_turn.await.unwrap();
230
231 thread.read_with(cx, |thread, cx| {
232 let AgentThreadEntry::ToolCall(ToolCall {
233 content,
234 status: ToolCallStatus::Pending
235 | ToolCallStatus::InProgress
236 | ToolCallStatus::Completed,
237 ..
238 }) = thread
239 .entries()
240 .iter()
241 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
242 .unwrap()
243 else {
244 panic!();
245 };
246
247 assert!(
248 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
249 "Expected content to contain 'Hello'"
250 );
251 });
252}
253
254pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
255where
256 T: AgentServer + 'static,
257 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
258{
259 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
260
261 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
262 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
263 let _ = thread.update(cx, |thread, cx| {
264 thread.send_raw(
265 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
266 cx,
267 )
268 });
269
270 let first_tool_call_ix = run_until_first_tool_call(
271 &thread,
272 |entry| {
273 matches!(
274 entry,
275 AgentThreadEntry::ToolCall(ToolCall {
276 status: ToolCallStatus::WaitingForConfirmation { .. },
277 ..
278 })
279 )
280 },
281 cx,
282 )
283 .await;
284
285 thread.read_with(cx, |thread, cx| {
286 let AgentThreadEntry::ToolCall(ToolCall {
287 id,
288 label,
289 status: ToolCallStatus::WaitingForConfirmation { .. },
290 ..
291 }) = &thread.entries()[first_tool_call_ix]
292 else {
293 panic!("{:?}", thread.entries()[1]);
294 };
295
296 let label = label.read(cx).source();
297 assert!(label.contains("touch"), "Got: {}", label);
298
299 id.clone()
300 });
301
302 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
303 thread.read_with(cx, |thread, _cx| {
304 let AgentThreadEntry::ToolCall(ToolCall {
305 status: ToolCallStatus::Canceled,
306 ..
307 }) = &thread.entries()[first_tool_call_ix]
308 else {
309 panic!();
310 };
311 });
312
313 thread
314 .update(cx, |thread, cx| {
315 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
316 })
317 .await
318 .unwrap();
319 thread.read_with(cx, |thread, _| {
320 assert!(matches!(
321 &thread.entries().last().unwrap(),
322 AgentThreadEntry::AssistantMessage(..),
323 ))
324 });
325}
326
327pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
328where
329 T: AgentServer + 'static,
330 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
331{
332 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
333 let project = Project::test(fs.clone(), [], cx).await;
334 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
335
336 thread
337 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
338 .await
339 .unwrap();
340
341 thread.read_with(cx, |thread, _| {
342 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
343 });
344
345 let weak_thread = thread.downgrade();
346 drop(thread);
347
348 cx.executor().run_until_parked();
349 assert!(!weak_thread.is_upgradable());
350}
351
352#[macro_export]
353macro_rules! common_e2e_tests {
354 ($server:expr, allow_option_id = $allow_option_id:expr) => {
355 mod common_e2e {
356 use super::*;
357
358 #[::gpui::test]
359 #[cfg_attr(not(feature = "e2e"), ignore)]
360 async fn basic(cx: &mut ::gpui::TestAppContext) {
361 $crate::e2e_tests::test_basic($server, cx).await;
362 }
363
364 #[::gpui::test]
365 #[cfg_attr(not(feature = "e2e"), ignore)]
366 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
367 $crate::e2e_tests::test_path_mentions($server, cx).await;
368 }
369
370 #[::gpui::test]
371 #[cfg_attr(not(feature = "e2e"), ignore)]
372 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
373 $crate::e2e_tests::test_tool_call($server, cx).await;
374 }
375
376 #[::gpui::test]
377 #[cfg_attr(not(feature = "e2e"), ignore)]
378 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
379 $crate::e2e_tests::test_tool_call_with_permission(
380 $server,
381 ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
382 cx,
383 )
384 .await;
385 }
386
387 #[::gpui::test]
388 #[cfg_attr(not(feature = "e2e"), ignore)]
389 async fn cancel(cx: &mut ::gpui::TestAppContext) {
390 $crate::e2e_tests::test_cancel($server, cx).await;
391 }
392
393 #[::gpui::test]
394 #[cfg_attr(not(feature = "e2e"), ignore)]
395 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
396 $crate::e2e_tests::test_thread_drop($server, cx).await;
397 }
398 }
399 };
400}
401pub use common_e2e_tests;
402
403// Helpers
404
405pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
406 env_logger::try_init().ok();
407
408 cx.update(|cx| {
409 let settings_store = settings::SettingsStore::test(cx);
410 cx.set_global(settings_store);
411 gpui_tokio::init(cx);
412 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
413 cx.set_http_client(Arc::new(http_client));
414 let client = client::Client::production(cx);
415 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
416 language_model::init(user_store, client, cx);
417
418 #[cfg(test)]
419 project::agent_server_store::AllAgentServersSettings::override_global(
420 project::agent_server_store::AllAgentServersSettings(collections::HashMap::default()),
421 cx,
422 );
423 });
424
425 cx.executor().allow_parking();
426
427 FakeFs::new(cx.executor())
428}
429
430pub async fn new_test_thread(
431 server: impl AgentServer + 'static,
432 project: Entity<Project>,
433 current_dir: impl AsRef<Path>,
434 cx: &mut TestAppContext,
435) -> Entity<AcpThread> {
436 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
437 let delegate = AgentServerDelegate::new(store, None);
438
439 let connection = cx
440 .update(|cx| server.connect(delegate, project.clone(), cx))
441 .await
442 .unwrap();
443
444 cx.update(|cx| {
445 connection.new_session(project.clone(), PathList::new(&[current_dir.as_ref()]), cx)
446 })
447 .await
448 .unwrap()
449}
450
451pub async fn run_until_first_tool_call(
452 thread: &Entity<AcpThread>,
453 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
454 cx: &mut TestAppContext,
455) -> usize {
456 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
457
458 let subscription = cx.update(|cx| {
459 cx.subscribe(thread, move |thread, _, cx| {
460 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
461 if wait_until(entry) {
462 return tx.try_send(ix).unwrap();
463 }
464 }
465 })
466 });
467
468 select! {
469 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(20))) => {
470 panic!("Timeout waiting for tool call")
471 }
472 ix = rx.next().fuse() => {
473 drop(subscription);
474 ix.unwrap()
475 }
476 }
477}
478
479pub fn get_zed_path() -> PathBuf {
480 let mut zed_path = std::env::current_exe().unwrap();
481
482 while zed_path
483 .file_name()
484 .is_none_or(|name| name.to_string_lossy() != "debug")
485 {
486 if !zed_path.pop() {
487 panic!("Could not find target directory");
488 }
489 }
490
491 zed_path.push("zed");
492
493 if !zed_path.exists() {
494 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
495 }
496
497 zed_path
498}