e2e_tests.rs

  1use crate::AgentServer;
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::{AppContext, Entity, TestAppContext};
  6use indoc::indoc;
  7use project::{FakeFs, Project};
  8use std::{
  9    path::{Path, PathBuf},
 10    sync::Arc,
 11    time::Duration,
 12};
 13use util::path;
 14
 15pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 16where
 17    T: AgentServer + 'static,
 18    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 19{
 20    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 21    let project = Project::test(fs.clone(), [], cx).await;
 22    let thread = new_test_thread(
 23        server(&fs, &project, cx).await,
 24        project.clone(),
 25        "/private/tmp",
 26        cx,
 27    )
 28    .await;
 29
 30    thread
 31        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 32        .await
 33        .unwrap();
 34
 35    thread.read_with(cx, |thread, _| {
 36        assert!(
 37            thread.entries().len() >= 2,
 38            "Expected at least 2 entries. Got: {:?}",
 39            thread.entries()
 40        );
 41        assert!(matches!(
 42            thread.entries()[0],
 43            AgentThreadEntry::UserMessage(_)
 44        ));
 45        assert!(matches!(
 46            thread.entries()[1],
 47            AgentThreadEntry::AssistantMessage(_)
 48        ));
 49    });
 50}
 51
 52pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 53where
 54    T: AgentServer + 'static,
 55    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 56{
 57    let fs = init_test(cx).await as _;
 58
 59    let tempdir = tempfile::tempdir().unwrap();
 60    std::fs::write(
 61        tempdir.path().join("foo.rs"),
 62        indoc! {"
 63            fn main() {
 64                println!(\"Hello, world!\");
 65            }
 66        "},
 67    )
 68    .expect("failed to write file");
 69    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 70    let thread = new_test_thread(
 71        server(&fs, &project, cx).await,
 72        project.clone(),
 73        tempdir.path(),
 74        cx,
 75    )
 76    .await;
 77    thread
 78        .update(cx, |thread, cx| {
 79            thread.send(
 80                vec![
 81                    acp::ContentBlock::Text(acp::TextContent {
 82                        text: "Read the file ".into(),
 83                        annotations: None,
 84                    }),
 85                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 86                        uri: "foo.rs".into(),
 87                        name: "foo.rs".into(),
 88                        annotations: None,
 89                        description: None,
 90                        mime_type: None,
 91                        size: None,
 92                        title: None,
 93                    }),
 94                    acp::ContentBlock::Text(acp::TextContent {
 95                        text: " and tell me what the content of the println! is".into(),
 96                        annotations: None,
 97                    }),
 98                ],
 99                cx,
100            )
101        })
102        .await
103        .unwrap();
104
105    thread.read_with(cx, |thread, cx| {
106        assert!(matches!(
107            thread.entries()[0],
108            AgentThreadEntry::UserMessage(_)
109        ));
110        let assistant_message = &thread
111            .entries()
112            .iter()
113            .rev()
114            .find_map(|entry| match entry {
115                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
116                _ => None,
117            })
118            .unwrap();
119
120        assert!(
121            assistant_message.to_markdown(cx).contains("Hello, world!"),
122            "unexpected assistant message: {:?}",
123            assistant_message.to_markdown(cx)
124        );
125    });
126
127    drop(tempdir);
128}
129
130pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
131where
132    T: AgentServer + 'static,
133    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
134{
135    let fs = init_test(cx).await as _;
136
137    let tempdir = tempfile::tempdir().unwrap();
138    let foo_path = tempdir.path().join("foo");
139    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
140
141    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
142    let thread = new_test_thread(
143        server(&fs, &project, cx).await,
144        project.clone(),
145        "/private/tmp",
146        cx,
147    )
148    .await;
149
150    thread
151        .update(cx, |thread, cx| {
152            thread.send_raw(
153                &format!("Read {} and tell me what you see.", foo_path.display()),
154                cx,
155            )
156        })
157        .await
158        .unwrap();
159    thread.read_with(cx, |thread, _cx| {
160        assert!(thread.entries().iter().any(|entry| {
161            matches!(
162                entry,
163                AgentThreadEntry::ToolCall(ToolCall {
164                    status: ToolCallStatus::Pending
165                        | ToolCallStatus::InProgress
166                        | ToolCallStatus::Completed,
167                    ..
168                })
169            )
170        }));
171        assert!(
172            thread
173                .entries()
174                .iter()
175                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
176        );
177    });
178
179    drop(tempdir);
180}
181
182pub async fn test_tool_call_with_permission<T, F>(
183    server: F,
184    allow_option_id: acp::PermissionOptionId,
185    cx: &mut TestAppContext,
186) where
187    T: AgentServer + 'static,
188    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
189{
190    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
191    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
192    let thread = new_test_thread(
193        server(&fs, &project, cx).await,
194        project.clone(),
195        "/private/tmp",
196        cx,
197    )
198    .await;
199    let full_turn = thread.update(cx, |thread, cx| {
200        thread.send_raw(
201            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
202            cx,
203        )
204    });
205
206    run_until_first_tool_call(
207        &thread,
208        |entry| {
209            matches!(
210                entry,
211                AgentThreadEntry::ToolCall(ToolCall {
212                    status: ToolCallStatus::WaitingForConfirmation { .. },
213                    ..
214                })
215            )
216        },
217        cx,
218    )
219    .await;
220
221    let tool_call_id = thread.read_with(cx, |thread, cx| {
222        let AgentThreadEntry::ToolCall(ToolCall {
223            id,
224            label,
225            status: ToolCallStatus::WaitingForConfirmation { .. },
226            ..
227        }) = &thread
228            .entries()
229            .iter()
230            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
231            .unwrap()
232        else {
233            panic!();
234        };
235
236        let label = label.read(cx).source();
237        assert!(label.contains("touch"), "Got: {}", label);
238
239        id.clone()
240    });
241
242    thread.update(cx, |thread, cx| {
243        thread.authorize_tool_call(
244            tool_call_id,
245            allow_option_id,
246            acp::PermissionOptionKind::AllowOnce,
247            cx,
248        );
249
250        assert!(thread.entries().iter().any(|entry| matches!(
251            entry,
252            AgentThreadEntry::ToolCall(ToolCall {
253                status: ToolCallStatus::Pending
254                    | ToolCallStatus::InProgress
255                    | ToolCallStatus::Completed,
256                ..
257            })
258        )));
259    });
260
261    full_turn.await.unwrap();
262
263    thread.read_with(cx, |thread, cx| {
264        let AgentThreadEntry::ToolCall(ToolCall {
265            content,
266            status: ToolCallStatus::Pending
267                | ToolCallStatus::InProgress
268                | ToolCallStatus::Completed,
269            ..
270        }) = thread
271            .entries()
272            .iter()
273            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
274            .unwrap()
275        else {
276            panic!();
277        };
278
279        assert!(
280            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
281            "Expected content to contain 'Hello'"
282        );
283    });
284}
285
286pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
287where
288    T: AgentServer + 'static,
289    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
290{
291    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
292
293    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
294    let thread = new_test_thread(
295        server(&fs, &project, cx).await,
296        project.clone(),
297        "/private/tmp",
298        cx,
299    )
300    .await;
301    let _ = thread.update(cx, |thread, cx| {
302        thread.send_raw(
303            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
304            cx,
305        )
306    });
307
308    let first_tool_call_ix = run_until_first_tool_call(
309        &thread,
310        |entry| {
311            matches!(
312                entry,
313                AgentThreadEntry::ToolCall(ToolCall {
314                    status: ToolCallStatus::WaitingForConfirmation { .. },
315                    ..
316                })
317            )
318        },
319        cx,
320    )
321    .await;
322
323    thread.read_with(cx, |thread, cx| {
324        let AgentThreadEntry::ToolCall(ToolCall {
325            id,
326            label,
327            status: ToolCallStatus::WaitingForConfirmation { .. },
328            ..
329        }) = &thread.entries()[first_tool_call_ix]
330        else {
331            panic!("{:?}", thread.entries()[1]);
332        };
333
334        let label = label.read(cx).source();
335        assert!(label.contains("touch"), "Got: {}", label);
336
337        id.clone()
338    });
339
340    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
341    thread.read_with(cx, |thread, _cx| {
342        let AgentThreadEntry::ToolCall(ToolCall {
343            status: ToolCallStatus::Canceled,
344            ..
345        }) = &thread.entries()[first_tool_call_ix]
346        else {
347            panic!();
348        };
349    });
350
351    thread
352        .update(cx, |thread, cx| {
353            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
354        })
355        .await
356        .unwrap();
357    thread.read_with(cx, |thread, _| {
358        assert!(matches!(
359            &thread.entries().last().unwrap(),
360            AgentThreadEntry::AssistantMessage(..),
361        ))
362    });
363}
364
365pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
366where
367    T: AgentServer + 'static,
368    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
369{
370    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
371    let project = Project::test(fs.clone(), [], cx).await;
372    let thread = new_test_thread(
373        server(&fs, &project, cx).await,
374        project.clone(),
375        "/private/tmp",
376        cx,
377    )
378    .await;
379
380    thread
381        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
382        .await
383        .unwrap();
384
385    thread.read_with(cx, |thread, _| {
386        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
387    });
388
389    let weak_thread = thread.downgrade();
390    drop(thread);
391
392    cx.executor().run_until_parked();
393    assert!(!weak_thread.is_upgradable());
394}
395
396#[macro_export]
397macro_rules! common_e2e_tests {
398    ($server:expr, allow_option_id = $allow_option_id:expr) => {
399        mod common_e2e {
400            use super::*;
401
402            #[::gpui::test]
403            #[cfg_attr(not(feature = "e2e"), ignore)]
404            async fn basic(cx: &mut ::gpui::TestAppContext) {
405                $crate::e2e_tests::test_basic($server, cx).await;
406            }
407
408            #[::gpui::test]
409            #[cfg_attr(not(feature = "e2e"), ignore)]
410            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
411                $crate::e2e_tests::test_path_mentions($server, cx).await;
412            }
413
414            #[::gpui::test]
415            #[cfg_attr(not(feature = "e2e"), ignore)]
416            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
417                $crate::e2e_tests::test_tool_call($server, cx).await;
418            }
419
420            #[::gpui::test]
421            #[cfg_attr(not(feature = "e2e"), ignore)]
422            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
423                $crate::e2e_tests::test_tool_call_with_permission(
424                    $server,
425                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
426                    cx,
427                )
428                .await;
429            }
430
431            #[::gpui::test]
432            #[cfg_attr(not(feature = "e2e"), ignore)]
433            async fn cancel(cx: &mut ::gpui::TestAppContext) {
434                $crate::e2e_tests::test_cancel($server, cx).await;
435            }
436
437            #[::gpui::test]
438            #[cfg_attr(not(feature = "e2e"), ignore)]
439            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
440                $crate::e2e_tests::test_thread_drop($server, cx).await;
441            }
442        }
443    };
444}
445pub use common_e2e_tests;
446
447// Helpers
448
449pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
450    #[cfg(test)]
451    use settings::Settings;
452
453    env_logger::try_init().ok();
454
455    cx.update(|cx| {
456        let settings_store = settings::SettingsStore::test(cx);
457        cx.set_global(settings_store);
458        Project::init_settings(cx);
459        language::init(cx);
460        gpui_tokio::init(cx);
461        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
462        cx.set_http_client(Arc::new(http_client));
463        client::init_settings(cx);
464        let client = client::Client::production(cx);
465        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
466        language_model::init(client.clone(), cx);
467        language_models::init(user_store, client, cx);
468        agent_settings::init(cx);
469        crate::settings::init(cx);
470
471        #[cfg(test)]
472        crate::AllAgentServersSettings::override_global(
473            crate::AllAgentServersSettings {
474                claude: Some(crate::AgentServerSettings {
475                    command: crate::claude::tests::local_command(),
476                }),
477                gemini: Some(crate::AgentServerSettings {
478                    command: crate::gemini::tests::local_command(),
479                }),
480                custom: collections::HashMap::default(),
481            },
482            cx,
483        );
484    });
485
486    cx.executor().allow_parking();
487
488    FakeFs::new(cx.executor())
489}
490
491pub async fn new_test_thread(
492    server: impl AgentServer + 'static,
493    project: Entity<Project>,
494    current_dir: impl AsRef<Path>,
495    cx: &mut TestAppContext,
496) -> Entity<AcpThread> {
497    let connection = cx
498        .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
499        .await
500        .unwrap();
501
502    cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
503        .await
504        .unwrap()
505}
506
507pub async fn run_until_first_tool_call(
508    thread: &Entity<AcpThread>,
509    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
510    cx: &mut TestAppContext,
511) -> usize {
512    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
513
514    let subscription = cx.update(|cx| {
515        cx.subscribe(thread, move |thread, _, cx| {
516            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
517                if wait_until(entry) {
518                    return tx.try_send(ix).unwrap();
519                }
520            }
521        })
522    });
523
524    select! {
525        // We have to use a smol timer here because
526        // cx.background_executor().timer isn't real in the test context
527        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
528            panic!("Timeout waiting for tool call")
529        }
530        ix = rx.next().fuse() => {
531            drop(subscription);
532            ix.unwrap()
533        }
534    }
535}
536
537pub fn get_zed_path() -> PathBuf {
538    let mut zed_path = std::env::current_exe().unwrap();
539
540    while zed_path
541        .file_name()
542        .is_none_or(|name| name.to_string_lossy() != "debug")
543    {
544        if !zed_path.pop() {
545            panic!("Could not find target directory");
546        }
547    }
548
549    zed_path.push("zed");
550
551    if !zed_path.exists() {
552        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
553    }
554
555    zed_path
556}