e2e_tests.rs

  1use crate::{AgentServer, AgentServerDelegate};
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::{AppContext, Entity, TestAppContext};
  6use indoc::indoc;
  7#[cfg(test)]
  8use project::agent_server_store::BuiltinAgentServerSettings;
  9use project::{FakeFs, Project};
 10#[cfg(test)]
 11use settings::Settings;
 12use std::{
 13    path::{Path, PathBuf},
 14    sync::Arc,
 15    time::Duration,
 16};
 17use util::path;
 18
 19pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 20where
 21    T: AgentServer + 'static,
 22    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 23{
 24    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 25    let project = Project::test(fs.clone(), [], cx).await;
 26    let thread = new_test_thread(
 27        server(&fs, &project, cx).await,
 28        project.clone(),
 29        "/private/tmp",
 30        cx,
 31    )
 32    .await;
 33
 34    thread
 35        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 36        .await
 37        .unwrap();
 38
 39    thread.read_with(cx, |thread, _| {
 40        assert!(
 41            thread.entries().len() >= 2,
 42            "Expected at least 2 entries. Got: {:?}",
 43            thread.entries()
 44        );
 45        assert!(matches!(
 46            thread.entries()[0],
 47            AgentThreadEntry::UserMessage(_)
 48        ));
 49        assert!(matches!(
 50            thread.entries()[1],
 51            AgentThreadEntry::AssistantMessage(_)
 52        ));
 53    });
 54}
 55
 56pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 57where
 58    T: AgentServer + 'static,
 59    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 60{
 61    let fs = init_test(cx).await as _;
 62
 63    let tempdir = tempfile::tempdir().unwrap();
 64    std::fs::write(
 65        tempdir.path().join("foo.rs"),
 66        indoc! {"
 67            fn main() {
 68                println!(\"Hello, world!\");
 69            }
 70        "},
 71    )
 72    .expect("failed to write file");
 73    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 74    let thread = new_test_thread(
 75        server(&fs, &project, cx).await,
 76        project.clone(),
 77        tempdir.path(),
 78        cx,
 79    )
 80    .await;
 81    thread
 82        .update(cx, |thread, cx| {
 83            thread.send(
 84                vec![
 85                    acp::ContentBlock::Text(acp::TextContent {
 86                        text: "Read the file ".into(),
 87                        annotations: None,
 88                        meta: None,
 89                    }),
 90                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 91                        uri: "foo.rs".into(),
 92                        name: "foo.rs".into(),
 93                        annotations: None,
 94                        description: None,
 95                        mime_type: None,
 96                        size: None,
 97                        title: None,
 98                        meta: None,
 99                    }),
100                    acp::ContentBlock::Text(acp::TextContent {
101                        text: " and tell me what the content of the println! is".into(),
102                        annotations: None,
103                        meta: None,
104                    }),
105                ],
106                cx,
107            )
108        })
109        .await
110        .unwrap();
111
112    thread.read_with(cx, |thread, cx| {
113        assert!(matches!(
114            thread.entries()[0],
115            AgentThreadEntry::UserMessage(_)
116        ));
117        let assistant_message = &thread
118            .entries()
119            .iter()
120            .rev()
121            .find_map(|entry| match entry {
122                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
123                _ => None,
124            })
125            .unwrap();
126
127        assert!(
128            assistant_message.to_markdown(cx).contains("Hello, world!"),
129            "unexpected assistant message: {:?}",
130            assistant_message.to_markdown(cx)
131        );
132    });
133
134    drop(tempdir);
135}
136
137pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
138where
139    T: AgentServer + 'static,
140    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
141{
142    let fs = init_test(cx).await as _;
143
144    let tempdir = tempfile::tempdir().unwrap();
145    let foo_path = tempdir.path().join("foo");
146    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
147
148    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
149    let thread = new_test_thread(
150        server(&fs, &project, cx).await,
151        project.clone(),
152        "/private/tmp",
153        cx,
154    )
155    .await;
156
157    thread
158        .update(cx, |thread, cx| {
159            thread.send_raw(
160                &format!("Read {} and tell me what you see.", foo_path.display()),
161                cx,
162            )
163        })
164        .await
165        .unwrap();
166    thread.read_with(cx, |thread, _cx| {
167        assert!(thread.entries().iter().any(|entry| {
168            matches!(
169                entry,
170                AgentThreadEntry::ToolCall(ToolCall {
171                    status: ToolCallStatus::Pending
172                        | ToolCallStatus::InProgress
173                        | ToolCallStatus::Completed,
174                    ..
175                })
176            )
177        }));
178        assert!(
179            thread
180                .entries()
181                .iter()
182                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
183        );
184    });
185
186    drop(tempdir);
187}
188
189pub async fn test_tool_call_with_permission<T, F>(
190    server: F,
191    allow_option_id: acp::PermissionOptionId,
192    cx: &mut TestAppContext,
193) where
194    T: AgentServer + 'static,
195    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
196{
197    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
198    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
199    let thread = new_test_thread(
200        server(&fs, &project, cx).await,
201        project.clone(),
202        "/private/tmp",
203        cx,
204    )
205    .await;
206    let full_turn = thread.update(cx, |thread, cx| {
207        thread.send_raw(
208            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
209            cx,
210        )
211    });
212
213    run_until_first_tool_call(
214        &thread,
215        |entry| {
216            matches!(
217                entry,
218                AgentThreadEntry::ToolCall(ToolCall {
219                    status: ToolCallStatus::WaitingForConfirmation { .. },
220                    ..
221                })
222            )
223        },
224        cx,
225    )
226    .await;
227
228    let tool_call_id = thread.read_with(cx, |thread, cx| {
229        let AgentThreadEntry::ToolCall(ToolCall {
230            id,
231            label,
232            status: ToolCallStatus::WaitingForConfirmation { .. },
233            ..
234        }) = &thread
235            .entries()
236            .iter()
237            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
238            .unwrap()
239        else {
240            panic!();
241        };
242
243        let label = label.read(cx).source();
244        assert!(label.contains("touch"), "Got: {}", label);
245
246        id.clone()
247    });
248
249    thread.update(cx, |thread, cx| {
250        thread.authorize_tool_call(
251            tool_call_id,
252            allow_option_id,
253            acp::PermissionOptionKind::AllowOnce,
254            cx,
255        );
256
257        assert!(thread.entries().iter().any(|entry| matches!(
258            entry,
259            AgentThreadEntry::ToolCall(ToolCall {
260                status: ToolCallStatus::Pending
261                    | ToolCallStatus::InProgress
262                    | ToolCallStatus::Completed,
263                ..
264            })
265        )));
266    });
267
268    full_turn.await.unwrap();
269
270    thread.read_with(cx, |thread, cx| {
271        let AgentThreadEntry::ToolCall(ToolCall {
272            content,
273            status: ToolCallStatus::Pending
274                | ToolCallStatus::InProgress
275                | ToolCallStatus::Completed,
276            ..
277        }) = thread
278            .entries()
279            .iter()
280            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
281            .unwrap()
282        else {
283            panic!();
284        };
285
286        assert!(
287            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
288            "Expected content to contain 'Hello'"
289        );
290    });
291}
292
293pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
294where
295    T: AgentServer + 'static,
296    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
297{
298    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
299
300    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
301    let thread = new_test_thread(
302        server(&fs, &project, cx).await,
303        project.clone(),
304        "/private/tmp",
305        cx,
306    )
307    .await;
308    let _ = thread.update(cx, |thread, cx| {
309        thread.send_raw(
310            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
311            cx,
312        )
313    });
314
315    let first_tool_call_ix = run_until_first_tool_call(
316        &thread,
317        |entry| {
318            matches!(
319                entry,
320                AgentThreadEntry::ToolCall(ToolCall {
321                    status: ToolCallStatus::WaitingForConfirmation { .. },
322                    ..
323                })
324            )
325        },
326        cx,
327    )
328    .await;
329
330    thread.read_with(cx, |thread, cx| {
331        let AgentThreadEntry::ToolCall(ToolCall {
332            id,
333            label,
334            status: ToolCallStatus::WaitingForConfirmation { .. },
335            ..
336        }) = &thread.entries()[first_tool_call_ix]
337        else {
338            panic!("{:?}", thread.entries()[1]);
339        };
340
341        let label = label.read(cx).source();
342        assert!(label.contains("touch"), "Got: {}", label);
343
344        id.clone()
345    });
346
347    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
348    thread.read_with(cx, |thread, _cx| {
349        let AgentThreadEntry::ToolCall(ToolCall {
350            status: ToolCallStatus::Canceled,
351            ..
352        }) = &thread.entries()[first_tool_call_ix]
353        else {
354            panic!();
355        };
356    });
357
358    thread
359        .update(cx, |thread, cx| {
360            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
361        })
362        .await
363        .unwrap();
364    thread.read_with(cx, |thread, _| {
365        assert!(matches!(
366            &thread.entries().last().unwrap(),
367            AgentThreadEntry::AssistantMessage(..),
368        ))
369    });
370}
371
372pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
373where
374    T: AgentServer + 'static,
375    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
376{
377    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
378    let project = Project::test(fs.clone(), [], cx).await;
379    let thread = new_test_thread(
380        server(&fs, &project, cx).await,
381        project.clone(),
382        "/private/tmp",
383        cx,
384    )
385    .await;
386
387    thread
388        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
389        .await
390        .unwrap();
391
392    thread.read_with(cx, |thread, _| {
393        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
394    });
395
396    let weak_thread = thread.downgrade();
397    drop(thread);
398
399    cx.executor().run_until_parked();
400    assert!(!weak_thread.is_upgradable());
401}
402
403#[macro_export]
404macro_rules! common_e2e_tests {
405    ($server:expr, allow_option_id = $allow_option_id:expr) => {
406        mod common_e2e {
407            use super::*;
408
409            #[::gpui::test]
410            #[cfg_attr(not(feature = "e2e"), ignore)]
411            async fn basic(cx: &mut ::gpui::TestAppContext) {
412                $crate::e2e_tests::test_basic($server, cx).await;
413            }
414
415            #[::gpui::test]
416            #[cfg_attr(not(feature = "e2e"), ignore)]
417            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
418                $crate::e2e_tests::test_path_mentions($server, cx).await;
419            }
420
421            #[::gpui::test]
422            #[cfg_attr(not(feature = "e2e"), ignore)]
423            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
424                $crate::e2e_tests::test_tool_call($server, cx).await;
425            }
426
427            #[::gpui::test]
428            #[cfg_attr(not(feature = "e2e"), ignore)]
429            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
430                $crate::e2e_tests::test_tool_call_with_permission(
431                    $server,
432                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
433                    cx,
434                )
435                .await;
436            }
437
438            #[::gpui::test]
439            #[cfg_attr(not(feature = "e2e"), ignore)]
440            async fn cancel(cx: &mut ::gpui::TestAppContext) {
441                $crate::e2e_tests::test_cancel($server, cx).await;
442            }
443
444            #[::gpui::test]
445            #[cfg_attr(not(feature = "e2e"), ignore)]
446            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
447                $crate::e2e_tests::test_thread_drop($server, cx).await;
448            }
449        }
450    };
451}
452pub use common_e2e_tests;
453
454// Helpers
455
456pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
457    env_logger::try_init().ok();
458
459    cx.update(|cx| {
460        let settings_store = settings::SettingsStore::test(cx);
461        cx.set_global(settings_store);
462        gpui_tokio::init(cx);
463        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
464        cx.set_http_client(Arc::new(http_client));
465        let client = client::Client::production(cx);
466        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
467        language_model::init(client.clone(), cx);
468        language_models::init(user_store, client, cx);
469
470        #[cfg(test)]
471        project::agent_server_store::AllAgentServersSettings::override_global(
472            project::agent_server_store::AllAgentServersSettings {
473                claude: Some(BuiltinAgentServerSettings {
474                    path: Some("claude-code-acp".into()),
475                    args: None,
476                    env: None,
477                    ignore_system_version: None,
478                    default_mode: None,
479                }),
480                gemini: Some(crate::gemini::tests::local_command().into()),
481                codex: Some(BuiltinAgentServerSettings {
482                    path: Some("codex-acp".into()),
483                    args: None,
484                    env: None,
485                    ignore_system_version: None,
486                    default_mode: None,
487                }),
488                custom: collections::HashMap::default(),
489            },
490            cx,
491        );
492    });
493
494    cx.executor().allow_parking();
495
496    FakeFs::new(cx.executor())
497}
498
499pub async fn new_test_thread(
500    server: impl AgentServer + 'static,
501    project: Entity<Project>,
502    current_dir: impl AsRef<Path>,
503    cx: &mut TestAppContext,
504) -> Entity<AcpThread> {
505    let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
506    let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
507
508    let (connection, _) = cx
509        .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
510        .await
511        .unwrap();
512
513    cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
514        .await
515        .unwrap()
516}
517
518pub async fn run_until_first_tool_call(
519    thread: &Entity<AcpThread>,
520    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
521    cx: &mut TestAppContext,
522) -> usize {
523    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
524
525    let subscription = cx.update(|cx| {
526        cx.subscribe(thread, move |thread, _, cx| {
527            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
528                if wait_until(entry) {
529                    return tx.try_send(ix).unwrap();
530                }
531            }
532        })
533    });
534
535    select! {
536        // We have to use a smol timer here because
537        // cx.background_executor().timer isn't real in the test context
538        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
539            panic!("Timeout waiting for tool call")
540        }
541        ix = rx.next().fuse() => {
542            drop(subscription);
543            ix.unwrap()
544        }
545    }
546}
547
548pub fn get_zed_path() -> PathBuf {
549    let mut zed_path = std::env::current_exe().unwrap();
550
551    while zed_path
552        .file_name()
553        .is_none_or(|name| name.to_string_lossy() != "debug")
554    {
555        if !zed_path.pop() {
556            panic!("Could not find target directory");
557        }
558    }
559
560    zed_path.push("zed");
561
562    if !zed_path.exists() {
563        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
564    }
565
566    zed_path
567}