e2e_tests.rs

  1use crate::AgentServer;
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus, new_prompt_id};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::{AppContext, Entity, TestAppContext};
  6use indoc::indoc;
  7use project::{FakeFs, Project};
  8use std::{
  9    path::{Path, PathBuf},
 10    sync::Arc,
 11    time::Duration,
 12};
 13use util::path;
 14
 15pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 16where
 17    T: AgentServer + 'static,
 18    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 19{
 20    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 21    let project = Project::test(fs.clone(), [], cx).await;
 22    let thread = new_test_thread(
 23        server(&fs, &project, cx).await,
 24        project.clone(),
 25        "/private/tmp",
 26        cx,
 27    )
 28    .await;
 29
 30    thread
 31        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 32        .await
 33        .unwrap();
 34
 35    thread.read_with(cx, |thread, _| {
 36        assert!(
 37            thread.entries().len() >= 2,
 38            "Expected at least 2 entries. Got: {:?}",
 39            thread.entries()
 40        );
 41        assert!(matches!(
 42            thread.entries()[0],
 43            AgentThreadEntry::UserMessage(_)
 44        ));
 45        assert!(matches!(
 46            thread.entries()[1],
 47            AgentThreadEntry::AssistantMessage(_)
 48        ));
 49    });
 50}
 51
 52pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 53where
 54    T: AgentServer + 'static,
 55    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 56{
 57    let fs = init_test(cx).await as _;
 58
 59    let tempdir = tempfile::tempdir().unwrap();
 60    std::fs::write(
 61        tempdir.path().join("foo.rs"),
 62        indoc! {"
 63            fn main() {
 64                println!(\"Hello, world!\");
 65            }
 66        "},
 67    )
 68    .expect("failed to write file");
 69    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 70    let thread = new_test_thread(
 71        server(&fs, &project, cx).await,
 72        project.clone(),
 73        tempdir.path(),
 74        cx,
 75    )
 76    .await;
 77    thread
 78        .update(cx, |thread, cx| {
 79            thread.send(
 80                new_prompt_id(),
 81                vec![
 82                    acp::ContentBlock::Text(acp::TextContent {
 83                        text: "Read the file ".into(),
 84                        annotations: None,
 85                    }),
 86                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 87                        uri: "foo.rs".into(),
 88                        name: "foo.rs".into(),
 89                        annotations: None,
 90                        description: None,
 91                        mime_type: None,
 92                        size: None,
 93                        title: None,
 94                    }),
 95                    acp::ContentBlock::Text(acp::TextContent {
 96                        text: " and tell me what the content of the println! is".into(),
 97                        annotations: None,
 98                    }),
 99                ],
100                cx,
101            )
102        })
103        .await
104        .unwrap();
105
106    thread.read_with(cx, |thread, cx| {
107        assert!(matches!(
108            thread.entries()[0],
109            AgentThreadEntry::UserMessage(_)
110        ));
111        let assistant_message = &thread
112            .entries()
113            .iter()
114            .rev()
115            .find_map(|entry| match entry {
116                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
117                _ => None,
118            })
119            .unwrap();
120
121        assert!(
122            assistant_message.to_markdown(cx).contains("Hello, world!"),
123            "unexpected assistant message: {:?}",
124            assistant_message.to_markdown(cx)
125        );
126    });
127
128    drop(tempdir);
129}
130
131pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
132where
133    T: AgentServer + 'static,
134    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
135{
136    let fs = init_test(cx).await as _;
137
138    let tempdir = tempfile::tempdir().unwrap();
139    let foo_path = tempdir.path().join("foo");
140    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
141
142    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
143    let thread = new_test_thread(
144        server(&fs, &project, cx).await,
145        project.clone(),
146        "/private/tmp",
147        cx,
148    )
149    .await;
150
151    thread
152        .update(cx, |thread, cx| {
153            thread.send_raw(
154                &format!("Read {} and tell me what you see.", foo_path.display()),
155                cx,
156            )
157        })
158        .await
159        .unwrap();
160    thread.read_with(cx, |thread, _cx| {
161        assert!(thread.entries().iter().any(|entry| {
162            matches!(
163                entry,
164                AgentThreadEntry::ToolCall(ToolCall {
165                    status: ToolCallStatus::Pending
166                        | ToolCallStatus::InProgress
167                        | ToolCallStatus::Completed,
168                    ..
169                })
170            )
171        }));
172        assert!(
173            thread
174                .entries()
175                .iter()
176                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
177        );
178    });
179
180    drop(tempdir);
181}
182
183pub async fn test_tool_call_with_permission<T, F>(
184    server: F,
185    allow_option_id: acp::PermissionOptionId,
186    cx: &mut TestAppContext,
187) where
188    T: AgentServer + 'static,
189    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
190{
191    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
192    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
193    let thread = new_test_thread(
194        server(&fs, &project, cx).await,
195        project.clone(),
196        "/private/tmp",
197        cx,
198    )
199    .await;
200    let full_turn = thread.update(cx, |thread, cx| {
201        thread.send_raw(
202            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
203            cx,
204        )
205    });
206
207    run_until_first_tool_call(
208        &thread,
209        |entry| {
210            matches!(
211                entry,
212                AgentThreadEntry::ToolCall(ToolCall {
213                    status: ToolCallStatus::WaitingForConfirmation { .. },
214                    ..
215                })
216            )
217        },
218        cx,
219    )
220    .await;
221
222    let tool_call_id = thread.read_with(cx, |thread, cx| {
223        let AgentThreadEntry::ToolCall(ToolCall {
224            id,
225            label,
226            status: ToolCallStatus::WaitingForConfirmation { .. },
227            ..
228        }) = &thread
229            .entries()
230            .iter()
231            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
232            .unwrap()
233        else {
234            panic!();
235        };
236
237        let label = label.read(cx).source();
238        assert!(label.contains("touch"), "Got: {}", label);
239
240        id.clone()
241    });
242
243    thread.update(cx, |thread, cx| {
244        thread.authorize_tool_call(
245            tool_call_id,
246            allow_option_id,
247            acp::PermissionOptionKind::AllowOnce,
248            cx,
249        );
250
251        assert!(thread.entries().iter().any(|entry| matches!(
252            entry,
253            AgentThreadEntry::ToolCall(ToolCall {
254                status: ToolCallStatus::Pending
255                    | ToolCallStatus::InProgress
256                    | ToolCallStatus::Completed,
257                ..
258            })
259        )));
260    });
261
262    full_turn.await.unwrap();
263
264    thread.read_with(cx, |thread, cx| {
265        let AgentThreadEntry::ToolCall(ToolCall {
266            content,
267            status: ToolCallStatus::Pending
268                | ToolCallStatus::InProgress
269                | ToolCallStatus::Completed,
270            ..
271        }) = thread
272            .entries()
273            .iter()
274            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
275            .unwrap()
276        else {
277            panic!();
278        };
279
280        assert!(
281            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
282            "Expected content to contain 'Hello'"
283        );
284    });
285}
286
287pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
288where
289    T: AgentServer + 'static,
290    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
291{
292    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
293
294    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
295    let thread = new_test_thread(
296        server(&fs, &project, cx).await,
297        project.clone(),
298        "/private/tmp",
299        cx,
300    )
301    .await;
302    let _ = thread.update(cx, |thread, cx| {
303        thread.send_raw(
304            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
305            cx,
306        )
307    });
308
309    let first_tool_call_ix = run_until_first_tool_call(
310        &thread,
311        |entry| {
312            matches!(
313                entry,
314                AgentThreadEntry::ToolCall(ToolCall {
315                    status: ToolCallStatus::WaitingForConfirmation { .. },
316                    ..
317                })
318            )
319        },
320        cx,
321    )
322    .await;
323
324    thread.read_with(cx, |thread, cx| {
325        let AgentThreadEntry::ToolCall(ToolCall {
326            id,
327            label,
328            status: ToolCallStatus::WaitingForConfirmation { .. },
329            ..
330        }) = &thread.entries()[first_tool_call_ix]
331        else {
332            panic!("{:?}", thread.entries()[1]);
333        };
334
335        let label = label.read(cx).source();
336        assert!(label.contains("touch"), "Got: {}", label);
337
338        id.clone()
339    });
340
341    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
342    thread.read_with(cx, |thread, _cx| {
343        let AgentThreadEntry::ToolCall(ToolCall {
344            status: ToolCallStatus::Canceled,
345            ..
346        }) = &thread.entries()[first_tool_call_ix]
347        else {
348            panic!();
349        };
350    });
351
352    thread
353        .update(cx, |thread, cx| {
354            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
355        })
356        .await
357        .unwrap();
358    thread.read_with(cx, |thread, _| {
359        assert!(matches!(
360            &thread.entries().last().unwrap(),
361            AgentThreadEntry::AssistantMessage(..),
362        ))
363    });
364}
365
366pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
367where
368    T: AgentServer + 'static,
369    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
370{
371    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
372    let project = Project::test(fs.clone(), [], cx).await;
373    let thread = new_test_thread(
374        server(&fs, &project, cx).await,
375        project.clone(),
376        "/private/tmp",
377        cx,
378    )
379    .await;
380
381    thread
382        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
383        .await
384        .unwrap();
385
386    thread.read_with(cx, |thread, _| {
387        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
388    });
389
390    let weak_thread = thread.downgrade();
391    drop(thread);
392
393    cx.executor().run_until_parked();
394    assert!(!weak_thread.is_upgradable());
395}
396
397#[macro_export]
398macro_rules! common_e2e_tests {
399    ($server:expr, allow_option_id = $allow_option_id:expr) => {
400        mod common_e2e {
401            use super::*;
402
403            #[::gpui::test]
404            #[cfg_attr(not(feature = "e2e"), ignore)]
405            async fn basic(cx: &mut ::gpui::TestAppContext) {
406                $crate::e2e_tests::test_basic($server, cx).await;
407            }
408
409            #[::gpui::test]
410            #[cfg_attr(not(feature = "e2e"), ignore)]
411            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
412                $crate::e2e_tests::test_path_mentions($server, cx).await;
413            }
414
415            #[::gpui::test]
416            #[cfg_attr(not(feature = "e2e"), ignore)]
417            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
418                $crate::e2e_tests::test_tool_call($server, cx).await;
419            }
420
421            #[::gpui::test]
422            #[cfg_attr(not(feature = "e2e"), ignore)]
423            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
424                $crate::e2e_tests::test_tool_call_with_permission(
425                    $server,
426                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
427                    cx,
428                )
429                .await;
430            }
431
432            #[::gpui::test]
433            #[cfg_attr(not(feature = "e2e"), ignore)]
434            async fn cancel(cx: &mut ::gpui::TestAppContext) {
435                $crate::e2e_tests::test_cancel($server, cx).await;
436            }
437
438            #[::gpui::test]
439            #[cfg_attr(not(feature = "e2e"), ignore)]
440            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
441                $crate::e2e_tests::test_thread_drop($server, cx).await;
442            }
443        }
444    };
445}
446pub use common_e2e_tests;
447
448// Helpers
449
450pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
451    #[cfg(test)]
452    use settings::Settings;
453
454    env_logger::try_init().ok();
455
456    cx.update(|cx| {
457        let settings_store = settings::SettingsStore::test(cx);
458        cx.set_global(settings_store);
459        Project::init_settings(cx);
460        language::init(cx);
461        gpui_tokio::init(cx);
462        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
463        cx.set_http_client(Arc::new(http_client));
464        client::init_settings(cx);
465        let client = client::Client::production(cx);
466        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
467        language_model::init(client.clone(), cx);
468        language_models::init(user_store, client, cx);
469        agent_settings::init(cx);
470        crate::settings::init(cx);
471
472        #[cfg(test)]
473        crate::AllAgentServersSettings::override_global(
474            crate::AllAgentServersSettings {
475                claude: Some(crate::AgentServerSettings {
476                    command: crate::claude::tests::local_command(),
477                }),
478                gemini: Some(crate::AgentServerSettings {
479                    command: crate::gemini::tests::local_command(),
480                }),
481                custom: collections::HashMap::default(),
482            },
483            cx,
484        );
485    });
486
487    cx.executor().allow_parking();
488
489    FakeFs::new(cx.executor())
490}
491
492pub async fn new_test_thread(
493    server: impl AgentServer + 'static,
494    project: Entity<Project>,
495    current_dir: impl AsRef<Path>,
496    cx: &mut TestAppContext,
497) -> Entity<AcpThread> {
498    let connection = cx
499        .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
500        .await
501        .unwrap();
502
503    cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
504        .await
505        .unwrap()
506}
507
508pub async fn run_until_first_tool_call(
509    thread: &Entity<AcpThread>,
510    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
511    cx: &mut TestAppContext,
512) -> usize {
513    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
514
515    let subscription = cx.update(|cx| {
516        cx.subscribe(thread, move |thread, _, cx| {
517            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
518                if wait_until(entry) {
519                    return tx.try_send(ix).unwrap();
520                }
521            }
522        })
523    });
524
525    select! {
526        // We have to use a smol timer here because
527        // cx.background_executor().timer isn't real in the test context
528        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
529            panic!("Timeout waiting for tool call")
530        }
531        ix = rx.next().fuse() => {
532            drop(subscription);
533            ix.unwrap()
534        }
535    }
536}
537
538pub fn get_zed_path() -> PathBuf {
539    let mut zed_path = std::env::current_exe().unwrap();
540
541    while zed_path
542        .file_name()
543        .is_none_or(|name| name.to_string_lossy() != "debug")
544    {
545        if !zed_path.pop() {
546            panic!("Could not find target directory");
547        }
548    }
549
550    zed_path.push("zed");
551
552    if !zed_path.exists() {
553        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
554    }
555
556    zed_path
557}