1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::AppContext;
6use gpui::{Entity, TestAppContext};
7use indoc::indoc;
8use project::{FakeFs, Project};
9#[cfg(test)]
10use settings::Settings;
11use std::{
12 path::{Path, PathBuf},
13 sync::Arc,
14 time::Duration,
15};
16use util::path;
17
18pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
19where
20 T: AgentServer + 'static,
21 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
22{
23 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
24 let project = Project::test(fs.clone(), [], cx).await;
25 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
26
27 thread
28 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
29 .await
30 .unwrap();
31
32 thread.read_with(cx, |thread, _| {
33 assert!(
34 thread.entries().len() >= 2,
35 "Expected at least 2 entries. Got: {:?}",
36 thread.entries()
37 );
38 assert!(matches!(
39 thread.entries()[0],
40 AgentThreadEntry::UserMessage(_)
41 ));
42 assert!(matches!(
43 thread.entries()[1],
44 AgentThreadEntry::AssistantMessage(_)
45 ));
46 });
47}
48
49pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
50where
51 T: AgentServer + 'static,
52 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
53{
54 let fs = init_test(cx).await as _;
55
56 let tempdir = tempfile::tempdir().unwrap();
57 std::fs::write(
58 tempdir.path().join("foo.rs"),
59 indoc! {"
60 fn main() {
61 println!(\"Hello, world!\");
62 }
63 "},
64 )
65 .expect("failed to write file");
66 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
67 let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
68 thread
69 .update(cx, |thread, cx| {
70 thread.send(
71 vec![
72 "Read the file ".into(),
73 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
74 " and tell me what the content of the println! is".into(),
75 ],
76 cx,
77 )
78 })
79 .await
80 .unwrap();
81
82 thread.read_with(cx, |thread, cx| {
83 assert!(matches!(
84 thread.entries()[0],
85 AgentThreadEntry::UserMessage(_)
86 ));
87 let assistant_message = &thread
88 .entries()
89 .iter()
90 .rev()
91 .find_map(|entry| match entry {
92 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
93 _ => None,
94 })
95 .unwrap();
96
97 assert!(
98 assistant_message.to_markdown(cx).contains("Hello, world!"),
99 "unexpected assistant message: {:?}",
100 assistant_message.to_markdown(cx)
101 );
102 });
103
104 drop(tempdir);
105}
106
107pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
108where
109 T: AgentServer + 'static,
110 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
111{
112 let fs = init_test(cx).await as _;
113
114 let tempdir = tempfile::tempdir().unwrap();
115 let foo_path = tempdir.path().join("foo");
116 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
117
118 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
119 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
120
121 thread
122 .update(cx, |thread, cx| {
123 thread.send_raw(
124 &format!("Read {} and tell me what you see.", foo_path.display()),
125 cx,
126 )
127 })
128 .await
129 .unwrap();
130 thread.read_with(cx, |thread, _cx| {
131 assert!(thread.entries().iter().any(|entry| {
132 matches!(
133 entry,
134 AgentThreadEntry::ToolCall(ToolCall {
135 status: ToolCallStatus::Pending
136 | ToolCallStatus::InProgress
137 | ToolCallStatus::Completed,
138 ..
139 })
140 )
141 }));
142 assert!(
143 thread
144 .entries()
145 .iter()
146 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
147 );
148 });
149
150 drop(tempdir);
151}
152
153pub async fn test_tool_call_with_permission<T, F>(
154 server: F,
155 allow_option_id: acp::PermissionOptionId,
156 cx: &mut TestAppContext,
157) where
158 T: AgentServer + 'static,
159 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
160{
161 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
162 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
163 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
164 let full_turn = thread.update(cx, |thread, cx| {
165 thread.send_raw(
166 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
167 cx,
168 )
169 });
170
171 run_until_first_tool_call(
172 &thread,
173 |entry| {
174 matches!(
175 entry,
176 AgentThreadEntry::ToolCall(ToolCall {
177 status: ToolCallStatus::WaitingForConfirmation { .. },
178 ..
179 })
180 )
181 },
182 cx,
183 )
184 .await;
185
186 let tool_call_id = thread.read_with(cx, |thread, cx| {
187 let AgentThreadEntry::ToolCall(ToolCall {
188 id,
189 label,
190 status: ToolCallStatus::WaitingForConfirmation { .. },
191 ..
192 }) = &thread
193 .entries()
194 .iter()
195 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
196 .unwrap()
197 else {
198 panic!();
199 };
200
201 let label = label.read(cx).source();
202 assert!(label.contains("touch"), "Got: {}", label);
203
204 id.clone()
205 });
206
207 thread.update(cx, |thread, cx| {
208 thread.authorize_tool_call(
209 tool_call_id,
210 allow_option_id,
211 acp::PermissionOptionKind::AllowOnce,
212 cx,
213 );
214
215 assert!(thread.entries().iter().any(|entry| matches!(
216 entry,
217 AgentThreadEntry::ToolCall(ToolCall {
218 status: ToolCallStatus::Pending
219 | ToolCallStatus::InProgress
220 | ToolCallStatus::Completed,
221 ..
222 })
223 )));
224 });
225
226 full_turn.await.unwrap();
227
228 thread.read_with(cx, |thread, cx| {
229 let AgentThreadEntry::ToolCall(ToolCall {
230 content,
231 status: ToolCallStatus::Pending
232 | ToolCallStatus::InProgress
233 | ToolCallStatus::Completed,
234 ..
235 }) = thread
236 .entries()
237 .iter()
238 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
239 .unwrap()
240 else {
241 panic!();
242 };
243
244 assert!(
245 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
246 "Expected content to contain 'Hello'"
247 );
248 });
249}
250
251pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
252where
253 T: AgentServer + 'static,
254 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
255{
256 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
257
258 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
259 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
260 let _ = thread.update(cx, |thread, cx| {
261 thread.send_raw(
262 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
263 cx,
264 )
265 });
266
267 let first_tool_call_ix = run_until_first_tool_call(
268 &thread,
269 |entry| {
270 matches!(
271 entry,
272 AgentThreadEntry::ToolCall(ToolCall {
273 status: ToolCallStatus::WaitingForConfirmation { .. },
274 ..
275 })
276 )
277 },
278 cx,
279 )
280 .await;
281
282 thread.read_with(cx, |thread, cx| {
283 let AgentThreadEntry::ToolCall(ToolCall {
284 id,
285 label,
286 status: ToolCallStatus::WaitingForConfirmation { .. },
287 ..
288 }) = &thread.entries()[first_tool_call_ix]
289 else {
290 panic!("{:?}", thread.entries()[1]);
291 };
292
293 let label = label.read(cx).source();
294 assert!(label.contains("touch"), "Got: {}", label);
295
296 id.clone()
297 });
298
299 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
300 thread.read_with(cx, |thread, _cx| {
301 let AgentThreadEntry::ToolCall(ToolCall {
302 status: ToolCallStatus::Canceled,
303 ..
304 }) = &thread.entries()[first_tool_call_ix]
305 else {
306 panic!();
307 };
308 });
309
310 thread
311 .update(cx, |thread, cx| {
312 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
313 })
314 .await
315 .unwrap();
316 thread.read_with(cx, |thread, _| {
317 assert!(matches!(
318 &thread.entries().last().unwrap(),
319 AgentThreadEntry::AssistantMessage(..),
320 ))
321 });
322}
323
324pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
325where
326 T: AgentServer + 'static,
327 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
328{
329 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
330 let project = Project::test(fs.clone(), [], cx).await;
331 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
332
333 thread
334 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
335 .await
336 .unwrap();
337
338 thread.read_with(cx, |thread, _| {
339 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
340 });
341
342 let weak_thread = thread.downgrade();
343 drop(thread);
344
345 cx.executor().run_until_parked();
346 assert!(!weak_thread.is_upgradable());
347}
348
349#[macro_export]
350macro_rules! common_e2e_tests {
351 ($server:expr, allow_option_id = $allow_option_id:expr) => {
352 mod common_e2e {
353 use super::*;
354
355 #[::gpui::test]
356 #[cfg_attr(not(feature = "e2e"), ignore)]
357 async fn basic(cx: &mut ::gpui::TestAppContext) {
358 $crate::e2e_tests::test_basic($server, cx).await;
359 }
360
361 #[::gpui::test]
362 #[cfg_attr(not(feature = "e2e"), ignore)]
363 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
364 $crate::e2e_tests::test_path_mentions($server, cx).await;
365 }
366
367 #[::gpui::test]
368 #[cfg_attr(not(feature = "e2e"), ignore)]
369 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
370 $crate::e2e_tests::test_tool_call($server, cx).await;
371 }
372
373 #[::gpui::test]
374 #[cfg_attr(not(feature = "e2e"), ignore)]
375 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
376 $crate::e2e_tests::test_tool_call_with_permission(
377 $server,
378 ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
379 cx,
380 )
381 .await;
382 }
383
384 #[::gpui::test]
385 #[cfg_attr(not(feature = "e2e"), ignore)]
386 async fn cancel(cx: &mut ::gpui::TestAppContext) {
387 $crate::e2e_tests::test_cancel($server, cx).await;
388 }
389
390 #[::gpui::test]
391 #[cfg_attr(not(feature = "e2e"), ignore)]
392 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
393 $crate::e2e_tests::test_thread_drop($server, cx).await;
394 }
395 }
396 };
397}
398pub use common_e2e_tests;
399
400// Helpers
401
402pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
403 env_logger::try_init().ok();
404
405 cx.update(|cx| {
406 let settings_store = settings::SettingsStore::test(cx);
407 cx.set_global(settings_store);
408 gpui_tokio::init(cx);
409 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
410 cx.set_http_client(Arc::new(http_client));
411 let client = client::Client::production(cx);
412 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
413 language_model::init(user_store, client, cx);
414
415 #[cfg(test)]
416 project::agent_server_store::AllAgentServersSettings::override_global(
417 project::agent_server_store::AllAgentServersSettings(collections::HashMap::default()),
418 cx,
419 );
420 });
421
422 cx.executor().allow_parking();
423
424 FakeFs::new(cx.executor())
425}
426
427pub async fn new_test_thread(
428 server: impl AgentServer + 'static,
429 project: Entity<Project>,
430 current_dir: impl AsRef<Path>,
431 cx: &mut TestAppContext,
432) -> Entity<AcpThread> {
433 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
434 let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
435
436 let connection = cx.update(|cx| server.connect(delegate, cx)).await.unwrap();
437
438 cx.update(|cx| connection.new_session(project.clone(), current_dir.as_ref(), cx))
439 .await
440 .unwrap()
441}
442
443pub async fn run_until_first_tool_call(
444 thread: &Entity<AcpThread>,
445 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
446 cx: &mut TestAppContext,
447) -> usize {
448 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
449
450 let subscription = cx.update(|cx| {
451 cx.subscribe(thread, move |thread, _, cx| {
452 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
453 if wait_until(entry) {
454 return tx.try_send(ix).unwrap();
455 }
456 }
457 })
458 });
459
460 select! {
461 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(20))) => {
462 panic!("Timeout waiting for tool call")
463 }
464 ix = rx.next().fuse() => {
465 drop(subscription);
466 ix.unwrap()
467 }
468 }
469}
470
471pub fn get_zed_path() -> PathBuf {
472 let mut zed_path = std::env::current_exe().unwrap();
473
474 while zed_path
475 .file_name()
476 .is_none_or(|name| name.to_string_lossy() != "debug")
477 {
478 if !zed_path.pop() {
479 panic!("Could not find target directory");
480 }
481 }
482
483 zed_path.push("zed");
484
485 if !zed_path.exists() {
486 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
487 }
488
489 zed_path
490}