e2e_tests.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4    time::Duration,
  5};
  6
  7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  9use agent_client_protocol as acp;
 10use agent_settings::AgentSettings;
 11
 12use futures::{FutureExt, StreamExt, channel::mpsc, select};
 13use gpui::{Entity, TestAppContext};
 14use indoc::indoc;
 15use project::{FakeFs, Project};
 16use settings::{Settings, SettingsStore};
 17use util::path;
 18
 19pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 20    let fs = init_test(cx).await;
 21    let project = Project::test(fs, [], cx).await;
 22    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
 23
 24    thread
 25        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 26        .await
 27        .unwrap();
 28
 29    thread.read_with(cx, |thread, _| {
 30        assert!(
 31            thread.entries().len() >= 2,
 32            "Expected at least 2 entries. Got: {:?}",
 33            thread.entries()
 34        );
 35        assert!(matches!(
 36            thread.entries()[0],
 37            AgentThreadEntry::UserMessage(_)
 38        ));
 39        assert!(matches!(
 40            thread.entries()[1],
 41            AgentThreadEntry::AssistantMessage(_)
 42        ));
 43    });
 44}
 45
 46pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 47    let _fs = init_test(cx).await;
 48
 49    let tempdir = tempfile::tempdir().unwrap();
 50    std::fs::write(
 51        tempdir.path().join("foo.rs"),
 52        indoc! {"
 53            fn main() {
 54                println!(\"Hello, world!\");
 55            }
 56        "},
 57    )
 58    .expect("failed to write file");
 59    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 60    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 61    thread
 62        .update(cx, |thread, cx| {
 63            thread.send(
 64                vec![
 65                    acp::ContentBlock::Text(acp::TextContent {
 66                        text: "Read the file ".into(),
 67                        annotations: None,
 68                    }),
 69                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 70                        uri: "foo.rs".into(),
 71                        name: "foo.rs".into(),
 72                        annotations: None,
 73                        description: None,
 74                        mime_type: None,
 75                        size: None,
 76                        title: None,
 77                    }),
 78                    acp::ContentBlock::Text(acp::TextContent {
 79                        text: " and tell me what the content of the println! is".into(),
 80                        annotations: None,
 81                    }),
 82                ],
 83                cx,
 84            )
 85        })
 86        .await
 87        .unwrap();
 88
 89    thread.read_with(cx, |thread, cx| {
 90        assert!(matches!(
 91            thread.entries()[0],
 92            AgentThreadEntry::UserMessage(_)
 93        ));
 94        let assistant_message = &thread
 95            .entries()
 96            .iter()
 97            .rev()
 98            .find_map(|entry| match entry {
 99                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
100                _ => None,
101            })
102            .unwrap();
103
104        assert!(
105            assistant_message.to_markdown(cx).contains("Hello, world!"),
106            "unexpected assistant message: {:?}",
107            assistant_message.to_markdown(cx)
108        );
109    });
110
111    drop(tempdir);
112}
113
114pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
115    let _fs = init_test(cx).await;
116
117    let tempdir = tempfile::tempdir().unwrap();
118    let foo_path = tempdir.path().join("foo");
119    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
120
121    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
122    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
123
124    thread
125        .update(cx, |thread, cx| {
126            thread.send_raw(
127                &format!("Read {} and tell me what you see.", foo_path.display()),
128                cx,
129            )
130        })
131        .await
132        .unwrap();
133    thread.read_with(cx, |thread, _cx| {
134        assert!(thread.entries().iter().any(|entry| {
135            matches!(
136                entry,
137                AgentThreadEntry::ToolCall(ToolCall {
138                    status: ToolCallStatus::Allowed { .. },
139                    ..
140                })
141            )
142        }));
143        assert!(
144            thread
145                .entries()
146                .iter()
147                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
148        );
149    });
150
151    drop(tempdir);
152}
153
154pub async fn test_tool_call_with_confirmation(
155    server: impl AgentServer + 'static,
156    allow_option_id: acp::PermissionOptionId,
157    cx: &mut TestAppContext,
158) {
159    let fs = init_test(cx).await;
160    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
161    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
162    let full_turn = thread.update(cx, |thread, cx| {
163        thread.send_raw(
164            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
165            cx,
166        )
167    });
168
169    run_until_first_tool_call(
170        &thread,
171        |entry| {
172            matches!(
173                entry,
174                AgentThreadEntry::ToolCall(ToolCall {
175                    status: ToolCallStatus::WaitingForConfirmation { .. },
176                    ..
177                })
178            )
179        },
180        cx,
181    )
182    .await;
183
184    let tool_call_id = thread.read_with(cx, |thread, cx| {
185        let AgentThreadEntry::ToolCall(ToolCall {
186            id,
187            label,
188            status: ToolCallStatus::WaitingForConfirmation { .. },
189            ..
190        }) = &thread
191            .entries()
192            .iter()
193            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
194            .unwrap()
195        else {
196            panic!();
197        };
198
199        let label = label.read(cx).source();
200        assert!(label.contains("touch"), "Got: {}", label);
201
202        id.clone()
203    });
204
205    thread.update(cx, |thread, cx| {
206        thread.authorize_tool_call(
207            tool_call_id,
208            allow_option_id,
209            acp::PermissionOptionKind::AllowOnce,
210            cx,
211        );
212
213        assert!(thread.entries().iter().any(|entry| matches!(
214            entry,
215            AgentThreadEntry::ToolCall(ToolCall {
216                status: ToolCallStatus::Allowed { .. },
217                ..
218            })
219        )));
220    });
221
222    full_turn.await.unwrap();
223
224    thread.read_with(cx, |thread, cx| {
225        let AgentThreadEntry::ToolCall(ToolCall {
226            content,
227            status: ToolCallStatus::Allowed { .. },
228            ..
229        }) = thread
230            .entries()
231            .iter()
232            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
233            .unwrap()
234        else {
235            panic!();
236        };
237
238        assert!(
239            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
240            "Expected content to contain 'Hello'"
241        );
242    });
243}
244
245pub async fn test_tool_call_always_allow(
246    server: impl AgentServer + 'static,
247    cx: &mut TestAppContext,
248) {
249    let fs = init_test(cx).await;
250    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
251
252    // Enable always_allow_tool_actions
253    cx.update(|cx| {
254        AgentSettings::override_global(
255            AgentSettings {
256                always_allow_tool_actions: true,
257                ..Default::default()
258            },
259            cx,
260        );
261    });
262
263    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
264    let full_turn = thread.update(cx, |thread, cx| {
265        thread.send_raw(
266            r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
267            cx,
268        )
269    });
270
271    // Wait for the tool call to complete
272    full_turn.await.unwrap();
273
274    thread.read_with(cx, |thread, _cx| {
275        // With always_allow_tool_actions enabled, the tool call should be immediately allowed
276        // without waiting for confirmation
277        let tool_call_entry = thread
278            .entries()
279            .iter()
280            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
281            .expect("Expected a tool call entry");
282
283        let AgentThreadEntry::ToolCall(tool_call) = tool_call_entry else {
284            panic!("Expected tool call entry");
285        };
286
287        // Should be allowed, not waiting for confirmation
288        assert!(
289            matches!(tool_call.status, ToolCallStatus::Allowed { .. }),
290            "Expected tool call to be allowed automatically, but got {:?}",
291            tool_call.status
292        );
293    });
294}
295
296pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
297    let fs = init_test(cx).await;
298
299    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
300    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
301    let full_turn = thread.update(cx, |thread, cx| {
302        thread.send_raw(
303            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
304            cx,
305        )
306    });
307
308    let first_tool_call_ix = run_until_first_tool_call(
309        &thread,
310        |entry| {
311            matches!(
312                entry,
313                AgentThreadEntry::ToolCall(ToolCall {
314                    status: ToolCallStatus::WaitingForConfirmation { .. },
315                    ..
316                })
317            )
318        },
319        cx,
320    )
321    .await;
322
323    thread.read_with(cx, |thread, cx| {
324        let AgentThreadEntry::ToolCall(ToolCall {
325            id,
326            label,
327            status: ToolCallStatus::WaitingForConfirmation { .. },
328            ..
329        }) = &thread.entries()[first_tool_call_ix]
330        else {
331            panic!("{:?}", thread.entries()[1]);
332        };
333
334        let label = label.read(cx).source();
335        assert!(label.contains("touch"), "Got: {}", label);
336
337        id.clone()
338    });
339
340    let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
341    full_turn.await.unwrap();
342    thread.read_with(cx, |thread, _| {
343        let AgentThreadEntry::ToolCall(ToolCall {
344            status: ToolCallStatus::Canceled,
345            ..
346        }) = &thread.entries()[first_tool_call_ix]
347        else {
348            panic!();
349        };
350    });
351
352    thread
353        .update(cx, |thread, cx| {
354            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
355        })
356        .await
357        .unwrap();
358    thread.read_with(cx, |thread, _| {
359        assert!(matches!(
360            &thread.entries().last().unwrap(),
361            AgentThreadEntry::AssistantMessage(..),
362        ))
363    });
364}
365
366#[macro_export]
367macro_rules! common_e2e_tests {
368    ($server:expr, allow_option_id = $allow_option_id:expr) => {
369        mod common_e2e {
370            use super::*;
371
372            #[::gpui::test]
373            #[cfg_attr(not(feature = "e2e"), ignore)]
374            async fn basic(cx: &mut ::gpui::TestAppContext) {
375                $crate::e2e_tests::test_basic($server, cx).await;
376            }
377
378            #[::gpui::test]
379            #[cfg_attr(not(feature = "e2e"), ignore)]
380            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
381                $crate::e2e_tests::test_path_mentions($server, cx).await;
382            }
383
384            #[::gpui::test]
385            #[cfg_attr(not(feature = "e2e"), ignore)]
386            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
387                $crate::e2e_tests::test_tool_call($server, cx).await;
388            }
389
390            #[::gpui::test]
391            #[cfg_attr(not(feature = "e2e"), ignore)]
392            async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
393                $crate::e2e_tests::test_tool_call_with_confirmation(
394                    $server,
395                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
396                    cx,
397                )
398                .await;
399            }
400
401            #[::gpui::test]
402            #[cfg_attr(not(feature = "e2e"), ignore)]
403            async fn cancel(cx: &mut ::gpui::TestAppContext) {
404                $crate::e2e_tests::test_cancel($server, cx).await;
405            }
406
407            #[::gpui::test]
408            #[cfg_attr(not(feature = "e2e"), ignore)]
409            async fn tool_call_always_allow(cx: &mut ::gpui::TestAppContext) {
410                $crate::e2e_tests::test_tool_call_always_allow($server, cx).await;
411            }
412        }
413    };
414}
415
416// Helpers
417
418pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
419    env_logger::try_init().ok();
420
421    cx.update(|cx| {
422        let settings_store = SettingsStore::test(cx);
423        cx.set_global(settings_store);
424        Project::init_settings(cx);
425        language::init(cx);
426        crate::settings::init(cx);
427
428        crate::AllAgentServersSettings::override_global(
429            AllAgentServersSettings {
430                claude: Some(AgentServerSettings {
431                    command: crate::claude::tests::local_command(),
432                }),
433                gemini: Some(AgentServerSettings {
434                    command: crate::gemini::tests::local_command(),
435                }),
436                codex: Some(AgentServerSettings {
437                    command: crate::codex::tests::local_command(),
438                }),
439            },
440            cx,
441        );
442    });
443
444    cx.executor().allow_parking();
445
446    FakeFs::new(cx.executor())
447}
448
449pub async fn new_test_thread(
450    server: impl AgentServer + 'static,
451    project: Entity<Project>,
452    current_dir: impl AsRef<Path>,
453    cx: &mut TestAppContext,
454) -> Entity<AcpThread> {
455    let connection = cx
456        .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
457        .await
458        .unwrap();
459
460    let thread = connection
461        .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
462        .await
463        .unwrap();
464
465    thread
466}
467
468pub async fn run_until_first_tool_call(
469    thread: &Entity<AcpThread>,
470    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
471    cx: &mut TestAppContext,
472) -> usize {
473    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
474
475    let subscription = cx.update(|cx| {
476        cx.subscribe(thread, move |thread, _, cx| {
477            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
478                if wait_until(entry) {
479                    return tx.try_send(ix).unwrap();
480                }
481            }
482        })
483    });
484
485    select! {
486        // We have to use a smol timer here because
487        // cx.background_executor().timer isn't real in the test context
488        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
489            panic!("Timeout waiting for tool call")
490        }
491        ix = rx.next().fuse() => {
492            drop(subscription);
493            ix.unwrap()
494        }
495    }
496}
497
498pub fn get_zed_path() -> PathBuf {
499    let mut zed_path = std::env::current_exe().unwrap();
500
501    while zed_path
502        .file_name()
503        .map_or(true, |name| name.to_string_lossy() != "debug")
504    {
505        if !zed_path.pop() {
506            panic!("Could not find target directory");
507        }
508    }
509
510    zed_path.push("zed");
511
512    if !zed_path.exists() {
513        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
514    }
515
516    zed_path
517}