e2e_tests.rs

  1use crate::{AgentServer, AgentServerDelegate};
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::{AppContext, Entity, TestAppContext};
  6use indoc::indoc;
  7#[cfg(test)]
  8use project::agent_server_store::BuiltinAgentServerSettings;
  9use project::{FakeFs, Project, agent_server_store::AllAgentServersSettings};
 10use std::{
 11    path::{Path, PathBuf},
 12    sync::Arc,
 13    time::Duration,
 14};
 15use util::path;
 16
 17pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 18where
 19    T: AgentServer + 'static,
 20    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 21{
 22    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 23    let project = Project::test(fs.clone(), [], cx).await;
 24    let thread = new_test_thread(
 25        server(&fs, &project, cx).await,
 26        project.clone(),
 27        "/private/tmp",
 28        cx,
 29    )
 30    .await;
 31
 32    thread
 33        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 34        .await
 35        .unwrap();
 36
 37    thread.read_with(cx, |thread, _| {
 38        assert!(
 39            thread.entries().len() >= 2,
 40            "Expected at least 2 entries. Got: {:?}",
 41            thread.entries()
 42        );
 43        assert!(matches!(
 44            thread.entries()[0],
 45            AgentThreadEntry::UserMessage(_)
 46        ));
 47        assert!(matches!(
 48            thread.entries()[1],
 49            AgentThreadEntry::AssistantMessage(_)
 50        ));
 51    });
 52}
 53
 54pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 55where
 56    T: AgentServer + 'static,
 57    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 58{
 59    let fs = init_test(cx).await as _;
 60
 61    let tempdir = tempfile::tempdir().unwrap();
 62    std::fs::write(
 63        tempdir.path().join("foo.rs"),
 64        indoc! {"
 65            fn main() {
 66                println!(\"Hello, world!\");
 67            }
 68        "},
 69    )
 70    .expect("failed to write file");
 71    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 72    let thread = new_test_thread(
 73        server(&fs, &project, cx).await,
 74        project.clone(),
 75        tempdir.path(),
 76        cx,
 77    )
 78    .await;
 79    thread
 80        .update(cx, |thread, cx| {
 81            thread.send(
 82                vec![
 83                    acp::ContentBlock::Text(acp::TextContent {
 84                        text: "Read the file ".into(),
 85                        annotations: None,
 86                        meta: None,
 87                    }),
 88                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 89                        uri: "foo.rs".into(),
 90                        name: "foo.rs".into(),
 91                        annotations: None,
 92                        description: None,
 93                        mime_type: None,
 94                        size: None,
 95                        title: None,
 96                        meta: None,
 97                    }),
 98                    acp::ContentBlock::Text(acp::TextContent {
 99                        text: " and tell me what the content of the println! is".into(),
100                        annotations: None,
101                        meta: None,
102                    }),
103                ],
104                cx,
105            )
106        })
107        .await
108        .unwrap();
109
110    thread.read_with(cx, |thread, cx| {
111        assert!(matches!(
112            thread.entries()[0],
113            AgentThreadEntry::UserMessage(_)
114        ));
115        let assistant_message = &thread
116            .entries()
117            .iter()
118            .rev()
119            .find_map(|entry| match entry {
120                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
121                _ => None,
122            })
123            .unwrap();
124
125        assert!(
126            assistant_message.to_markdown(cx).contains("Hello, world!"),
127            "unexpected assistant message: {:?}",
128            assistant_message.to_markdown(cx)
129        );
130    });
131
132    drop(tempdir);
133}
134
135pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
136where
137    T: AgentServer + 'static,
138    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
139{
140    let fs = init_test(cx).await as _;
141
142    let tempdir = tempfile::tempdir().unwrap();
143    let foo_path = tempdir.path().join("foo");
144    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
145
146    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
147    let thread = new_test_thread(
148        server(&fs, &project, cx).await,
149        project.clone(),
150        "/private/tmp",
151        cx,
152    )
153    .await;
154
155    thread
156        .update(cx, |thread, cx| {
157            thread.send_raw(
158                &format!("Read {} and tell me what you see.", foo_path.display()),
159                cx,
160            )
161        })
162        .await
163        .unwrap();
164    thread.read_with(cx, |thread, _cx| {
165        assert!(thread.entries().iter().any(|entry| {
166            matches!(
167                entry,
168                AgentThreadEntry::ToolCall(ToolCall {
169                    status: ToolCallStatus::Pending
170                        | ToolCallStatus::InProgress
171                        | ToolCallStatus::Completed,
172                    ..
173                })
174            )
175        }));
176        assert!(
177            thread
178                .entries()
179                .iter()
180                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
181        );
182    });
183
184    drop(tempdir);
185}
186
187pub async fn test_tool_call_with_permission<T, F>(
188    server: F,
189    allow_option_id: acp::PermissionOptionId,
190    cx: &mut TestAppContext,
191) where
192    T: AgentServer + 'static,
193    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
194{
195    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
196    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
197    let thread = new_test_thread(
198        server(&fs, &project, cx).await,
199        project.clone(),
200        "/private/tmp",
201        cx,
202    )
203    .await;
204    let full_turn = thread.update(cx, |thread, cx| {
205        thread.send_raw(
206            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
207            cx,
208        )
209    });
210
211    run_until_first_tool_call(
212        &thread,
213        |entry| {
214            matches!(
215                entry,
216                AgentThreadEntry::ToolCall(ToolCall {
217                    status: ToolCallStatus::WaitingForConfirmation { .. },
218                    ..
219                })
220            )
221        },
222        cx,
223    )
224    .await;
225
226    let tool_call_id = thread.read_with(cx, |thread, cx| {
227        let AgentThreadEntry::ToolCall(ToolCall {
228            id,
229            label,
230            status: ToolCallStatus::WaitingForConfirmation { .. },
231            ..
232        }) = &thread
233            .entries()
234            .iter()
235            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
236            .unwrap()
237        else {
238            panic!();
239        };
240
241        let label = label.read(cx).source();
242        assert!(label.contains("touch"), "Got: {}", label);
243
244        id.clone()
245    });
246
247    thread.update(cx, |thread, cx| {
248        thread.authorize_tool_call(
249            tool_call_id,
250            allow_option_id,
251            acp::PermissionOptionKind::AllowOnce,
252            cx,
253        );
254
255        assert!(thread.entries().iter().any(|entry| matches!(
256            entry,
257            AgentThreadEntry::ToolCall(ToolCall {
258                status: ToolCallStatus::Pending
259                    | ToolCallStatus::InProgress
260                    | ToolCallStatus::Completed,
261                ..
262            })
263        )));
264    });
265
266    full_turn.await.unwrap();
267
268    thread.read_with(cx, |thread, cx| {
269        let AgentThreadEntry::ToolCall(ToolCall {
270            content,
271            status: ToolCallStatus::Pending
272                | ToolCallStatus::InProgress
273                | ToolCallStatus::Completed,
274            ..
275        }) = thread
276            .entries()
277            .iter()
278            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
279            .unwrap()
280        else {
281            panic!();
282        };
283
284        assert!(
285            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
286            "Expected content to contain 'Hello'"
287        );
288    });
289}
290
291pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
292where
293    T: AgentServer + 'static,
294    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
295{
296    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
297
298    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
299    let thread = new_test_thread(
300        server(&fs, &project, cx).await,
301        project.clone(),
302        "/private/tmp",
303        cx,
304    )
305    .await;
306    let _ = thread.update(cx, |thread, cx| {
307        thread.send_raw(
308            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
309            cx,
310        )
311    });
312
313    let first_tool_call_ix = run_until_first_tool_call(
314        &thread,
315        |entry| {
316            matches!(
317                entry,
318                AgentThreadEntry::ToolCall(ToolCall {
319                    status: ToolCallStatus::WaitingForConfirmation { .. },
320                    ..
321                })
322            )
323        },
324        cx,
325    )
326    .await;
327
328    thread.read_with(cx, |thread, cx| {
329        let AgentThreadEntry::ToolCall(ToolCall {
330            id,
331            label,
332            status: ToolCallStatus::WaitingForConfirmation { .. },
333            ..
334        }) = &thread.entries()[first_tool_call_ix]
335        else {
336            panic!("{:?}", thread.entries()[1]);
337        };
338
339        let label = label.read(cx).source();
340        assert!(label.contains("touch"), "Got: {}", label);
341
342        id.clone()
343    });
344
345    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
346    thread.read_with(cx, |thread, _cx| {
347        let AgentThreadEntry::ToolCall(ToolCall {
348            status: ToolCallStatus::Canceled,
349            ..
350        }) = &thread.entries()[first_tool_call_ix]
351        else {
352            panic!();
353        };
354    });
355
356    thread
357        .update(cx, |thread, cx| {
358            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
359        })
360        .await
361        .unwrap();
362    thread.read_with(cx, |thread, _| {
363        assert!(matches!(
364            &thread.entries().last().unwrap(),
365            AgentThreadEntry::AssistantMessage(..),
366        ))
367    });
368}
369
370pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
371where
372    T: AgentServer + 'static,
373    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
374{
375    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
376    let project = Project::test(fs.clone(), [], cx).await;
377    let thread = new_test_thread(
378        server(&fs, &project, cx).await,
379        project.clone(),
380        "/private/tmp",
381        cx,
382    )
383    .await;
384
385    thread
386        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
387        .await
388        .unwrap();
389
390    thread.read_with(cx, |thread, _| {
391        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
392    });
393
394    let weak_thread = thread.downgrade();
395    drop(thread);
396
397    cx.executor().run_until_parked();
398    assert!(!weak_thread.is_upgradable());
399}
400
401#[macro_export]
402macro_rules! common_e2e_tests {
403    ($server:expr, allow_option_id = $allow_option_id:expr) => {
404        mod common_e2e {
405            use super::*;
406
407            #[::gpui::test]
408            #[cfg_attr(not(feature = "e2e"), ignore)]
409            async fn basic(cx: &mut ::gpui::TestAppContext) {
410                $crate::e2e_tests::test_basic($server, cx).await;
411            }
412
413            #[::gpui::test]
414            #[cfg_attr(not(feature = "e2e"), ignore)]
415            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
416                $crate::e2e_tests::test_path_mentions($server, cx).await;
417            }
418
419            #[::gpui::test]
420            #[cfg_attr(not(feature = "e2e"), ignore)]
421            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
422                $crate::e2e_tests::test_tool_call($server, cx).await;
423            }
424
425            #[::gpui::test]
426            #[cfg_attr(not(feature = "e2e"), ignore)]
427            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
428                $crate::e2e_tests::test_tool_call_with_permission(
429                    $server,
430                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
431                    cx,
432                )
433                .await;
434            }
435
436            #[::gpui::test]
437            #[cfg_attr(not(feature = "e2e"), ignore)]
438            async fn cancel(cx: &mut ::gpui::TestAppContext) {
439                $crate::e2e_tests::test_cancel($server, cx).await;
440            }
441
442            #[::gpui::test]
443            #[cfg_attr(not(feature = "e2e"), ignore)]
444            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
445                $crate::e2e_tests::test_thread_drop($server, cx).await;
446            }
447        }
448    };
449}
450pub use common_e2e_tests;
451
452// Helpers
453
454pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
455    use settings::Settings;
456
457    env_logger::try_init().ok();
458
459    cx.update(|cx| {
460        let settings_store = settings::SettingsStore::test(cx);
461        cx.set_global(settings_store);
462        Project::init_settings(cx);
463        language::init(cx);
464        gpui_tokio::init(cx);
465        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
466        cx.set_http_client(Arc::new(http_client));
467        client::init_settings(cx);
468        let client = client::Client::production(cx);
469        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
470        language_model::init(client.clone(), cx);
471        language_models::init(user_store, client, cx);
472        agent_settings::init(cx);
473        AllAgentServersSettings::register(cx);
474
475        #[cfg(test)]
476        AllAgentServersSettings::override_global(
477            AllAgentServersSettings {
478                claude: Some(BuiltinAgentServerSettings {
479                    path: Some("claude-code-acp".into()),
480                    args: None,
481                    env: None,
482                    ignore_system_version: None,
483                    default_mode: None,
484                }),
485                gemini: Some(crate::gemini::tests::local_command().into()),
486                codex: Some(BuiltinAgentServerSettings {
487                    path: Some("codex-acp".into()),
488                    args: None,
489                    env: None,
490                    ignore_system_version: None,
491                    default_mode: None,
492                }),
493                custom: collections::HashMap::default(),
494            },
495            cx,
496        );
497    });
498
499    cx.executor().allow_parking();
500
501    FakeFs::new(cx.executor())
502}
503
504pub async fn new_test_thread(
505    server: impl AgentServer + 'static,
506    project: Entity<Project>,
507    current_dir: impl AsRef<Path>,
508    cx: &mut TestAppContext,
509) -> Entity<AcpThread> {
510    let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
511    let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
512
513    let (connection, _) = cx
514        .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
515        .await
516        .unwrap();
517
518    cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
519        .await
520        .unwrap()
521}
522
523pub async fn run_until_first_tool_call(
524    thread: &Entity<AcpThread>,
525    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
526    cx: &mut TestAppContext,
527) -> usize {
528    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
529
530    let subscription = cx.update(|cx| {
531        cx.subscribe(thread, move |thread, _, cx| {
532            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
533                if wait_until(entry) {
534                    return tx.try_send(ix).unwrap();
535                }
536            }
537        })
538    });
539
540    select! {
541        // We have to use a smol timer here because
542        // cx.background_executor().timer isn't real in the test context
543        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
544            panic!("Timeout waiting for tool call")
545        }
546        ix = rx.next().fuse() => {
547            drop(subscription);
548            ix.unwrap()
549        }
550    }
551}
552
553pub fn get_zed_path() -> PathBuf {
554    let mut zed_path = std::env::current_exe().unwrap();
555
556    while zed_path
557        .file_name()
558        .is_none_or(|name| name.to_string_lossy() != "debug")
559    {
560        if !zed_path.pop() {
561            panic!("Could not find target directory");
562        }
563    }
564
565    zed_path.push("zed");
566
567    if !zed_path.exists() {
568        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
569    }
570
571    zed_path
572}