1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use client::RefreshLlmTokenListener;
5use futures::{FutureExt, StreamExt, channel::mpsc, select};
6use gpui::AppContext;
7use gpui::{Entity, TestAppContext};
8use indoc::indoc;
9use project::{FakeFs, Project};
10#[cfg(test)]
11use settings::Settings;
12use std::{
13 path::{Path, PathBuf},
14 sync::Arc,
15 time::Duration,
16};
17use util::path;
18use util::path_list::PathList;
19
20pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
21where
22 T: AgentServer + 'static,
23 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
24{
25 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
26 let project = Project::test(fs.clone(), [], cx).await;
27 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
28
29 thread
30 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
31 .await
32 .unwrap();
33
34 thread.read_with(cx, |thread, _| {
35 assert!(
36 thread.entries().len() >= 2,
37 "Expected at least 2 entries. Got: {:?}",
38 thread.entries()
39 );
40 assert!(matches!(
41 thread.entries()[0],
42 AgentThreadEntry::UserMessage(_)
43 ));
44 assert!(matches!(
45 thread.entries()[1],
46 AgentThreadEntry::AssistantMessage(_)
47 ));
48 });
49}
50
51pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
52where
53 T: AgentServer + 'static,
54 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
55{
56 let fs = init_test(cx).await as _;
57
58 let tempdir = tempfile::tempdir().unwrap();
59 std::fs::write(
60 tempdir.path().join("foo.rs"),
61 indoc! {"
62 fn main() {
63 println!(\"Hello, world!\");
64 }
65 "},
66 )
67 .expect("failed to write file");
68 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
69 let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
70 thread
71 .update(cx, |thread, cx| {
72 thread.send(
73 vec![
74 "Read the file ".into(),
75 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
76 " and tell me what the content of the println! is".into(),
77 ],
78 cx,
79 )
80 })
81 .await
82 .unwrap();
83
84 thread.read_with(cx, |thread, cx| {
85 assert!(matches!(
86 thread.entries()[0],
87 AgentThreadEntry::UserMessage(_)
88 ));
89 let assistant_message = &thread
90 .entries()
91 .iter()
92 .rev()
93 .find_map(|entry| match entry {
94 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
95 _ => None,
96 })
97 .unwrap();
98
99 assert!(
100 assistant_message.to_markdown(cx).contains("Hello, world!"),
101 "unexpected assistant message: {:?}",
102 assistant_message.to_markdown(cx)
103 );
104 });
105
106 drop(tempdir);
107}
108
109pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
110where
111 T: AgentServer + 'static,
112 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
113{
114 let fs = init_test(cx).await as _;
115
116 let tempdir = tempfile::tempdir().unwrap();
117 let foo_path = tempdir.path().join("foo");
118 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
119
120 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
121 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
122
123 thread
124 .update(cx, |thread, cx| {
125 thread.send_raw(
126 &format!("Read {} and tell me what you see.", foo_path.display()),
127 cx,
128 )
129 })
130 .await
131 .unwrap();
132 thread.read_with(cx, |thread, _cx| {
133 assert!(thread.entries().iter().any(|entry| {
134 matches!(
135 entry,
136 AgentThreadEntry::ToolCall(ToolCall {
137 status: ToolCallStatus::Pending
138 | ToolCallStatus::InProgress
139 | ToolCallStatus::Completed,
140 ..
141 })
142 )
143 }));
144 assert!(
145 thread
146 .entries()
147 .iter()
148 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
149 );
150 });
151
152 drop(tempdir);
153}
154
155pub async fn test_tool_call_with_permission<T, F>(
156 server: F,
157 allow_option_id: acp::PermissionOptionId,
158 cx: &mut TestAppContext,
159) where
160 T: AgentServer + 'static,
161 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
162{
163 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
164 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
165 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
166 let full_turn = thread.update(cx, |thread, cx| {
167 thread.send_raw(
168 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
169 cx,
170 )
171 });
172
173 run_until_first_tool_call(
174 &thread,
175 |entry| {
176 matches!(
177 entry,
178 AgentThreadEntry::ToolCall(ToolCall {
179 status: ToolCallStatus::WaitingForConfirmation { .. },
180 ..
181 })
182 )
183 },
184 cx,
185 )
186 .await;
187
188 let tool_call_id = thread.read_with(cx, |thread, cx| {
189 let AgentThreadEntry::ToolCall(ToolCall {
190 id,
191 label,
192 status: ToolCallStatus::WaitingForConfirmation { .. },
193 ..
194 }) = &thread
195 .entries()
196 .iter()
197 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
198 .unwrap()
199 else {
200 panic!();
201 };
202
203 let label = label.read(cx).source();
204 assert!(label.contains("touch"), "Got: {}", label);
205
206 id.clone()
207 });
208
209 thread.update(cx, |thread, cx| {
210 thread.authorize_tool_call(
211 tool_call_id,
212 acp_thread::SelectedPermissionOutcome::new(
213 allow_option_id,
214 acp::PermissionOptionKind::AllowOnce,
215 ),
216 cx,
217 );
218
219 assert!(thread.entries().iter().any(|entry| matches!(
220 entry,
221 AgentThreadEntry::ToolCall(ToolCall {
222 status: ToolCallStatus::Pending
223 | ToolCallStatus::InProgress
224 | ToolCallStatus::Completed,
225 ..
226 })
227 )));
228 });
229
230 full_turn.await.unwrap();
231
232 thread.read_with(cx, |thread, cx| {
233 let AgentThreadEntry::ToolCall(ToolCall {
234 content,
235 status: ToolCallStatus::Pending
236 | ToolCallStatus::InProgress
237 | ToolCallStatus::Completed,
238 ..
239 }) = thread
240 .entries()
241 .iter()
242 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
243 .unwrap()
244 else {
245 panic!();
246 };
247
248 assert!(
249 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
250 "Expected content to contain 'Hello'"
251 );
252 });
253}
254
255pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
256where
257 T: AgentServer + 'static,
258 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
259{
260 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
261
262 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
263 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
264 let _ = thread.update(cx, |thread, cx| {
265 thread.send_raw(
266 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
267 cx,
268 )
269 });
270
271 let first_tool_call_ix = run_until_first_tool_call(
272 &thread,
273 |entry| {
274 matches!(
275 entry,
276 AgentThreadEntry::ToolCall(ToolCall {
277 status: ToolCallStatus::WaitingForConfirmation { .. },
278 ..
279 })
280 )
281 },
282 cx,
283 )
284 .await;
285
286 thread.read_with(cx, |thread, cx| {
287 let AgentThreadEntry::ToolCall(ToolCall {
288 id,
289 label,
290 status: ToolCallStatus::WaitingForConfirmation { .. },
291 ..
292 }) = &thread.entries()[first_tool_call_ix]
293 else {
294 panic!("{:?}", thread.entries()[1]);
295 };
296
297 let label = label.read(cx).source();
298 assert!(label.contains("touch"), "Got: {}", label);
299
300 id.clone()
301 });
302
303 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
304 thread.read_with(cx, |thread, _cx| {
305 let AgentThreadEntry::ToolCall(ToolCall {
306 status: ToolCallStatus::Canceled,
307 ..
308 }) = &thread.entries()[first_tool_call_ix]
309 else {
310 panic!();
311 };
312 });
313
314 thread
315 .update(cx, |thread, cx| {
316 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
317 })
318 .await
319 .unwrap();
320 thread.read_with(cx, |thread, _| {
321 assert!(matches!(
322 &thread.entries().last().unwrap(),
323 AgentThreadEntry::AssistantMessage(..),
324 ))
325 });
326}
327
328pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
329where
330 T: AgentServer + 'static,
331 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
332{
333 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
334 let project = Project::test(fs.clone(), [], cx).await;
335 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
336
337 thread
338 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
339 .await
340 .unwrap();
341
342 thread.read_with(cx, |thread, _| {
343 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
344 });
345
346 let weak_thread = thread.downgrade();
347 drop(thread);
348
349 cx.executor().run_until_parked();
350 assert!(!weak_thread.is_upgradable());
351}
352
353#[macro_export]
354macro_rules! common_e2e_tests {
355 ($server:expr, allow_option_id = $allow_option_id:expr) => {
356 mod common_e2e {
357 use super::*;
358
359 #[::gpui::test]
360 #[cfg_attr(not(feature = "e2e"), ignore)]
361 async fn basic(cx: &mut ::gpui::TestAppContext) {
362 $crate::e2e_tests::test_basic($server, cx).await;
363 }
364
365 #[::gpui::test]
366 #[cfg_attr(not(feature = "e2e"), ignore)]
367 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
368 $crate::e2e_tests::test_path_mentions($server, cx).await;
369 }
370
371 #[::gpui::test]
372 #[cfg_attr(not(feature = "e2e"), ignore)]
373 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
374 $crate::e2e_tests::test_tool_call($server, cx).await;
375 }
376
377 #[::gpui::test]
378 #[cfg_attr(not(feature = "e2e"), ignore)]
379 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
380 $crate::e2e_tests::test_tool_call_with_permission(
381 $server,
382 ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
383 cx,
384 )
385 .await;
386 }
387
388 #[::gpui::test]
389 #[cfg_attr(not(feature = "e2e"), ignore)]
390 async fn cancel(cx: &mut ::gpui::TestAppContext) {
391 $crate::e2e_tests::test_cancel($server, cx).await;
392 }
393
394 #[::gpui::test]
395 #[cfg_attr(not(feature = "e2e"), ignore)]
396 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
397 $crate::e2e_tests::test_thread_drop($server, cx).await;
398 }
399 }
400 };
401}
402pub use common_e2e_tests;
403
404// Helpers
405
406pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
407 env_logger::try_init().ok();
408
409 cx.update(|cx| {
410 let settings_store = settings::SettingsStore::test(cx);
411 cx.set_global(settings_store);
412 gpui_tokio::init(cx);
413 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
414 cx.set_http_client(Arc::new(http_client));
415 let client = client::Client::production(cx);
416 let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
417 language_model::init(cx);
418 RefreshLlmTokenListener::register(client.clone(), user_store, cx);
419
420 #[cfg(test)]
421 project::agent_server_store::AllAgentServersSettings::override_global(
422 project::agent_server_store::AllAgentServersSettings(collections::HashMap::default()),
423 cx,
424 );
425 });
426
427 cx.executor().allow_parking();
428
429 FakeFs::new(cx.executor())
430}
431
432pub async fn new_test_thread(
433 server: impl AgentServer + 'static,
434 project: Entity<Project>,
435 current_dir: impl AsRef<Path>,
436 cx: &mut TestAppContext,
437) -> Entity<AcpThread> {
438 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
439 let delegate = AgentServerDelegate::new(store, None);
440
441 let connection = cx
442 .update(|cx| server.connect(delegate, project.clone(), cx))
443 .await
444 .unwrap();
445
446 cx.update(|cx| {
447 connection.new_session(project.clone(), PathList::new(&[current_dir.as_ref()]), cx)
448 })
449 .await
450 .unwrap()
451}
452
453pub async fn run_until_first_tool_call(
454 thread: &Entity<AcpThread>,
455 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
456 cx: &mut TestAppContext,
457) -> usize {
458 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
459
460 let subscription = cx.update(|cx| {
461 cx.subscribe(thread, move |thread, _, cx| {
462 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
463 if wait_until(entry) {
464 return tx.try_send(ix).unwrap();
465 }
466 }
467 })
468 });
469
470 select! {
471 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(20))) => {
472 panic!("Timeout waiting for tool call")
473 }
474 ix = rx.next().fuse() => {
475 drop(subscription);
476 ix.unwrap()
477 }
478 }
479}
480
481pub fn get_zed_path() -> PathBuf {
482 let mut zed_path = std::env::current_exe().unwrap();
483
484 while zed_path
485 .file_name()
486 .is_none_or(|name| name.to_string_lossy() != "debug")
487 {
488 if !zed_path.pop() {
489 panic!("Could not find target directory");
490 }
491 }
492
493 zed_path.push("zed");
494
495 if !zed_path.exists() {
496 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
497 }
498
499 zed_path
500}