1use crate::{AgentServer, AgentServerDelegate};
2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
3use agent_client_protocol as acp;
4use futures::{FutureExt, StreamExt, channel::mpsc, select};
5use gpui::{Entity, TestAppContext};
6use indoc::indoc;
7use project::{FakeFs, Project};
8#[cfg(test)]
9use settings::Settings;
10use std::{
11 path::{Path, PathBuf},
12 sync::Arc,
13 time::Duration,
14};
15use util::path;
16
17pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
18where
19 T: AgentServer + 'static,
20 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
21{
22 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
23 let project = Project::test(fs.clone(), [], cx).await;
24 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
25
26 thread
27 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
28 .await
29 .unwrap();
30
31 thread.read_with(cx, |thread, _| {
32 assert!(
33 thread.entries().len() >= 2,
34 "Expected at least 2 entries. Got: {:?}",
35 thread.entries()
36 );
37 assert!(matches!(
38 thread.entries()[0],
39 AgentThreadEntry::UserMessage(_)
40 ));
41 assert!(matches!(
42 thread.entries()[1],
43 AgentThreadEntry::AssistantMessage(_)
44 ));
45 });
46}
47
48pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
49where
50 T: AgentServer + 'static,
51 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
52{
53 let fs = init_test(cx).await as _;
54
55 let tempdir = tempfile::tempdir().unwrap();
56 std::fs::write(
57 tempdir.path().join("foo.rs"),
58 indoc! {"
59 fn main() {
60 println!(\"Hello, world!\");
61 }
62 "},
63 )
64 .expect("failed to write file");
65 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
66 let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
67 thread
68 .update(cx, |thread, cx| {
69 thread.send(
70 vec![
71 "Read the file ".into(),
72 acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
73 " and tell me what the content of the println! is".into(),
74 ],
75 cx,
76 )
77 })
78 .await
79 .unwrap();
80
81 thread.read_with(cx, |thread, cx| {
82 assert!(matches!(
83 thread.entries()[0],
84 AgentThreadEntry::UserMessage(_)
85 ));
86 let assistant_message = &thread
87 .entries()
88 .iter()
89 .rev()
90 .find_map(|entry| match entry {
91 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
92 _ => None,
93 })
94 .unwrap();
95
96 assert!(
97 assistant_message.to_markdown(cx).contains("Hello, world!"),
98 "unexpected assistant message: {:?}",
99 assistant_message.to_markdown(cx)
100 );
101 });
102
103 drop(tempdir);
104}
105
106pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
107where
108 T: AgentServer + 'static,
109 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
110{
111 let fs = init_test(cx).await as _;
112
113 let tempdir = tempfile::tempdir().unwrap();
114 let foo_path = tempdir.path().join("foo");
115 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
116
117 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
118 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
119
120 thread
121 .update(cx, |thread, cx| {
122 thread.send_raw(
123 &format!("Read {} and tell me what you see.", foo_path.display()),
124 cx,
125 )
126 })
127 .await
128 .unwrap();
129 thread.read_with(cx, |thread, _cx| {
130 assert!(thread.entries().iter().any(|entry| {
131 matches!(
132 entry,
133 AgentThreadEntry::ToolCall(ToolCall {
134 status: ToolCallStatus::Pending
135 | ToolCallStatus::InProgress
136 | ToolCallStatus::Completed,
137 ..
138 })
139 )
140 }));
141 assert!(
142 thread
143 .entries()
144 .iter()
145 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
146 );
147 });
148
149 drop(tempdir);
150}
151
152pub async fn test_tool_call_with_permission<T, F>(
153 server: F,
154 allow_option_id: acp::PermissionOptionId,
155 cx: &mut TestAppContext,
156) where
157 T: AgentServer + 'static,
158 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
159{
160 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
161 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
162 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
163 let full_turn = thread.update(cx, |thread, cx| {
164 thread.send_raw(
165 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
166 cx,
167 )
168 });
169
170 run_until_first_tool_call(
171 &thread,
172 |entry| {
173 matches!(
174 entry,
175 AgentThreadEntry::ToolCall(ToolCall {
176 status: ToolCallStatus::WaitingForConfirmation { .. },
177 ..
178 })
179 )
180 },
181 cx,
182 )
183 .await;
184
185 let tool_call_id = thread.read_with(cx, |thread, cx| {
186 let AgentThreadEntry::ToolCall(ToolCall {
187 id,
188 label,
189 status: ToolCallStatus::WaitingForConfirmation { .. },
190 ..
191 }) = &thread
192 .entries()
193 .iter()
194 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
195 .unwrap()
196 else {
197 panic!();
198 };
199
200 let label = label.read(cx).source();
201 assert!(label.contains("touch"), "Got: {}", label);
202
203 id.clone()
204 });
205
206 thread.update(cx, |thread, cx| {
207 thread.authorize_tool_call(
208 tool_call_id,
209 allow_option_id,
210 acp::PermissionOptionKind::AllowOnce,
211 cx,
212 );
213
214 assert!(thread.entries().iter().any(|entry| matches!(
215 entry,
216 AgentThreadEntry::ToolCall(ToolCall {
217 status: ToolCallStatus::Pending
218 | ToolCallStatus::InProgress
219 | ToolCallStatus::Completed,
220 ..
221 })
222 )));
223 });
224
225 full_turn.await.unwrap();
226
227 thread.read_with(cx, |thread, cx| {
228 let AgentThreadEntry::ToolCall(ToolCall {
229 content,
230 status: ToolCallStatus::Pending
231 | ToolCallStatus::InProgress
232 | ToolCallStatus::Completed,
233 ..
234 }) = thread
235 .entries()
236 .iter()
237 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
238 .unwrap()
239 else {
240 panic!();
241 };
242
243 assert!(
244 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
245 "Expected content to contain 'Hello'"
246 );
247 });
248}
249
250pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
251where
252 T: AgentServer + 'static,
253 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
254{
255 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
256
257 let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
258 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
259 let _ = thread.update(cx, |thread, cx| {
260 thread.send_raw(
261 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
262 cx,
263 )
264 });
265
266 let first_tool_call_ix = run_until_first_tool_call(
267 &thread,
268 |entry| {
269 matches!(
270 entry,
271 AgentThreadEntry::ToolCall(ToolCall {
272 status: ToolCallStatus::WaitingForConfirmation { .. },
273 ..
274 })
275 )
276 },
277 cx,
278 )
279 .await;
280
281 thread.read_with(cx, |thread, cx| {
282 let AgentThreadEntry::ToolCall(ToolCall {
283 id,
284 label,
285 status: ToolCallStatus::WaitingForConfirmation { .. },
286 ..
287 }) = &thread.entries()[first_tool_call_ix]
288 else {
289 panic!("{:?}", thread.entries()[1]);
290 };
291
292 let label = label.read(cx).source();
293 assert!(label.contains("touch"), "Got: {}", label);
294
295 id.clone()
296 });
297
298 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
299 thread.read_with(cx, |thread, _cx| {
300 let AgentThreadEntry::ToolCall(ToolCall {
301 status: ToolCallStatus::Canceled,
302 ..
303 }) = &thread.entries()[first_tool_call_ix]
304 else {
305 panic!();
306 };
307 });
308
309 thread
310 .update(cx, |thread, cx| {
311 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
312 })
313 .await
314 .unwrap();
315 thread.read_with(cx, |thread, _| {
316 assert!(matches!(
317 &thread.entries().last().unwrap(),
318 AgentThreadEntry::AssistantMessage(..),
319 ))
320 });
321}
322
323pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
324where
325 T: AgentServer + 'static,
326 F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
327{
328 let fs = init_test(cx).await as Arc<dyn fs::Fs>;
329 let project = Project::test(fs.clone(), [], cx).await;
330 let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
331
332 thread
333 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
334 .await
335 .unwrap();
336
337 thread.read_with(cx, |thread, _| {
338 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
339 });
340
341 let weak_thread = thread.downgrade();
342 drop(thread);
343
344 cx.executor().run_until_parked();
345 assert!(!weak_thread.is_upgradable());
346}
347
348#[macro_export]
349macro_rules! common_e2e_tests {
350 ($server:expr, allow_option_id = $allow_option_id:expr) => {
351 mod common_e2e {
352 use super::*;
353
354 #[::gpui::test]
355 #[cfg_attr(not(feature = "e2e"), ignore)]
356 async fn basic(cx: &mut ::gpui::TestAppContext) {
357 $crate::e2e_tests::test_basic($server, cx).await;
358 }
359
360 #[::gpui::test]
361 #[cfg_attr(not(feature = "e2e"), ignore)]
362 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
363 $crate::e2e_tests::test_path_mentions($server, cx).await;
364 }
365
366 #[::gpui::test]
367 #[cfg_attr(not(feature = "e2e"), ignore)]
368 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
369 $crate::e2e_tests::test_tool_call($server, cx).await;
370 }
371
372 #[::gpui::test]
373 #[cfg_attr(not(feature = "e2e"), ignore)]
374 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
375 $crate::e2e_tests::test_tool_call_with_permission(
376 $server,
377 ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
378 cx,
379 )
380 .await;
381 }
382
383 #[::gpui::test]
384 #[cfg_attr(not(feature = "e2e"), ignore)]
385 async fn cancel(cx: &mut ::gpui::TestAppContext) {
386 $crate::e2e_tests::test_cancel($server, cx).await;
387 }
388
389 #[::gpui::test]
390 #[cfg_attr(not(feature = "e2e"), ignore)]
391 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
392 $crate::e2e_tests::test_thread_drop($server, cx).await;
393 }
394 }
395 };
396}
397pub use common_e2e_tests;
398
399// Helpers
400
401pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
402 env_logger::try_init().ok();
403
404 cx.update(|cx| {
405 let settings_store = settings::SettingsStore::test(cx);
406 cx.set_global(settings_store);
407 gpui_tokio::init(cx);
408 let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
409 cx.set_http_client(Arc::new(http_client));
410 let client = client::Client::production(cx);
411 language_model::init(client, cx);
412
413 #[cfg(test)]
414 project::agent_server_store::AllAgentServersSettings::override_global(
415 project::agent_server_store::AllAgentServersSettings(collections::HashMap::default()),
416 cx,
417 );
418 });
419
420 cx.executor().allow_parking();
421
422 FakeFs::new(cx.executor())
423}
424
425pub async fn new_test_thread(
426 server: impl AgentServer + 'static,
427 project: Entity<Project>,
428 current_dir: impl AsRef<Path>,
429 cx: &mut TestAppContext,
430) -> Entity<AcpThread> {
431 let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
432 let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
433
434 let connection = cx.update(|cx| server.connect(delegate, cx)).await.unwrap();
435
436 cx.update(|cx| connection.new_session(project.clone(), current_dir.as_ref(), cx))
437 .await
438 .unwrap()
439}
440
441pub async fn run_until_first_tool_call(
442 thread: &Entity<AcpThread>,
443 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
444 cx: &mut TestAppContext,
445) -> usize {
446 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
447
448 let subscription = cx.update(|cx| {
449 cx.subscribe(thread, move |thread, _, cx| {
450 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
451 if wait_until(entry) {
452 return tx.try_send(ix).unwrap();
453 }
454 }
455 })
456 });
457
458 select! {
459 _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(20))) => {
460 panic!("Timeout waiting for tool call")
461 }
462 ix = rx.next().fuse() => {
463 drop(subscription);
464 ix.unwrap()
465 }
466 }
467}
468
469pub fn get_zed_path() -> PathBuf {
470 let mut zed_path = std::env::current_exe().unwrap();
471
472 while zed_path
473 .file_name()
474 .is_none_or(|name| name.to_string_lossy() != "debug")
475 {
476 if !zed_path.pop() {
477 panic!("Could not find target directory");
478 }
479 }
480
481 zed_path.push("zed");
482
483 if !zed_path.exists() {
484 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
485 }
486
487 zed_path
488}