e2e_tests.rs

  1use crate::{AgentServer, AgentServerDelegate};
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::{AppContext, Entity, TestAppContext};
  6use indoc::indoc;
  7#[cfg(test)]
  8use project::agent_server_store::BuiltinAgentServerSettings;
  9use project::{FakeFs, Project, agent_server_store::AllAgentServersSettings};
 10use std::{
 11    path::{Path, PathBuf},
 12    sync::Arc,
 13    time::Duration,
 14};
 15use util::path;
 16
 17pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 18where
 19    T: AgentServer + 'static,
 20    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 21{
 22    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 23    let project = Project::test(fs.clone(), [], cx).await;
 24    let thread = new_test_thread(
 25        server(&fs, &project, cx).await,
 26        project.clone(),
 27        "/private/tmp",
 28        cx,
 29    )
 30    .await;
 31
 32    thread
 33        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 34        .await
 35        .unwrap();
 36
 37    thread.read_with(cx, |thread, _| {
 38        assert!(
 39            thread.entries().len() >= 2,
 40            "Expected at least 2 entries. Got: {:?}",
 41            thread.entries()
 42        );
 43        assert!(matches!(
 44            thread.entries()[0],
 45            AgentThreadEntry::UserMessage(_)
 46        ));
 47        assert!(matches!(
 48            thread.entries()[1],
 49            AgentThreadEntry::AssistantMessage(_)
 50        ));
 51    });
 52}
 53
 54pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 55where
 56    T: AgentServer + 'static,
 57    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 58{
 59    let fs = init_test(cx).await as _;
 60
 61    let tempdir = tempfile::tempdir().unwrap();
 62    std::fs::write(
 63        tempdir.path().join("foo.rs"),
 64        indoc! {"
 65            fn main() {
 66                println!(\"Hello, world!\");
 67            }
 68        "},
 69    )
 70    .expect("failed to write file");
 71    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 72    let thread = new_test_thread(
 73        server(&fs, &project, cx).await,
 74        project.clone(),
 75        tempdir.path(),
 76        cx,
 77    )
 78    .await;
 79    thread
 80        .update(cx, |thread, cx| {
 81            thread.send(
 82                vec![
 83                    acp::ContentBlock::Text(acp::TextContent {
 84                        text: "Read the file ".into(),
 85                        annotations: None,
 86                    }),
 87                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 88                        uri: "foo.rs".into(),
 89                        name: "foo.rs".into(),
 90                        annotations: None,
 91                        description: None,
 92                        mime_type: None,
 93                        size: None,
 94                        title: None,
 95                    }),
 96                    acp::ContentBlock::Text(acp::TextContent {
 97                        text: " and tell me what the content of the println! is".into(),
 98                        annotations: None,
 99                    }),
100                ],
101                cx,
102            )
103        })
104        .await
105        .unwrap();
106
107    thread.read_with(cx, |thread, cx| {
108        assert!(matches!(
109            thread.entries()[0],
110            AgentThreadEntry::UserMessage(_)
111        ));
112        let assistant_message = &thread
113            .entries()
114            .iter()
115            .rev()
116            .find_map(|entry| match entry {
117                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
118                _ => None,
119            })
120            .unwrap();
121
122        assert!(
123            assistant_message.to_markdown(cx).contains("Hello, world!"),
124            "unexpected assistant message: {:?}",
125            assistant_message.to_markdown(cx)
126        );
127    });
128
129    drop(tempdir);
130}
131
132pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
133where
134    T: AgentServer + 'static,
135    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
136{
137    let fs = init_test(cx).await as _;
138
139    let tempdir = tempfile::tempdir().unwrap();
140    let foo_path = tempdir.path().join("foo");
141    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
142
143    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
144    let thread = new_test_thread(
145        server(&fs, &project, cx).await,
146        project.clone(),
147        "/private/tmp",
148        cx,
149    )
150    .await;
151
152    thread
153        .update(cx, |thread, cx| {
154            thread.send_raw(
155                &format!("Read {} and tell me what you see.", foo_path.display()),
156                cx,
157            )
158        })
159        .await
160        .unwrap();
161    thread.read_with(cx, |thread, _cx| {
162        assert!(thread.entries().iter().any(|entry| {
163            matches!(
164                entry,
165                AgentThreadEntry::ToolCall(ToolCall {
166                    status: ToolCallStatus::Pending
167                        | ToolCallStatus::InProgress
168                        | ToolCallStatus::Completed,
169                    ..
170                })
171            )
172        }));
173        assert!(
174            thread
175                .entries()
176                .iter()
177                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
178        );
179    });
180
181    drop(tempdir);
182}
183
184pub async fn test_tool_call_with_permission<T, F>(
185    server: F,
186    allow_option_id: acp::PermissionOptionId,
187    cx: &mut TestAppContext,
188) where
189    T: AgentServer + 'static,
190    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
191{
192    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
193    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
194    let thread = new_test_thread(
195        server(&fs, &project, cx).await,
196        project.clone(),
197        "/private/tmp",
198        cx,
199    )
200    .await;
201    let full_turn = thread.update(cx, |thread, cx| {
202        thread.send_raw(
203            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
204            cx,
205        )
206    });
207
208    run_until_first_tool_call(
209        &thread,
210        |entry| {
211            matches!(
212                entry,
213                AgentThreadEntry::ToolCall(ToolCall {
214                    status: ToolCallStatus::WaitingForConfirmation { .. },
215                    ..
216                })
217            )
218        },
219        cx,
220    )
221    .await;
222
223    let tool_call_id = thread.read_with(cx, |thread, cx| {
224        let AgentThreadEntry::ToolCall(ToolCall {
225            id,
226            label,
227            status: ToolCallStatus::WaitingForConfirmation { .. },
228            ..
229        }) = &thread
230            .entries()
231            .iter()
232            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
233            .unwrap()
234        else {
235            panic!();
236        };
237
238        let label = label.read(cx).source();
239        assert!(label.contains("touch"), "Got: {}", label);
240
241        id.clone()
242    });
243
244    thread.update(cx, |thread, cx| {
245        thread.authorize_tool_call(
246            tool_call_id,
247            allow_option_id,
248            acp::PermissionOptionKind::AllowOnce,
249            cx,
250        );
251
252        assert!(thread.entries().iter().any(|entry| matches!(
253            entry,
254            AgentThreadEntry::ToolCall(ToolCall {
255                status: ToolCallStatus::Pending
256                    | ToolCallStatus::InProgress
257                    | ToolCallStatus::Completed,
258                ..
259            })
260        )));
261    });
262
263    full_turn.await.unwrap();
264
265    thread.read_with(cx, |thread, cx| {
266        let AgentThreadEntry::ToolCall(ToolCall {
267            content,
268            status: ToolCallStatus::Pending
269                | ToolCallStatus::InProgress
270                | ToolCallStatus::Completed,
271            ..
272        }) = thread
273            .entries()
274            .iter()
275            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
276            .unwrap()
277        else {
278            panic!();
279        };
280
281        assert!(
282            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
283            "Expected content to contain 'Hello'"
284        );
285    });
286}
287
288pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
289where
290    T: AgentServer + 'static,
291    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
292{
293    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
294
295    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
296    let thread = new_test_thread(
297        server(&fs, &project, cx).await,
298        project.clone(),
299        "/private/tmp",
300        cx,
301    )
302    .await;
303    let _ = thread.update(cx, |thread, cx| {
304        thread.send_raw(
305            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
306            cx,
307        )
308    });
309
310    let first_tool_call_ix = run_until_first_tool_call(
311        &thread,
312        |entry| {
313            matches!(
314                entry,
315                AgentThreadEntry::ToolCall(ToolCall {
316                    status: ToolCallStatus::WaitingForConfirmation { .. },
317                    ..
318                })
319            )
320        },
321        cx,
322    )
323    .await;
324
325    thread.read_with(cx, |thread, cx| {
326        let AgentThreadEntry::ToolCall(ToolCall {
327            id,
328            label,
329            status: ToolCallStatus::WaitingForConfirmation { .. },
330            ..
331        }) = &thread.entries()[first_tool_call_ix]
332        else {
333            panic!("{:?}", thread.entries()[1]);
334        };
335
336        let label = label.read(cx).source();
337        assert!(label.contains("touch"), "Got: {}", label);
338
339        id.clone()
340    });
341
342    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
343    thread.read_with(cx, |thread, _cx| {
344        let AgentThreadEntry::ToolCall(ToolCall {
345            status: ToolCallStatus::Canceled,
346            ..
347        }) = &thread.entries()[first_tool_call_ix]
348        else {
349            panic!();
350        };
351    });
352
353    thread
354        .update(cx, |thread, cx| {
355            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
356        })
357        .await
358        .unwrap();
359    thread.read_with(cx, |thread, _| {
360        assert!(matches!(
361            &thread.entries().last().unwrap(),
362            AgentThreadEntry::AssistantMessage(..),
363        ))
364    });
365}
366
367pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
368where
369    T: AgentServer + 'static,
370    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
371{
372    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
373    let project = Project::test(fs.clone(), [], cx).await;
374    let thread = new_test_thread(
375        server(&fs, &project, cx).await,
376        project.clone(),
377        "/private/tmp",
378        cx,
379    )
380    .await;
381
382    thread
383        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
384        .await
385        .unwrap();
386
387    thread.read_with(cx, |thread, _| {
388        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
389    });
390
391    let weak_thread = thread.downgrade();
392    drop(thread);
393
394    cx.executor().run_until_parked();
395    assert!(!weak_thread.is_upgradable());
396}
397
398#[macro_export]
399macro_rules! common_e2e_tests {
400    ($server:expr, allow_option_id = $allow_option_id:expr) => {
401        mod common_e2e {
402            use super::*;
403
404            #[::gpui::test]
405            #[cfg_attr(not(feature = "e2e"), ignore)]
406            async fn basic(cx: &mut ::gpui::TestAppContext) {
407                $crate::e2e_tests::test_basic($server, cx).await;
408            }
409
410            #[::gpui::test]
411            #[cfg_attr(not(feature = "e2e"), ignore)]
412            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
413                $crate::e2e_tests::test_path_mentions($server, cx).await;
414            }
415
416            #[::gpui::test]
417            #[cfg_attr(not(feature = "e2e"), ignore)]
418            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
419                $crate::e2e_tests::test_tool_call($server, cx).await;
420            }
421
422            #[::gpui::test]
423            #[cfg_attr(not(feature = "e2e"), ignore)]
424            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
425                $crate::e2e_tests::test_tool_call_with_permission(
426                    $server,
427                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
428                    cx,
429                )
430                .await;
431            }
432
433            #[::gpui::test]
434            #[cfg_attr(not(feature = "e2e"), ignore)]
435            async fn cancel(cx: &mut ::gpui::TestAppContext) {
436                $crate::e2e_tests::test_cancel($server, cx).await;
437            }
438
439            #[::gpui::test]
440            #[cfg_attr(not(feature = "e2e"), ignore)]
441            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
442                $crate::e2e_tests::test_thread_drop($server, cx).await;
443            }
444        }
445    };
446}
447pub use common_e2e_tests;
448
449// Helpers
450
451pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
452    use settings::Settings;
453
454    env_logger::try_init().ok();
455
456    cx.update(|cx| {
457        let settings_store = settings::SettingsStore::test(cx);
458        cx.set_global(settings_store);
459        Project::init_settings(cx);
460        language::init(cx);
461        gpui_tokio::init(cx);
462        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
463        cx.set_http_client(Arc::new(http_client));
464        client::init_settings(cx);
465        let client = client::Client::production(cx);
466        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
467        language_model::init(client.clone(), cx);
468        language_models::init(user_store, client, cx);
469        agent_settings::init(cx);
470        AllAgentServersSettings::register(cx);
471
472        #[cfg(test)]
473        AllAgentServersSettings::override_global(
474            AllAgentServersSettings {
475                claude: Some(BuiltinAgentServerSettings {
476                    path: Some("claude-code-acp".into()),
477                    args: None,
478                    env: None,
479                    ignore_system_version: None,
480                    default_mode: None,
481                }),
482                gemini: Some(crate::gemini::tests::local_command().into()),
483                custom: collections::HashMap::default(),
484            },
485            cx,
486        );
487    });
488
489    cx.executor().allow_parking();
490
491    FakeFs::new(cx.executor())
492}
493
494pub async fn new_test_thread(
495    server: impl AgentServer + 'static,
496    project: Entity<Project>,
497    current_dir: impl AsRef<Path>,
498    cx: &mut TestAppContext,
499) -> Entity<AcpThread> {
500    let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
501    let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
502
503    let (connection, _) = cx
504        .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
505        .await
506        .unwrap();
507
508    cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
509        .await
510        .unwrap()
511}
512
513pub async fn run_until_first_tool_call(
514    thread: &Entity<AcpThread>,
515    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
516    cx: &mut TestAppContext,
517) -> usize {
518    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
519
520    let subscription = cx.update(|cx| {
521        cx.subscribe(thread, move |thread, _, cx| {
522            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
523                if wait_until(entry) {
524                    return tx.try_send(ix).unwrap();
525                }
526            }
527        })
528    });
529
530    select! {
531        // We have to use a smol timer here because
532        // cx.background_executor().timer isn't real in the test context
533        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
534            panic!("Timeout waiting for tool call")
535        }
536        ix = rx.next().fuse() => {
537            drop(subscription);
538            ix.unwrap()
539        }
540    }
541}
542
543pub fn get_zed_path() -> PathBuf {
544    let mut zed_path = std::env::current_exe().unwrap();
545
546    while zed_path
547        .file_name()
548        .is_none_or(|name| name.to_string_lossy() != "debug")
549    {
550        if !zed_path.pop() {
551            panic!("Could not find target directory");
552        }
553    }
554
555    zed_path.push("zed");
556
557    if !zed_path.exists() {
558        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
559    }
560
561    zed_path
562}