1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::Duration,
5};
6
7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
9use agent_client_protocol as acp;
10use agent_settings::AgentSettings;
11
12use futures::{FutureExt, StreamExt, channel::mpsc, select};
13use gpui::{Entity, TestAppContext};
14use indoc::indoc;
15use project::{FakeFs, Project};
16use settings::{Settings, SettingsStore};
17use util::path;
18
19pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
20 let fs = init_test(cx).await;
21 let project = Project::test(fs, [], cx).await;
22 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
23
24 thread
25 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
26 .await
27 .unwrap();
28
29 thread.read_with(cx, |thread, _| {
30 assert!(
31 thread.entries().len() >= 2,
32 "Expected at least 2 entries. Got: {:?}",
33 thread.entries()
34 );
35 assert!(matches!(
36 thread.entries()[0],
37 AgentThreadEntry::UserMessage(_)
38 ));
39 assert!(matches!(
40 thread.entries()[1],
41 AgentThreadEntry::AssistantMessage(_)
42 ));
43 });
44}
45
46pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
47 let _fs = init_test(cx).await;
48
49 let tempdir = tempfile::tempdir().unwrap();
50 std::fs::write(
51 tempdir.path().join("foo.rs"),
52 indoc! {"
53 fn main() {
54 println!(\"Hello, world!\");
55 }
56 "},
57 )
58 .expect("failed to write file");
59 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
60 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
61 thread
62 .update(cx, |thread, cx| {
63 thread.send(
64 vec![
65 acp::ContentBlock::Text(acp::TextContent {
66 text: "Read the file ".into(),
67 annotations: None,
68 }),
69 acp::ContentBlock::ResourceLink(acp::ResourceLink {
70 uri: "foo.rs".into(),
71 name: "foo.rs".into(),
72 annotations: None,
73 description: None,
74 mime_type: None,
75 size: None,
76 title: None,
77 }),
78 acp::ContentBlock::Text(acp::TextContent {
79 text: " and tell me what the content of the println! is".into(),
80 annotations: None,
81 }),
82 ],
83 cx,
84 )
85 })
86 .await
87 .unwrap();
88
89 thread.read_with(cx, |thread, cx| {
90 assert!(matches!(
91 thread.entries()[0],
92 AgentThreadEntry::UserMessage(_)
93 ));
94 let assistant_message = &thread
95 .entries()
96 .iter()
97 .rev()
98 .find_map(|entry| match entry {
99 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
100 _ => None,
101 })
102 .unwrap();
103
104 assert!(
105 assistant_message.to_markdown(cx).contains("Hello, world!"),
106 "unexpected assistant message: {:?}",
107 assistant_message.to_markdown(cx)
108 );
109 });
110
111 drop(tempdir);
112}
113
114pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
115 let _fs = init_test(cx).await;
116
117 let tempdir = tempfile::tempdir().unwrap();
118 let foo_path = tempdir.path().join("foo");
119 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
120
121 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
122 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
123
124 thread
125 .update(cx, |thread, cx| {
126 thread.send_raw(
127 &format!("Read {} and tell me what you see.", foo_path.display()),
128 cx,
129 )
130 })
131 .await
132 .unwrap();
133 thread.read_with(cx, |thread, _cx| {
134 assert!(thread.entries().iter().any(|entry| {
135 matches!(
136 entry,
137 AgentThreadEntry::ToolCall(ToolCall {
138 status: ToolCallStatus::Allowed { .. },
139 ..
140 })
141 )
142 }));
143 assert!(
144 thread
145 .entries()
146 .iter()
147 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
148 );
149 });
150
151 drop(tempdir);
152}
153
154pub async fn test_tool_call_with_confirmation(
155 server: impl AgentServer + 'static,
156 allow_option_id: acp::PermissionOptionId,
157 cx: &mut TestAppContext,
158) {
159 let fs = init_test(cx).await;
160 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
161 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
162 let full_turn = thread.update(cx, |thread, cx| {
163 thread.send_raw(
164 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
165 cx,
166 )
167 });
168
169 run_until_first_tool_call(
170 &thread,
171 |entry| {
172 matches!(
173 entry,
174 AgentThreadEntry::ToolCall(ToolCall {
175 status: ToolCallStatus::WaitingForConfirmation { .. },
176 ..
177 })
178 )
179 },
180 cx,
181 )
182 .await;
183
184 let tool_call_id = thread.read_with(cx, |thread, cx| {
185 let AgentThreadEntry::ToolCall(ToolCall {
186 id,
187 label,
188 status: ToolCallStatus::WaitingForConfirmation { .. },
189 ..
190 }) = &thread
191 .entries()
192 .iter()
193 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
194 .unwrap()
195 else {
196 panic!();
197 };
198
199 let label = label.read(cx).source();
200 assert!(label.contains("touch"), "Got: {}", label);
201
202 id.clone()
203 });
204
205 thread.update(cx, |thread, cx| {
206 thread.authorize_tool_call(
207 tool_call_id,
208 allow_option_id,
209 acp::PermissionOptionKind::AllowOnce,
210 cx,
211 );
212
213 assert!(thread.entries().iter().any(|entry| matches!(
214 entry,
215 AgentThreadEntry::ToolCall(ToolCall {
216 status: ToolCallStatus::Allowed { .. },
217 ..
218 })
219 )));
220 });
221
222 full_turn.await.unwrap();
223
224 thread.read_with(cx, |thread, cx| {
225 let AgentThreadEntry::ToolCall(ToolCall {
226 content,
227 status: ToolCallStatus::Allowed { .. },
228 ..
229 }) = thread
230 .entries()
231 .iter()
232 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
233 .unwrap()
234 else {
235 panic!();
236 };
237
238 assert!(
239 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
240 "Expected content to contain 'Hello'"
241 );
242 });
243}
244
245pub async fn test_tool_call_always_allow(
246 server: impl AgentServer + 'static,
247 cx: &mut TestAppContext,
248) {
249 let fs = init_test(cx).await;
250 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
251
252 // Enable always_allow_tool_actions
253 cx.update(|cx| {
254 AgentSettings::override_global(
255 AgentSettings {
256 always_allow_tool_actions: true,
257 ..Default::default()
258 },
259 cx,
260 );
261 });
262
263 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
264 let full_turn = thread.update(cx, |thread, cx| {
265 thread.send_raw(
266 r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
267 cx,
268 )
269 });
270
271 // Wait for the tool call to complete
272 full_turn.await.unwrap();
273
274 thread.read_with(cx, |thread, _cx| {
275 // With always_allow_tool_actions enabled, the tool call should be immediately allowed
276 // without waiting for confirmation
277 let tool_call_entry = thread
278 .entries()
279 .iter()
280 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
281 .expect("Expected a tool call entry");
282
283 let AgentThreadEntry::ToolCall(tool_call) = tool_call_entry else {
284 panic!("Expected tool call entry");
285 };
286
287 // Should be allowed, not waiting for confirmation
288 assert!(
289 matches!(tool_call.status, ToolCallStatus::Allowed { .. }),
290 "Expected tool call to be allowed automatically, but got {:?}",
291 tool_call.status
292 );
293 });
294}
295
296pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
297 let fs = init_test(cx).await;
298
299 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
300 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
301 let full_turn = thread.update(cx, |thread, cx| {
302 thread.send_raw(
303 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
304 cx,
305 )
306 });
307
308 let first_tool_call_ix = run_until_first_tool_call(
309 &thread,
310 |entry| {
311 matches!(
312 entry,
313 AgentThreadEntry::ToolCall(ToolCall {
314 status: ToolCallStatus::WaitingForConfirmation { .. },
315 ..
316 })
317 )
318 },
319 cx,
320 )
321 .await;
322
323 thread.read_with(cx, |thread, cx| {
324 let AgentThreadEntry::ToolCall(ToolCall {
325 id,
326 label,
327 status: ToolCallStatus::WaitingForConfirmation { .. },
328 ..
329 }) = &thread.entries()[first_tool_call_ix]
330 else {
331 panic!("{:?}", thread.entries()[1]);
332 };
333
334 let label = label.read(cx).source();
335 assert!(label.contains("touch"), "Got: {}", label);
336
337 id.clone()
338 });
339
340 let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
341 full_turn.await.unwrap();
342 thread.read_with(cx, |thread, _| {
343 let AgentThreadEntry::ToolCall(ToolCall {
344 status: ToolCallStatus::Canceled,
345 ..
346 }) = &thread.entries()[first_tool_call_ix]
347 else {
348 panic!();
349 };
350 });
351
352 thread
353 .update(cx, |thread, cx| {
354 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
355 })
356 .await
357 .unwrap();
358 thread.read_with(cx, |thread, _| {
359 assert!(matches!(
360 &thread.entries().last().unwrap(),
361 AgentThreadEntry::AssistantMessage(..),
362 ))
363 });
364}
365
366#[macro_export]
367macro_rules! common_e2e_tests {
368 ($server:expr, allow_option_id = $allow_option_id:expr) => {
369 mod common_e2e {
370 use super::*;
371
372 #[::gpui::test]
373 #[cfg_attr(not(feature = "e2e"), ignore)]
374 async fn basic(cx: &mut ::gpui::TestAppContext) {
375 $crate::e2e_tests::test_basic($server, cx).await;
376 }
377
378 #[::gpui::test]
379 #[cfg_attr(not(feature = "e2e"), ignore)]
380 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
381 $crate::e2e_tests::test_path_mentions($server, cx).await;
382 }
383
384 #[::gpui::test]
385 #[cfg_attr(not(feature = "e2e"), ignore)]
386 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
387 $crate::e2e_tests::test_tool_call($server, cx).await;
388 }
389
390 #[::gpui::test]
391 #[cfg_attr(not(feature = "e2e"), ignore)]
392 async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
393 $crate::e2e_tests::test_tool_call_with_confirmation(
394 $server,
395 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
396 cx,
397 )
398 .await;
399 }
400
401 #[::gpui::test]
402 #[cfg_attr(not(feature = "e2e"), ignore)]
403 async fn cancel(cx: &mut ::gpui::TestAppContext) {
404 $crate::e2e_tests::test_cancel($server, cx).await;
405 }
406
407 #[::gpui::test]
408 #[cfg_attr(not(feature = "e2e"), ignore)]
409 async fn tool_call_always_allow(cx: &mut ::gpui::TestAppContext) {
410 $crate::e2e_tests::test_tool_call_always_allow($server, cx).await;
411 }
412 }
413 };
414}
415
416// Helpers
417
418pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
419 env_logger::try_init().ok();
420
421 cx.update(|cx| {
422 let settings_store = SettingsStore::test(cx);
423 cx.set_global(settings_store);
424 Project::init_settings(cx);
425 language::init(cx);
426 crate::settings::init(cx);
427
428 crate::AllAgentServersSettings::override_global(
429 AllAgentServersSettings {
430 claude: Some(AgentServerSettings {
431 command: crate::claude::tests::local_command(),
432 }),
433 gemini: Some(AgentServerSettings {
434 command: crate::gemini::tests::local_command(),
435 }),
436 codex: Some(AgentServerSettings {
437 command: crate::codex::tests::local_command(),
438 }),
439 },
440 cx,
441 );
442 });
443
444 cx.executor().allow_parking();
445
446 FakeFs::new(cx.executor())
447}
448
449pub async fn new_test_thread(
450 server: impl AgentServer + 'static,
451 project: Entity<Project>,
452 current_dir: impl AsRef<Path>,
453 cx: &mut TestAppContext,
454) -> Entity<AcpThread> {
455 let connection = cx
456 .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
457 .await
458 .unwrap();
459
460 let thread = connection
461 .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
462 .await
463 .unwrap();
464
465 thread
466}
467
468pub async fn run_until_first_tool_call(
469 thread: &Entity<AcpThread>,
470 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
471 cx: &mut TestAppContext,
472) -> usize {
473 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
474
475 let subscription = cx.update(|cx| {
476 cx.subscribe(thread, move |thread, _, cx| {
477 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
478 if wait_until(entry) {
479 return tx.try_send(ix).unwrap();
480 }
481 }
482 })
483 });
484
485 select! {
486 // We have to use a smol timer here because
487 // cx.background_executor().timer isn't real in the test context
488 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
489 panic!("Timeout waiting for tool call")
490 }
491 ix = rx.next().fuse() => {
492 drop(subscription);
493 ix.unwrap()
494 }
495 }
496}
497
498pub fn get_zed_path() -> PathBuf {
499 let mut zed_path = std::env::current_exe().unwrap();
500
501 while zed_path
502 .file_name()
503 .map_or(true, |name| name.to_string_lossy() != "debug")
504 {
505 if !zed_path.pop() {
506 panic!("Could not find target directory");
507 }
508 }
509
510 zed_path.push("zed");
511
512 if !zed_path.exists() {
513 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
514 }
515
516 zed_path
517}