1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::Duration,
5};
6
7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
9use agent_client_protocol as acp;
10
11use futures::{FutureExt, StreamExt, channel::mpsc, select};
12use gpui::{Entity, TestAppContext};
13use indoc::indoc;
14use project::{FakeFs, Project};
15use settings::{Settings, SettingsStore};
16use util::path;
17
18pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
19 let fs = init_test(cx).await;
20 let project = Project::test(fs, [], cx).await;
21 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
22
23 thread
24 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
25 .await
26 .unwrap();
27
28 thread.read_with(cx, |thread, _| {
29 assert!(
30 thread.entries().len() >= 2,
31 "Expected at least 2 entries. Got: {:?}",
32 thread.entries()
33 );
34 assert!(matches!(
35 thread.entries()[0],
36 AgentThreadEntry::UserMessage(_)
37 ));
38 assert!(matches!(
39 thread.entries()[1],
40 AgentThreadEntry::AssistantMessage(_)
41 ));
42 });
43}
44
45pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
46 let _fs = init_test(cx).await;
47
48 let tempdir = tempfile::tempdir().unwrap();
49 std::fs::write(
50 tempdir.path().join("foo.rs"),
51 indoc! {"
52 fn main() {
53 println!(\"Hello, world!\");
54 }
55 "},
56 )
57 .expect("failed to write file");
58 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
59 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
60 thread
61 .update(cx, |thread, cx| {
62 thread.send(
63 vec![
64 acp::ContentBlock::Text(acp::TextContent {
65 text: "Read the file ".into(),
66 annotations: None,
67 }),
68 acp::ContentBlock::ResourceLink(acp::ResourceLink {
69 uri: "foo.rs".into(),
70 name: "foo.rs".into(),
71 annotations: None,
72 description: None,
73 mime_type: None,
74 size: None,
75 title: None,
76 }),
77 acp::ContentBlock::Text(acp::TextContent {
78 text: " and tell me what the content of the println! is".into(),
79 annotations: None,
80 }),
81 ],
82 cx,
83 )
84 })
85 .await
86 .unwrap();
87
88 thread.read_with(cx, |thread, cx| {
89 assert!(matches!(
90 thread.entries()[0],
91 AgentThreadEntry::UserMessage(_)
92 ));
93 let assistant_message = &thread
94 .entries()
95 .iter()
96 .rev()
97 .find_map(|entry| match entry {
98 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
99 _ => None,
100 })
101 .unwrap();
102
103 assert!(
104 assistant_message.to_markdown(cx).contains("Hello, world!"),
105 "unexpected assistant message: {:?}",
106 assistant_message.to_markdown(cx)
107 );
108 });
109
110 drop(tempdir);
111}
112
113pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
114 let _fs = init_test(cx).await;
115
116 let tempdir = tempfile::tempdir().unwrap();
117 let foo_path = tempdir.path().join("foo");
118 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
119
120 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
121 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
122
123 thread
124 .update(cx, |thread, cx| {
125 thread.send_raw(
126 &format!("Read {} and tell me what you see.", foo_path.display()),
127 cx,
128 )
129 })
130 .await
131 .unwrap();
132 thread.read_with(cx, |thread, _cx| {
133 assert!(thread.entries().iter().any(|entry| {
134 matches!(
135 entry,
136 AgentThreadEntry::ToolCall(ToolCall {
137 status: ToolCallStatus::Allowed { .. },
138 ..
139 })
140 )
141 }));
142 assert!(
143 thread
144 .entries()
145 .iter()
146 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
147 );
148 });
149
150 drop(tempdir);
151}
152
153pub async fn test_tool_call_with_permission(
154 server: impl AgentServer + 'static,
155 allow_option_id: acp::PermissionOptionId,
156 cx: &mut TestAppContext,
157) {
158 let fs = init_test(cx).await;
159 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
160 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
161 let full_turn = thread.update(cx, |thread, cx| {
162 thread.send_raw(
163 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
164 cx,
165 )
166 });
167
168 run_until_first_tool_call(
169 &thread,
170 |entry| {
171 matches!(
172 entry,
173 AgentThreadEntry::ToolCall(ToolCall {
174 status: ToolCallStatus::WaitingForConfirmation { .. },
175 ..
176 })
177 )
178 },
179 cx,
180 )
181 .await;
182
183 let tool_call_id = thread.read_with(cx, |thread, cx| {
184 let AgentThreadEntry::ToolCall(ToolCall {
185 id,
186 label,
187 status: ToolCallStatus::WaitingForConfirmation { .. },
188 ..
189 }) = &thread
190 .entries()
191 .iter()
192 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
193 .unwrap()
194 else {
195 panic!();
196 };
197
198 let label = label.read(cx).source();
199 assert!(label.contains("touch"), "Got: {}", label);
200
201 id.clone()
202 });
203
204 thread.update(cx, |thread, cx| {
205 thread.authorize_tool_call(
206 tool_call_id,
207 allow_option_id,
208 acp::PermissionOptionKind::AllowOnce,
209 cx,
210 );
211
212 assert!(thread.entries().iter().any(|entry| matches!(
213 entry,
214 AgentThreadEntry::ToolCall(ToolCall {
215 status: ToolCallStatus::Allowed { .. },
216 ..
217 })
218 )));
219 });
220
221 full_turn.await.unwrap();
222
223 thread.read_with(cx, |thread, cx| {
224 let AgentThreadEntry::ToolCall(ToolCall {
225 content,
226 status: ToolCallStatus::Allowed { .. },
227 ..
228 }) = thread
229 .entries()
230 .iter()
231 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
232 .unwrap()
233 else {
234 panic!();
235 };
236
237 assert!(
238 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
239 "Expected content to contain 'Hello'"
240 );
241 });
242}
243
244pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
245 let fs = init_test(cx).await;
246
247 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
248 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
249 let full_turn = thread.update(cx, |thread, cx| {
250 thread.send_raw(
251 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
252 cx,
253 )
254 });
255
256 let first_tool_call_ix = run_until_first_tool_call(
257 &thread,
258 |entry| {
259 matches!(
260 entry,
261 AgentThreadEntry::ToolCall(ToolCall {
262 status: ToolCallStatus::WaitingForConfirmation { .. },
263 ..
264 })
265 )
266 },
267 cx,
268 )
269 .await;
270
271 thread.read_with(cx, |thread, cx| {
272 let AgentThreadEntry::ToolCall(ToolCall {
273 id,
274 label,
275 status: ToolCallStatus::WaitingForConfirmation { .. },
276 ..
277 }) = &thread.entries()[first_tool_call_ix]
278 else {
279 panic!("{:?}", thread.entries()[1]);
280 };
281
282 let label = label.read(cx).source();
283 assert!(label.contains("touch"), "Got: {}", label);
284
285 id.clone()
286 });
287
288 let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
289 full_turn.await.unwrap();
290 thread.read_with(cx, |thread, _| {
291 let AgentThreadEntry::ToolCall(ToolCall {
292 status: ToolCallStatus::Canceled,
293 ..
294 }) = &thread.entries()[first_tool_call_ix]
295 else {
296 panic!();
297 };
298 });
299
300 thread
301 .update(cx, |thread, cx| {
302 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
303 })
304 .await
305 .unwrap();
306 thread.read_with(cx, |thread, _| {
307 assert!(matches!(
308 &thread.entries().last().unwrap(),
309 AgentThreadEntry::AssistantMessage(..),
310 ))
311 });
312}
313
314pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
315 let fs = init_test(cx).await;
316 let project = Project::test(fs, [], cx).await;
317 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
318
319 thread
320 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
321 .await
322 .unwrap();
323
324 thread.read_with(cx, |thread, _| {
325 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
326 });
327
328 let weak_thread = thread.downgrade();
329 drop(thread);
330
331 cx.executor().run_until_parked();
332 assert!(!weak_thread.is_upgradable());
333}
334
335#[macro_export]
336macro_rules! common_e2e_tests {
337 ($server:expr, allow_option_id = $allow_option_id:expr) => {
338 mod common_e2e {
339 use super::*;
340
341 #[::gpui::test]
342 #[cfg_attr(not(feature = "e2e"), ignore)]
343 async fn basic(cx: &mut ::gpui::TestAppContext) {
344 $crate::e2e_tests::test_basic($server, cx).await;
345 }
346
347 #[::gpui::test]
348 #[cfg_attr(not(feature = "e2e"), ignore)]
349 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
350 $crate::e2e_tests::test_path_mentions($server, cx).await;
351 }
352
353 #[::gpui::test]
354 #[cfg_attr(not(feature = "e2e"), ignore)]
355 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
356 $crate::e2e_tests::test_tool_call($server, cx).await;
357 }
358
359 #[::gpui::test]
360 #[cfg_attr(not(feature = "e2e"), ignore)]
361 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
362 $crate::e2e_tests::test_tool_call_with_permission(
363 $server,
364 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
365 cx,
366 )
367 .await;
368 }
369
370 #[::gpui::test]
371 #[cfg_attr(not(feature = "e2e"), ignore)]
372 async fn cancel(cx: &mut ::gpui::TestAppContext) {
373 $crate::e2e_tests::test_cancel($server, cx).await;
374 }
375
376 #[::gpui::test]
377 #[cfg_attr(not(feature = "e2e"), ignore)]
378 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
379 $crate::e2e_tests::test_thread_drop($server, cx).await;
380 }
381 }
382 };
383}
384
385// Helpers
386
387pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
388 env_logger::try_init().ok();
389
390 cx.update(|cx| {
391 let settings_store = SettingsStore::test(cx);
392 cx.set_global(settings_store);
393 Project::init_settings(cx);
394 language::init(cx);
395 crate::settings::init(cx);
396
397 crate::AllAgentServersSettings::override_global(
398 AllAgentServersSettings {
399 claude: Some(AgentServerSettings {
400 command: crate::claude::tests::local_command(),
401 }),
402 gemini: Some(AgentServerSettings {
403 command: crate::gemini::tests::local_command(),
404 }),
405 },
406 cx,
407 );
408 });
409
410 cx.executor().allow_parking();
411
412 FakeFs::new(cx.executor())
413}
414
415pub async fn new_test_thread(
416 server: impl AgentServer + 'static,
417 project: Entity<Project>,
418 current_dir: impl AsRef<Path>,
419 cx: &mut TestAppContext,
420) -> Entity<AcpThread> {
421 let connection = cx
422 .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
423 .await
424 .unwrap();
425
426 let thread = connection
427 .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
428 .await
429 .unwrap();
430
431 thread
432}
433
434pub async fn run_until_first_tool_call(
435 thread: &Entity<AcpThread>,
436 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
437 cx: &mut TestAppContext,
438) -> usize {
439 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
440
441 let subscription = cx.update(|cx| {
442 cx.subscribe(thread, move |thread, _, cx| {
443 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
444 if wait_until(entry) {
445 return tx.try_send(ix).unwrap();
446 }
447 }
448 })
449 });
450
451 select! {
452 // We have to use a smol timer here because
453 // cx.background_executor().timer isn't real in the test context
454 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
455 panic!("Timeout waiting for tool call")
456 }
457 ix = rx.next().fuse() => {
458 drop(subscription);
459 ix.unwrap()
460 }
461 }
462}
463
464pub fn get_zed_path() -> PathBuf {
465 let mut zed_path = std::env::current_exe().unwrap();
466
467 while zed_path
468 .file_name()
469 .map_or(true, |name| name.to_string_lossy() != "debug")
470 {
471 if !zed_path.pop() {
472 panic!("Could not find target directory");
473 }
474 }
475
476 zed_path.push("zed");
477
478 if !zed_path.exists() {
479 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
480 }
481
482 zed_path
483}