1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::Duration,
5};
6
7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
9use agent_client_protocol as acp;
10
11use futures::{FutureExt, StreamExt, channel::mpsc, select};
12use gpui::{Entity, TestAppContext};
13use indoc::indoc;
14use project::{FakeFs, Project};
15use settings::{Settings, SettingsStore};
16
17pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
18 let fs = init_test(cx).await;
19 let tempdir = tempfile::tempdir().unwrap();
20 let project = Project::test(fs, [], cx).await;
21 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
22
23 thread
24 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
25 .await
26 .unwrap();
27
28 thread.read_with(cx, |thread, _| {
29 assert!(
30 thread.entries().len() >= 2,
31 "Expected at least 2 entries. Got: {:?}",
32 thread.entries()
33 );
34 assert!(matches!(
35 thread.entries()[0],
36 AgentThreadEntry::UserMessage(_)
37 ));
38 assert!(matches!(
39 thread.entries()[1],
40 AgentThreadEntry::AssistantMessage(_)
41 ));
42 });
43
44 drop(tempdir);
45}
46
47pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
48 let _fs = init_test(cx).await;
49
50 let tempdir = tempfile::tempdir().unwrap();
51 std::fs::write(
52 tempdir.path().join("foo.rs"),
53 indoc! {"
54 fn main() {
55 println!(\"Hello, world!\");
56 }
57 "},
58 )
59 .expect("failed to write file");
60 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
61 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
62 thread
63 .update(cx, |thread, cx| {
64 thread.send(
65 vec![
66 acp::ContentBlock::Text(acp::TextContent {
67 text: "Read the file ".into(),
68 annotations: None,
69 }),
70 acp::ContentBlock::ResourceLink(acp::ResourceLink {
71 uri: "foo.rs".into(),
72 name: "foo.rs".into(),
73 annotations: None,
74 description: None,
75 mime_type: None,
76 size: None,
77 title: None,
78 }),
79 acp::ContentBlock::Text(acp::TextContent {
80 text: " and tell me what the content of the println! is".into(),
81 annotations: None,
82 }),
83 ],
84 cx,
85 )
86 })
87 .await
88 .unwrap();
89
90 thread.read_with(cx, |thread, cx| {
91 assert!(matches!(
92 thread.entries()[0],
93 AgentThreadEntry::UserMessage(_)
94 ));
95 let assistant_message = &thread
96 .entries()
97 .iter()
98 .rev()
99 .find_map(|entry| match entry {
100 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
101 _ => None,
102 })
103 .unwrap();
104
105 assert!(
106 assistant_message.to_markdown(cx).contains("Hello, world!"),
107 "unexpected assistant message: {:?}",
108 assistant_message.to_markdown(cx)
109 );
110 });
111
112 drop(tempdir);
113}
114
115pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
116 let _fs = init_test(cx).await;
117
118 let tempdir = tempfile::tempdir().unwrap();
119 let foo_path = tempdir.path().join("foo");
120 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
121
122 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
123 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
124
125 thread
126 .update(cx, |thread, cx| {
127 thread.send_raw(
128 &format!("Read {} and tell me what you see.", foo_path.display()),
129 cx,
130 )
131 })
132 .await
133 .unwrap();
134 thread.read_with(cx, |thread, _cx| {
135 assert!(thread.entries().iter().any(|entry| {
136 matches!(
137 entry,
138 AgentThreadEntry::ToolCall(ToolCall {
139 status: ToolCallStatus::Allowed { .. },
140 ..
141 })
142 )
143 }));
144 assert!(
145 thread
146 .entries()
147 .iter()
148 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
149 );
150 });
151
152 drop(tempdir);
153}
154
155pub async fn test_tool_call_with_permission(
156 server: impl AgentServer + 'static,
157 allow_option_id: acp::PermissionOptionId,
158 cx: &mut TestAppContext,
159) {
160 let fs = init_test(cx).await;
161 let tempdir = tempfile::tempdir().unwrap();
162 let project = Project::test(fs, [tempdir.path()], cx).await;
163 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
164 let full_turn = thread.update(cx, |thread, cx| {
165 thread.send_raw(
166 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
167 cx,
168 )
169 });
170
171 run_until_first_tool_call(
172 &thread,
173 |entry| {
174 matches!(
175 entry,
176 AgentThreadEntry::ToolCall(ToolCall {
177 status: ToolCallStatus::WaitingForConfirmation { .. },
178 ..
179 })
180 )
181 },
182 cx,
183 )
184 .await;
185
186 let tool_call_id = thread.read_with(cx, |thread, cx| {
187 let AgentThreadEntry::ToolCall(ToolCall {
188 id,
189 label,
190 status: ToolCallStatus::WaitingForConfirmation { .. },
191 ..
192 }) = &thread
193 .entries()
194 .iter()
195 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
196 .unwrap()
197 else {
198 panic!();
199 };
200
201 let label = label.read(cx).source();
202 assert!(label.contains("touch"), "Got: {}", label);
203
204 id.clone()
205 });
206
207 thread.update(cx, |thread, cx| {
208 thread.authorize_tool_call(
209 tool_call_id,
210 allow_option_id,
211 acp::PermissionOptionKind::AllowOnce,
212 cx,
213 );
214
215 assert!(thread.entries().iter().any(|entry| matches!(
216 entry,
217 AgentThreadEntry::ToolCall(ToolCall {
218 status: ToolCallStatus::Allowed { .. },
219 ..
220 })
221 )));
222 });
223
224 full_turn.await.unwrap();
225
226 thread.read_with(cx, |thread, cx| {
227 let AgentThreadEntry::ToolCall(ToolCall {
228 content,
229 status: ToolCallStatus::Allowed { .. },
230 ..
231 }) = thread
232 .entries()
233 .iter()
234 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
235 .unwrap()
236 else {
237 panic!();
238 };
239
240 assert!(
241 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
242 "Expected content to contain 'Hello'"
243 );
244 });
245
246 drop(tempdir);
247}
248
249pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
250 let fs = init_test(cx).await;
251 let tempdir = tempfile::tempdir().unwrap();
252 let project = Project::test(fs, [tempdir.path()], cx).await;
253 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
254 let _ = thread.update(cx, |thread, cx| {
255 thread.send_raw(
256 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
257 cx,
258 )
259 });
260
261 let first_tool_call_ix = run_until_first_tool_call(
262 &thread,
263 |entry| {
264 matches!(
265 entry,
266 AgentThreadEntry::ToolCall(ToolCall {
267 status: ToolCallStatus::WaitingForConfirmation { .. },
268 ..
269 })
270 )
271 },
272 cx,
273 )
274 .await;
275
276 thread.read_with(cx, |thread, cx| {
277 let AgentThreadEntry::ToolCall(ToolCall {
278 id,
279 label,
280 status: ToolCallStatus::WaitingForConfirmation { .. },
281 ..
282 }) = &thread.entries()[first_tool_call_ix]
283 else {
284 panic!("{:?}", thread.entries()[1]);
285 };
286
287 let label = label.read(cx).source();
288 assert!(label.contains("touch"), "Got: {}", label);
289
290 id.clone()
291 });
292
293 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
294 thread.read_with(cx, |thread, _cx| {
295 let AgentThreadEntry::ToolCall(ToolCall {
296 status: ToolCallStatus::Canceled,
297 ..
298 }) = &thread.entries()[first_tool_call_ix]
299 else {
300 panic!();
301 };
302 });
303
304 thread
305 .update(cx, |thread, cx| {
306 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
307 })
308 .await
309 .unwrap();
310 thread.read_with(cx, |thread, _| {
311 assert!(matches!(
312 &thread.entries().last().unwrap(),
313 AgentThreadEntry::AssistantMessage(..),
314 ))
315 });
316
317 drop(tempdir);
318}
319
320pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
321 let fs = init_test(cx).await;
322 let tempdir = tempfile::tempdir().unwrap();
323 let project = Project::test(fs, [], cx).await;
324 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
325
326 thread
327 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
328 .await
329 .unwrap();
330
331 thread.read_with(cx, |thread, _| {
332 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
333 });
334
335 let weak_thread = thread.downgrade();
336 drop(thread);
337
338 cx.executor().run_until_parked();
339 assert!(!weak_thread.is_upgradable());
340
341 drop(tempdir);
342}
343
344#[macro_export]
345macro_rules! common_e2e_tests {
346 ($server:expr, allow_option_id = $allow_option_id:expr) => {
347 mod common_e2e {
348 use super::*;
349
350 #[::gpui::test]
351 #[cfg_attr(not(feature = "e2e"), ignore)]
352 async fn basic(cx: &mut ::gpui::TestAppContext) {
353 $crate::e2e_tests::test_basic($server, cx).await;
354 }
355
356 #[::gpui::test]
357 #[cfg_attr(not(feature = "e2e"), ignore)]
358 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
359 $crate::e2e_tests::test_path_mentions($server, cx).await;
360 }
361
362 #[::gpui::test]
363 #[cfg_attr(not(feature = "e2e"), ignore)]
364 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
365 $crate::e2e_tests::test_tool_call($server, cx).await;
366 }
367
368 #[::gpui::test]
369 #[cfg_attr(not(feature = "e2e"), ignore)]
370 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
371 $crate::e2e_tests::test_tool_call_with_permission(
372 $server,
373 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
374 cx,
375 )
376 .await;
377 }
378
379 #[::gpui::test]
380 #[cfg_attr(not(feature = "e2e"), ignore)]
381 async fn cancel(cx: &mut ::gpui::TestAppContext) {
382 $crate::e2e_tests::test_cancel($server, cx).await;
383 }
384
385 #[::gpui::test]
386 #[cfg_attr(not(feature = "e2e"), ignore)]
387 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
388 $crate::e2e_tests::test_thread_drop($server, cx).await;
389 }
390 }
391 };
392}
393
394// Helpers
395
396pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
397 env_logger::try_init().ok();
398
399 cx.update(|cx| {
400 let settings_store = SettingsStore::test(cx);
401 cx.set_global(settings_store);
402 Project::init_settings(cx);
403 language::init(cx);
404 crate::settings::init(cx);
405
406 crate::AllAgentServersSettings::override_global(
407 AllAgentServersSettings {
408 claude: Some(AgentServerSettings {
409 command: crate::claude::tests::local_command(),
410 }),
411 gemini: Some(AgentServerSettings {
412 command: crate::gemini::tests::local_command(),
413 }),
414 },
415 cx,
416 );
417 });
418
419 cx.executor().allow_parking();
420
421 FakeFs::new(cx.executor())
422}
423
424pub async fn new_test_thread(
425 server: impl AgentServer + 'static,
426 project: Entity<Project>,
427 current_dir: impl AsRef<Path>,
428 cx: &mut TestAppContext,
429) -> Entity<AcpThread> {
430 let connection = cx
431 .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
432 .await
433 .unwrap();
434
435 let thread = connection
436 .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
437 .await
438 .unwrap();
439
440 thread
441}
442
443pub async fn run_until_first_tool_call(
444 thread: &Entity<AcpThread>,
445 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
446 cx: &mut TestAppContext,
447) -> usize {
448 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
449
450 let subscription = cx.update(|cx| {
451 cx.subscribe(thread, move |thread, _, cx| {
452 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
453 if wait_until(entry) {
454 return tx.try_send(ix).unwrap();
455 }
456 }
457 })
458 });
459
460 select! {
461 // We have to use a smol timer here because
462 // cx.background_executor().timer isn't real in the test context
463 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
464 panic!("Timeout waiting for tool call")
465 }
466 ix = rx.next().fuse() => {
467 drop(subscription);
468 ix.unwrap()
469 }
470 }
471}
472
473pub fn get_zed_path() -> PathBuf {
474 let mut zed_path = std::env::current_exe().unwrap();
475
476 while zed_path
477 .file_name()
478 .map_or(true, |name| name.to_string_lossy() != "debug")
479 {
480 if !zed_path.pop() {
481 panic!("Could not find target directory");
482 }
483 }
484
485 zed_path.push("zed");
486
487 if !zed_path.exists() {
488 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
489 }
490
491 zed_path
492}