e2e_tests.rs

  1use crate::{AgentServer, AgentServerDelegate};
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::{AppContext, Entity, TestAppContext};
  6use indoc::indoc;
  7#[cfg(test)]
  8use project::agent_server_store::BuiltinAgentServerSettings;
  9use project::{FakeFs, Project};
 10#[cfg(test)]
 11use settings::Settings;
 12use std::{
 13    path::{Path, PathBuf},
 14    sync::Arc,
 15    time::Duration,
 16};
 17use util::path;
 18
 19pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 20where
 21    T: AgentServer + 'static,
 22    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 23{
 24    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 25    let project = Project::test(fs.clone(), [], cx).await;
 26    let thread = new_test_thread(
 27        server(&fs, &project, cx).await,
 28        project.clone(),
 29        "/private/tmp",
 30        cx,
 31    )
 32    .await;
 33
 34    thread
 35        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 36        .await
 37        .unwrap();
 38
 39    thread.read_with(cx, |thread, _| {
 40        assert!(
 41            thread.entries().len() >= 2,
 42            "Expected at least 2 entries. Got: {:?}",
 43            thread.entries()
 44        );
 45        assert!(matches!(
 46            thread.entries()[0],
 47            AgentThreadEntry::UserMessage(_)
 48        ));
 49        assert!(matches!(
 50            thread.entries()[1],
 51            AgentThreadEntry::AssistantMessage(_)
 52        ));
 53    });
 54}
 55
 56pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 57where
 58    T: AgentServer + 'static,
 59    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
 60{
 61    let fs = init_test(cx).await as _;
 62
 63    let tempdir = tempfile::tempdir().unwrap();
 64    std::fs::write(
 65        tempdir.path().join("foo.rs"),
 66        indoc! {"
 67            fn main() {
 68                println!(\"Hello, world!\");
 69            }
 70        "},
 71    )
 72    .expect("failed to write file");
 73    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 74    let thread = new_test_thread(
 75        server(&fs, &project, cx).await,
 76        project.clone(),
 77        tempdir.path(),
 78        cx,
 79    )
 80    .await;
 81    thread
 82        .update(cx, |thread, cx| {
 83            thread.send(
 84                vec![
 85                    "Read the file ".into(),
 86                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
 87                    " and tell me what the content of the println! is".into(),
 88                ],
 89                cx,
 90            )
 91        })
 92        .await
 93        .unwrap();
 94
 95    thread.read_with(cx, |thread, cx| {
 96        assert!(matches!(
 97            thread.entries()[0],
 98            AgentThreadEntry::UserMessage(_)
 99        ));
100        let assistant_message = &thread
101            .entries()
102            .iter()
103            .rev()
104            .find_map(|entry| match entry {
105                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
106                _ => None,
107            })
108            .unwrap();
109
110        assert!(
111            assistant_message.to_markdown(cx).contains("Hello, world!"),
112            "unexpected assistant message: {:?}",
113            assistant_message.to_markdown(cx)
114        );
115    });
116
117    drop(tempdir);
118}
119
120pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
121where
122    T: AgentServer + 'static,
123    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
124{
125    let fs = init_test(cx).await as _;
126
127    let tempdir = tempfile::tempdir().unwrap();
128    let foo_path = tempdir.path().join("foo");
129    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
130
131    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
132    let thread = new_test_thread(
133        server(&fs, &project, cx).await,
134        project.clone(),
135        "/private/tmp",
136        cx,
137    )
138    .await;
139
140    thread
141        .update(cx, |thread, cx| {
142            thread.send_raw(
143                &format!("Read {} and tell me what you see.", foo_path.display()),
144                cx,
145            )
146        })
147        .await
148        .unwrap();
149    thread.read_with(cx, |thread, _cx| {
150        assert!(thread.entries().iter().any(|entry| {
151            matches!(
152                entry,
153                AgentThreadEntry::ToolCall(ToolCall {
154                    status: ToolCallStatus::Pending
155                        | ToolCallStatus::InProgress
156                        | ToolCallStatus::Completed,
157                    ..
158                })
159            )
160        }));
161        assert!(
162            thread
163                .entries()
164                .iter()
165                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
166        );
167    });
168
169    drop(tempdir);
170}
171
172pub async fn test_tool_call_with_permission<T, F>(
173    server: F,
174    allow_option_id: acp::PermissionOptionId,
175    cx: &mut TestAppContext,
176) where
177    T: AgentServer + 'static,
178    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
179{
180    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
181    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
182    let thread = new_test_thread(
183        server(&fs, &project, cx).await,
184        project.clone(),
185        "/private/tmp",
186        cx,
187    )
188    .await;
189    let full_turn = thread.update(cx, |thread, cx| {
190        thread.send_raw(
191            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
192            cx,
193        )
194    });
195
196    run_until_first_tool_call(
197        &thread,
198        |entry| {
199            matches!(
200                entry,
201                AgentThreadEntry::ToolCall(ToolCall {
202                    status: ToolCallStatus::WaitingForConfirmation { .. },
203                    ..
204                })
205            )
206        },
207        cx,
208    )
209    .await;
210
211    let tool_call_id = thread.read_with(cx, |thread, cx| {
212        let AgentThreadEntry::ToolCall(ToolCall {
213            id,
214            label,
215            status: ToolCallStatus::WaitingForConfirmation { .. },
216            ..
217        }) = &thread
218            .entries()
219            .iter()
220            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
221            .unwrap()
222        else {
223            panic!();
224        };
225
226        let label = label.read(cx).source();
227        assert!(label.contains("touch"), "Got: {}", label);
228
229        id.clone()
230    });
231
232    thread.update(cx, |thread, cx| {
233        thread.authorize_tool_call(
234            tool_call_id,
235            allow_option_id,
236            acp::PermissionOptionKind::AllowOnce,
237            cx,
238        );
239
240        assert!(thread.entries().iter().any(|entry| matches!(
241            entry,
242            AgentThreadEntry::ToolCall(ToolCall {
243                status: ToolCallStatus::Pending
244                    | ToolCallStatus::InProgress
245                    | ToolCallStatus::Completed,
246                ..
247            })
248        )));
249    });
250
251    full_turn.await.unwrap();
252
253    thread.read_with(cx, |thread, cx| {
254        let AgentThreadEntry::ToolCall(ToolCall {
255            content,
256            status: ToolCallStatus::Pending
257                | ToolCallStatus::InProgress
258                | ToolCallStatus::Completed,
259            ..
260        }) = thread
261            .entries()
262            .iter()
263            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
264            .unwrap()
265        else {
266            panic!();
267        };
268
269        assert!(
270            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
271            "Expected content to contain 'Hello'"
272        );
273    });
274}
275
276pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
277where
278    T: AgentServer + 'static,
279    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
280{
281    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
282
283    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
284    let thread = new_test_thread(
285        server(&fs, &project, cx).await,
286        project.clone(),
287        "/private/tmp",
288        cx,
289    )
290    .await;
291    let _ = thread.update(cx, |thread, cx| {
292        thread.send_raw(
293            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
294            cx,
295        )
296    });
297
298    let first_tool_call_ix = run_until_first_tool_call(
299        &thread,
300        |entry| {
301            matches!(
302                entry,
303                AgentThreadEntry::ToolCall(ToolCall {
304                    status: ToolCallStatus::WaitingForConfirmation { .. },
305                    ..
306                })
307            )
308        },
309        cx,
310    )
311    .await;
312
313    thread.read_with(cx, |thread, cx| {
314        let AgentThreadEntry::ToolCall(ToolCall {
315            id,
316            label,
317            status: ToolCallStatus::WaitingForConfirmation { .. },
318            ..
319        }) = &thread.entries()[first_tool_call_ix]
320        else {
321            panic!("{:?}", thread.entries()[1]);
322        };
323
324        let label = label.read(cx).source();
325        assert!(label.contains("touch"), "Got: {}", label);
326
327        id.clone()
328    });
329
330    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
331    thread.read_with(cx, |thread, _cx| {
332        let AgentThreadEntry::ToolCall(ToolCall {
333            status: ToolCallStatus::Canceled,
334            ..
335        }) = &thread.entries()[first_tool_call_ix]
336        else {
337            panic!();
338        };
339    });
340
341    thread
342        .update(cx, |thread, cx| {
343            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
344        })
345        .await
346        .unwrap();
347    thread.read_with(cx, |thread, _| {
348        assert!(matches!(
349            &thread.entries().last().unwrap(),
350            AgentThreadEntry::AssistantMessage(..),
351        ))
352    });
353}
354
355pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
356where
357    T: AgentServer + 'static,
358    F: AsyncFn(&Arc<dyn fs::Fs>, &Entity<Project>, &mut TestAppContext) -> T,
359{
360    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
361    let project = Project::test(fs.clone(), [], cx).await;
362    let thread = new_test_thread(
363        server(&fs, &project, cx).await,
364        project.clone(),
365        "/private/tmp",
366        cx,
367    )
368    .await;
369
370    thread
371        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
372        .await
373        .unwrap();
374
375    thread.read_with(cx, |thread, _| {
376        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
377    });
378
379    let weak_thread = thread.downgrade();
380    drop(thread);
381
382    cx.executor().run_until_parked();
383    assert!(!weak_thread.is_upgradable());
384}
385
386#[macro_export]
387macro_rules! common_e2e_tests {
388    ($server:expr, allow_option_id = $allow_option_id:expr) => {
389        mod common_e2e {
390            use super::*;
391
392            #[::gpui::test]
393            #[cfg_attr(not(feature = "e2e"), ignore)]
394            async fn basic(cx: &mut ::gpui::TestAppContext) {
395                $crate::e2e_tests::test_basic($server, cx).await;
396            }
397
398            #[::gpui::test]
399            #[cfg_attr(not(feature = "e2e"), ignore)]
400            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
401                $crate::e2e_tests::test_path_mentions($server, cx).await;
402            }
403
404            #[::gpui::test]
405            #[cfg_attr(not(feature = "e2e"), ignore)]
406            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
407                $crate::e2e_tests::test_tool_call($server, cx).await;
408            }
409
410            #[::gpui::test]
411            #[cfg_attr(not(feature = "e2e"), ignore)]
412            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
413                $crate::e2e_tests::test_tool_call_with_permission(
414                    $server,
415                    ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
416                    cx,
417                )
418                .await;
419            }
420
421            #[::gpui::test]
422            #[cfg_attr(not(feature = "e2e"), ignore)]
423            async fn cancel(cx: &mut ::gpui::TestAppContext) {
424                $crate::e2e_tests::test_cancel($server, cx).await;
425            }
426
427            #[::gpui::test]
428            #[cfg_attr(not(feature = "e2e"), ignore)]
429            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
430                $crate::e2e_tests::test_thread_drop($server, cx).await;
431            }
432        }
433    };
434}
435pub use common_e2e_tests;
436
437// Helpers
438
439pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
440    env_logger::try_init().ok();
441
442    cx.update(|cx| {
443        let settings_store = settings::SettingsStore::test(cx);
444        cx.set_global(settings_store);
445        gpui_tokio::init(cx);
446        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
447        cx.set_http_client(Arc::new(http_client));
448        let client = client::Client::production(cx);
449        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
450        language_model::init(client.clone(), cx);
451        language_models::init(user_store, client, cx);
452
453        #[cfg(test)]
454        project::agent_server_store::AllAgentServersSettings::override_global(
455            project::agent_server_store::AllAgentServersSettings {
456                claude: Some(BuiltinAgentServerSettings {
457                    path: Some("claude-code-acp".into()),
458                    args: None,
459                    env: None,
460                    ignore_system_version: None,
461                    default_mode: None,
462                    default_model: None,
463                }),
464                gemini: Some(crate::gemini::tests::local_command().into()),
465                codex: Some(BuiltinAgentServerSettings {
466                    path: Some("codex-acp".into()),
467                    args: None,
468                    env: None,
469                    ignore_system_version: None,
470                    default_mode: None,
471                    default_model: None,
472                }),
473                custom: collections::HashMap::default(),
474            },
475            cx,
476        );
477    });
478
479    cx.executor().allow_parking();
480
481    FakeFs::new(cx.executor())
482}
483
484pub async fn new_test_thread(
485    server: impl AgentServer + 'static,
486    project: Entity<Project>,
487    current_dir: impl AsRef<Path>,
488    cx: &mut TestAppContext,
489) -> Entity<AcpThread> {
490    let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
491    let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
492
493    let (connection, _) = cx
494        .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
495        .await
496        .unwrap();
497
498    cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
499        .await
500        .unwrap()
501}
502
503pub async fn run_until_first_tool_call(
504    thread: &Entity<AcpThread>,
505    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
506    cx: &mut TestAppContext,
507) -> usize {
508    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
509
510    let subscription = cx.update(|cx| {
511        cx.subscribe(thread, move |thread, _, cx| {
512            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
513                if wait_until(entry) {
514                    return tx.try_send(ix).unwrap();
515                }
516            }
517        })
518    });
519
520    select! {
521        // We have to use a smol timer here because
522        // cx.background_executor().timer isn't real in the test context
523        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
524            panic!("Timeout waiting for tool call")
525        }
526        ix = rx.next().fuse() => {
527            drop(subscription);
528            ix.unwrap()
529        }
530    }
531}
532
533pub fn get_zed_path() -> PathBuf {
534    let mut zed_path = std::env::current_exe().unwrap();
535
536    while zed_path
537        .file_name()
538        .is_none_or(|name| name.to_string_lossy() != "debug")
539    {
540        if !zed_path.pop() {
541            panic!("Could not find target directory");
542        }
543    }
544
545    zed_path.push("zed");
546
547    if !zed_path.exists() {
548        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
549    }
550
551    zed_path
552}