1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::Duration,
5};
6
7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
9use agent_client_protocol as acp;
10
11use futures::{FutureExt, StreamExt, channel::mpsc, select};
12use gpui::{Entity, TestAppContext};
13use indoc::indoc;
14use project::{FakeFs, Project};
15use settings::{Settings, SettingsStore};
16use util::path;
17
18pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
19 let fs = init_test(cx).await;
20 let project = Project::test(fs, [], cx).await;
21 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
22
23 thread
24 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
25 .await
26 .unwrap();
27
28 thread.read_with(cx, |thread, _| {
29 assert!(
30 thread.entries().len() >= 2,
31 "Expected at least 2 entries. Got: {:?}",
32 thread.entries()
33 );
34 assert!(matches!(
35 thread.entries()[0],
36 AgentThreadEntry::UserMessage(_)
37 ));
38 assert!(matches!(
39 thread.entries()[1],
40 AgentThreadEntry::AssistantMessage(_)
41 ));
42 });
43}
44
45pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
46 let _fs = init_test(cx).await;
47
48 let tempdir = tempfile::tempdir().unwrap();
49 std::fs::write(
50 tempdir.path().join("foo.rs"),
51 indoc! {"
52 fn main() {
53 println!(\"Hello, world!\");
54 }
55 "},
56 )
57 .expect("failed to write file");
58 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
59 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
60 thread
61 .update(cx, |thread, cx| {
62 thread.send(
63 vec![
64 acp::ContentBlock::Text(acp::TextContent {
65 text: "Read the file ".into(),
66 annotations: None,
67 }),
68 acp::ContentBlock::ResourceLink(acp::ResourceLink {
69 uri: "foo.rs".into(),
70 name: "foo.rs".into(),
71 annotations: None,
72 description: None,
73 mime_type: None,
74 size: None,
75 title: None,
76 }),
77 acp::ContentBlock::Text(acp::TextContent {
78 text: " and tell me what the content of the println! is".into(),
79 annotations: None,
80 }),
81 ],
82 cx,
83 )
84 })
85 .await
86 .unwrap();
87
88 thread.read_with(cx, |thread, cx| {
89 assert!(matches!(
90 thread.entries()[0],
91 AgentThreadEntry::UserMessage(_)
92 ));
93 let assistant_message = &thread
94 .entries()
95 .iter()
96 .rev()
97 .find_map(|entry| match entry {
98 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
99 _ => None,
100 })
101 .unwrap();
102
103 assert!(
104 assistant_message.to_markdown(cx).contains("Hello, world!"),
105 "unexpected assistant message: {:?}",
106 assistant_message.to_markdown(cx)
107 );
108 });
109
110 drop(tempdir);
111}
112
113pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
114 let _fs = init_test(cx).await;
115
116 let tempdir = tempfile::tempdir().unwrap();
117 let foo_path = tempdir.path().join("foo");
118 std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
119
120 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
121 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
122
123 thread
124 .update(cx, |thread, cx| {
125 thread.send_raw(
126 &format!("Read {} and tell me what you see.", foo_path.display()),
127 cx,
128 )
129 })
130 .await
131 .unwrap();
132 thread.read_with(cx, |thread, _cx| {
133 assert!(thread.entries().iter().any(|entry| {
134 matches!(
135 entry,
136 AgentThreadEntry::ToolCall(ToolCall {
137 status: ToolCallStatus::Pending
138 | ToolCallStatus::InProgress
139 | ToolCallStatus::Completed,
140 ..
141 })
142 )
143 }));
144 assert!(
145 thread
146 .entries()
147 .iter()
148 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
149 );
150 });
151
152 drop(tempdir);
153}
154
155pub async fn test_tool_call_with_permission(
156 server: impl AgentServer + 'static,
157 allow_option_id: acp::PermissionOptionId,
158 cx: &mut TestAppContext,
159) {
160 let fs = init_test(cx).await;
161 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
162 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
163 let full_turn = thread.update(cx, |thread, cx| {
164 thread.send_raw(
165 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
166 cx,
167 )
168 });
169
170 run_until_first_tool_call(
171 &thread,
172 |entry| {
173 matches!(
174 entry,
175 AgentThreadEntry::ToolCall(ToolCall {
176 status: ToolCallStatus::WaitingForConfirmation { .. },
177 ..
178 })
179 )
180 },
181 cx,
182 )
183 .await;
184
185 let tool_call_id = thread.read_with(cx, |thread, cx| {
186 let AgentThreadEntry::ToolCall(ToolCall {
187 id,
188 label,
189 status: ToolCallStatus::WaitingForConfirmation { .. },
190 ..
191 }) = &thread
192 .entries()
193 .iter()
194 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
195 .unwrap()
196 else {
197 panic!();
198 };
199
200 let label = label.read(cx).source();
201 assert!(label.contains("touch"), "Got: {}", label);
202
203 id.clone()
204 });
205
206 thread.update(cx, |thread, cx| {
207 thread.authorize_tool_call(
208 tool_call_id,
209 allow_option_id,
210 acp::PermissionOptionKind::AllowOnce,
211 cx,
212 );
213
214 assert!(thread.entries().iter().any(|entry| matches!(
215 entry,
216 AgentThreadEntry::ToolCall(ToolCall {
217 status: ToolCallStatus::Pending
218 | ToolCallStatus::InProgress
219 | ToolCallStatus::Completed,
220 ..
221 })
222 )));
223 });
224
225 full_turn.await.unwrap();
226
227 thread.read_with(cx, |thread, cx| {
228 let AgentThreadEntry::ToolCall(ToolCall {
229 content,
230 status: ToolCallStatus::Pending
231 | ToolCallStatus::InProgress
232 | ToolCallStatus::Completed,
233 ..
234 }) = thread
235 .entries()
236 .iter()
237 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
238 .unwrap()
239 else {
240 panic!();
241 };
242
243 assert!(
244 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
245 "Expected content to contain 'Hello'"
246 );
247 });
248}
249
250pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
251 let fs = init_test(cx).await;
252
253 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
254 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
255 let _ = thread.update(cx, |thread, cx| {
256 thread.send_raw(
257 r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
258 cx,
259 )
260 });
261
262 let first_tool_call_ix = run_until_first_tool_call(
263 &thread,
264 |entry| {
265 matches!(
266 entry,
267 AgentThreadEntry::ToolCall(ToolCall {
268 status: ToolCallStatus::WaitingForConfirmation { .. },
269 ..
270 })
271 )
272 },
273 cx,
274 )
275 .await;
276
277 thread.read_with(cx, |thread, cx| {
278 let AgentThreadEntry::ToolCall(ToolCall {
279 id,
280 label,
281 status: ToolCallStatus::WaitingForConfirmation { .. },
282 ..
283 }) = &thread.entries()[first_tool_call_ix]
284 else {
285 panic!("{:?}", thread.entries()[1]);
286 };
287
288 let label = label.read(cx).source();
289 assert!(label.contains("touch"), "Got: {}", label);
290
291 id.clone()
292 });
293
294 thread.update(cx, |thread, cx| thread.cancel(cx)).await;
295 thread.read_with(cx, |thread, _cx| {
296 let AgentThreadEntry::ToolCall(ToolCall {
297 status: ToolCallStatus::Canceled,
298 ..
299 }) = &thread.entries()[first_tool_call_ix]
300 else {
301 panic!();
302 };
303 });
304
305 thread
306 .update(cx, |thread, cx| {
307 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
308 })
309 .await
310 .unwrap();
311 thread.read_with(cx, |thread, _| {
312 assert!(matches!(
313 &thread.entries().last().unwrap(),
314 AgentThreadEntry::AssistantMessage(..),
315 ))
316 });
317}
318
319pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
320 let fs = init_test(cx).await;
321 let project = Project::test(fs, [], cx).await;
322 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
323
324 thread
325 .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
326 .await
327 .unwrap();
328
329 thread.read_with(cx, |thread, _| {
330 assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
331 });
332
333 let weak_thread = thread.downgrade();
334 drop(thread);
335
336 cx.executor().run_until_parked();
337 assert!(!weak_thread.is_upgradable());
338}
339
340#[macro_export]
341macro_rules! common_e2e_tests {
342 ($server:expr, allow_option_id = $allow_option_id:expr) => {
343 mod common_e2e {
344 use super::*;
345
346 #[::gpui::test]
347 #[cfg_attr(not(feature = "e2e"), ignore)]
348 async fn basic(cx: &mut ::gpui::TestAppContext) {
349 $crate::e2e_tests::test_basic($server, cx).await;
350 }
351
352 #[::gpui::test]
353 #[cfg_attr(not(feature = "e2e"), ignore)]
354 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
355 $crate::e2e_tests::test_path_mentions($server, cx).await;
356 }
357
358 #[::gpui::test]
359 #[cfg_attr(not(feature = "e2e"), ignore)]
360 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
361 $crate::e2e_tests::test_tool_call($server, cx).await;
362 }
363
364 #[::gpui::test]
365 #[cfg_attr(not(feature = "e2e"), ignore)]
366 async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
367 $crate::e2e_tests::test_tool_call_with_permission(
368 $server,
369 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
370 cx,
371 )
372 .await;
373 }
374
375 #[::gpui::test]
376 #[cfg_attr(not(feature = "e2e"), ignore)]
377 async fn cancel(cx: &mut ::gpui::TestAppContext) {
378 $crate::e2e_tests::test_cancel($server, cx).await;
379 }
380
381 #[::gpui::test]
382 #[cfg_attr(not(feature = "e2e"), ignore)]
383 async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
384 $crate::e2e_tests::test_thread_drop($server, cx).await;
385 }
386 }
387 };
388}
389
390// Helpers
391
392pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
393 env_logger::try_init().ok();
394
395 cx.update(|cx| {
396 let settings_store = SettingsStore::test(cx);
397 cx.set_global(settings_store);
398 Project::init_settings(cx);
399 language::init(cx);
400 crate::settings::init(cx);
401
402 crate::AllAgentServersSettings::override_global(
403 AllAgentServersSettings {
404 claude: Some(AgentServerSettings {
405 command: crate::claude::tests::local_command(),
406 }),
407 gemini: Some(AgentServerSettings {
408 command: crate::gemini::tests::local_command(),
409 }),
410 },
411 cx,
412 );
413 });
414
415 cx.executor().allow_parking();
416
417 FakeFs::new(cx.executor())
418}
419
420pub async fn new_test_thread(
421 server: impl AgentServer + 'static,
422 project: Entity<Project>,
423 current_dir: impl AsRef<Path>,
424 cx: &mut TestAppContext,
425) -> Entity<AcpThread> {
426 let connection = cx
427 .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
428 .await
429 .unwrap();
430
431 let thread = cx
432 .update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
433 .await
434 .unwrap();
435
436 thread
437}
438
439pub async fn run_until_first_tool_call(
440 thread: &Entity<AcpThread>,
441 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
442 cx: &mut TestAppContext,
443) -> usize {
444 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
445
446 let subscription = cx.update(|cx| {
447 cx.subscribe(thread, move |thread, _, cx| {
448 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
449 if wait_until(entry) {
450 return tx.try_send(ix).unwrap();
451 }
452 }
453 })
454 });
455
456 select! {
457 // We have to use a smol timer here because
458 // cx.background_executor().timer isn't real in the test context
459 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
460 panic!("Timeout waiting for tool call")
461 }
462 ix = rx.next().fuse() => {
463 drop(subscription);
464 ix.unwrap()
465 }
466 }
467}
468
469pub fn get_zed_path() -> PathBuf {
470 let mut zed_path = std::env::current_exe().unwrap();
471
472 while zed_path
473 .file_name()
474 .is_none_or(|name| name.to_string_lossy() != "debug")
475 {
476 if !zed_path.pop() {
477 panic!("Could not find target directory");
478 }
479 }
480
481 zed_path.push("zed");
482
483 if !zed_path.exists() {
484 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
485 }
486
487 zed_path
488}