1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4 time::Duration,
5};
6
7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
9use agent_client_protocol as acp;
10
11use futures::{FutureExt, StreamExt, channel::mpsc, select};
12use gpui::{Entity, TestAppContext};
13use indoc::indoc;
14use project::{FakeFs, Project};
15use serde_json::json;
16use settings::{Settings, SettingsStore};
17use util::path;
18
19pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
20 let fs = init_test(cx).await;
21 let project = Project::test(fs, [], cx).await;
22 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
23
24 thread
25 .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
26 .await
27 .unwrap();
28
29 thread.read_with(cx, |thread, _| {
30 assert_eq!(thread.entries().len(), 2);
31 assert!(matches!(
32 thread.entries()[0],
33 AgentThreadEntry::UserMessage(_)
34 ));
35 assert!(matches!(
36 thread.entries()[1],
37 AgentThreadEntry::AssistantMessage(_)
38 ));
39 });
40}
41
42pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
43 let _fs = init_test(cx).await;
44
45 let tempdir = tempfile::tempdir().unwrap();
46 std::fs::write(
47 tempdir.path().join("foo.rs"),
48 indoc! {"
49 fn main() {
50 println!(\"Hello, world!\");
51 }
52 "},
53 )
54 .expect("failed to write file");
55 let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
56 let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
57 thread
58 .update(cx, |thread, cx| {
59 thread.send(
60 vec![
61 acp::ContentBlock::Text(acp::TextContent {
62 text: "Read the file ".into(),
63 annotations: None,
64 }),
65 acp::ContentBlock::ResourceLink(acp::ResourceLink {
66 uri: "foo.rs".into(),
67 name: "foo.rs".into(),
68 annotations: None,
69 description: None,
70 mime_type: None,
71 size: None,
72 title: None,
73 }),
74 acp::ContentBlock::Text(acp::TextContent {
75 text: " and tell me what the content of the println! is".into(),
76 annotations: None,
77 }),
78 ],
79 cx,
80 )
81 })
82 .await
83 .unwrap();
84
85 thread.read_with(cx, |thread, cx| {
86 assert!(matches!(
87 thread.entries()[0],
88 AgentThreadEntry::UserMessage(_)
89 ));
90 let assistant_message = &thread
91 .entries()
92 .iter()
93 .rev()
94 .find_map(|entry| match entry {
95 AgentThreadEntry::AssistantMessage(msg) => Some(msg),
96 _ => None,
97 })
98 .unwrap();
99
100 assert!(
101 assistant_message.to_markdown(cx).contains("Hello, world!"),
102 "unexpected assistant message: {:?}",
103 assistant_message.to_markdown(cx)
104 );
105 });
106
107 drop(tempdir);
108}
109
110pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
111 let fs = init_test(cx).await;
112 fs.insert_tree(
113 path!("/private/tmp"),
114 json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
115 )
116 .await;
117 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
118 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
119
120 thread
121 .update(cx, |thread, cx| {
122 thread.send_raw(
123 "Read the '/private/tmp/foo' file and tell me what you see.",
124 cx,
125 )
126 })
127 .await
128 .unwrap();
129 thread.read_with(cx, |thread, _cx| {
130 assert!(thread.entries().iter().any(|entry| {
131 matches!(
132 entry,
133 AgentThreadEntry::ToolCall(ToolCall {
134 status: ToolCallStatus::Allowed { .. },
135 ..
136 })
137 )
138 }));
139 assert!(
140 thread
141 .entries()
142 .iter()
143 .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
144 );
145 });
146}
147
148pub async fn test_tool_call_with_confirmation(
149 server: impl AgentServer + 'static,
150 allow_option_id: acp::PermissionOptionId,
151 cx: &mut TestAppContext,
152) {
153 let fs = init_test(cx).await;
154 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
155 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
156 let full_turn = thread.update(cx, |thread, cx| {
157 thread.send_raw(
158 r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
159 cx,
160 )
161 });
162
163 run_until_first_tool_call(
164 &thread,
165 |entry| {
166 matches!(
167 entry,
168 AgentThreadEntry::ToolCall(ToolCall {
169 status: ToolCallStatus::WaitingForConfirmation { .. },
170 ..
171 })
172 )
173 },
174 cx,
175 )
176 .await;
177
178 let tool_call_id = thread.read_with(cx, |thread, _cx| {
179 let AgentThreadEntry::ToolCall(ToolCall {
180 id,
181 content,
182 status: ToolCallStatus::WaitingForConfirmation { .. },
183 ..
184 }) = &thread
185 .entries()
186 .iter()
187 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
188 .unwrap()
189 else {
190 panic!();
191 };
192
193 assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
194
195 id.clone()
196 });
197
198 thread.update(cx, |thread, cx| {
199 thread.authorize_tool_call(
200 tool_call_id,
201 allow_option_id,
202 acp::PermissionOptionKind::AllowOnce,
203 cx,
204 );
205
206 assert!(thread.entries().iter().any(|entry| matches!(
207 entry,
208 AgentThreadEntry::ToolCall(ToolCall {
209 status: ToolCallStatus::Allowed { .. },
210 ..
211 })
212 )));
213 });
214
215 full_turn.await.unwrap();
216
217 thread.read_with(cx, |thread, cx| {
218 let AgentThreadEntry::ToolCall(ToolCall {
219 content,
220 status: ToolCallStatus::Allowed { .. },
221 ..
222 }) = thread
223 .entries()
224 .iter()
225 .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
226 .unwrap()
227 else {
228 panic!();
229 };
230
231 assert!(
232 content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
233 "Expected content to contain 'Hello'"
234 );
235 });
236}
237
238pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
239 let fs = init_test(cx).await;
240
241 let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
242 let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
243 let full_turn = thread.update(cx, |thread, cx| {
244 thread.send_raw(
245 r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
246 cx,
247 )
248 });
249
250 let first_tool_call_ix = run_until_first_tool_call(
251 &thread,
252 |entry| {
253 matches!(
254 entry,
255 AgentThreadEntry::ToolCall(ToolCall {
256 status: ToolCallStatus::WaitingForConfirmation { .. },
257 ..
258 })
259 )
260 },
261 cx,
262 )
263 .await;
264
265 thread.read_with(cx, |thread, _cx| {
266 let AgentThreadEntry::ToolCall(ToolCall {
267 id,
268 content,
269 status: ToolCallStatus::WaitingForConfirmation { .. },
270 ..
271 }) = &thread.entries()[first_tool_call_ix]
272 else {
273 panic!("{:?}", thread.entries()[1]);
274 };
275
276 assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
277
278 id.clone()
279 });
280
281 let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
282 full_turn.await.unwrap();
283 thread.read_with(cx, |thread, _| {
284 let AgentThreadEntry::ToolCall(ToolCall {
285 status: ToolCallStatus::Canceled,
286 ..
287 }) = &thread.entries()[first_tool_call_ix]
288 else {
289 panic!();
290 };
291 });
292
293 thread
294 .update(cx, |thread, cx| {
295 thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
296 })
297 .await
298 .unwrap();
299 thread.read_with(cx, |thread, _| {
300 assert!(matches!(
301 &thread.entries().last().unwrap(),
302 AgentThreadEntry::AssistantMessage(..),
303 ))
304 });
305}
306
307#[macro_export]
308macro_rules! common_e2e_tests {
309 ($server:expr, allow_option_id = $allow_option_id:expr) => {
310 mod common_e2e {
311 use super::*;
312
313 #[::gpui::test]
314 #[cfg_attr(not(feature = "e2e"), ignore)]
315 async fn basic(cx: &mut ::gpui::TestAppContext) {
316 $crate::e2e_tests::test_basic($server, cx).await;
317 }
318
319 #[::gpui::test]
320 #[cfg_attr(not(feature = "e2e"), ignore)]
321 async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
322 $crate::e2e_tests::test_path_mentions($server, cx).await;
323 }
324
325 #[::gpui::test]
326 #[cfg_attr(not(feature = "e2e"), ignore)]
327 async fn tool_call(cx: &mut ::gpui::TestAppContext) {
328 $crate::e2e_tests::test_tool_call($server, cx).await;
329 }
330
331 #[::gpui::test]
332 #[cfg_attr(not(feature = "e2e"), ignore)]
333 async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
334 $crate::e2e_tests::test_tool_call_with_confirmation(
335 $server,
336 ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
337 cx,
338 )
339 .await;
340 }
341
342 #[::gpui::test]
343 #[cfg_attr(not(feature = "e2e"), ignore)]
344 async fn cancel(cx: &mut ::gpui::TestAppContext) {
345 $crate::e2e_tests::test_cancel($server, cx).await;
346 }
347 }
348 };
349}
350
351// Helpers
352
353pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
354 env_logger::try_init().ok();
355
356 cx.update(|cx| {
357 let settings_store = SettingsStore::test(cx);
358 cx.set_global(settings_store);
359 Project::init_settings(cx);
360 language::init(cx);
361 crate::settings::init(cx);
362
363 crate::AllAgentServersSettings::override_global(
364 AllAgentServersSettings {
365 claude: Some(AgentServerSettings {
366 command: crate::claude::tests::local_command(),
367 }),
368 gemini: Some(AgentServerSettings {
369 command: crate::gemini::tests::local_command(),
370 }),
371 codex: Some(AgentServerSettings {
372 command: crate::codex::tests::local_command(),
373 }),
374 },
375 cx,
376 );
377 });
378
379 cx.executor().allow_parking();
380
381 FakeFs::new(cx.executor())
382}
383
384pub async fn new_test_thread(
385 server: impl AgentServer + 'static,
386 project: Entity<Project>,
387 current_dir: impl AsRef<Path>,
388 cx: &mut TestAppContext,
389) -> Entity<AcpThread> {
390 let connection = cx
391 .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
392 .await
393 .unwrap();
394
395 let thread = connection
396 .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
397 .await
398 .unwrap();
399
400 thread
401}
402
403pub async fn run_until_first_tool_call(
404 thread: &Entity<AcpThread>,
405 wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
406 cx: &mut TestAppContext,
407) -> usize {
408 let (mut tx, mut rx) = mpsc::channel::<usize>(1);
409
410 let subscription = cx.update(|cx| {
411 cx.subscribe(thread, move |thread, _, cx| {
412 for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
413 if wait_until(entry) {
414 return tx.try_send(ix).unwrap();
415 }
416 }
417 })
418 });
419
420 select! {
421 // We have to use a smol timer here because
422 // cx.background_executor().timer isn't real in the test context
423 _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
424 panic!("Timeout waiting for tool call")
425 }
426 ix = rx.next().fuse() => {
427 drop(subscription);
428 ix.unwrap()
429 }
430 }
431}
432
433pub fn get_zed_path() -> PathBuf {
434 let mut zed_path = std::env::current_exe().unwrap();
435
436 while zed_path
437 .file_name()
438 .map_or(true, |name| name.to_string_lossy() != "debug")
439 {
440 if !zed_path.pop() {
441 panic!("Could not find target directory");
442 }
443 }
444
445 zed_path.push("zed");
446
447 if !zed_path.exists() {
448 panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
449 }
450
451 zed_path
452}