e2e_tests.rs

  1use crate::{AgentServer, AgentServerDelegate};
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::{AppContext, Entity, TestAppContext};
  6use indoc::indoc;
  7#[cfg(test)]
  8use project::agent_server_store::BuiltinAgentServerSettings;
  9use project::{FakeFs, Project};
 10#[cfg(test)]
 11use settings::Settings;
 12use std::{
 13    path::{Path, PathBuf},
 14    sync::Arc,
 15    time::Duration,
 16};
 17use util::path;
 18
 19pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 20where
 21    T: AgentServer + 'static,
 22    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
 23{
 24    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 25    let project = Project::test(fs.clone(), [], cx).await;
 26    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
 27
 28    thread
 29        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 30        .await
 31        .unwrap();
 32
 33    thread.read_with(cx, |thread, _| {
 34        assert!(
 35            thread.entries().len() >= 2,
 36            "Expected at least 2 entries. Got: {:?}",
 37            thread.entries()
 38        );
 39        assert!(matches!(
 40            thread.entries()[0],
 41            AgentThreadEntry::UserMessage(_)
 42        ));
 43        assert!(matches!(
 44            thread.entries()[1],
 45            AgentThreadEntry::AssistantMessage(_)
 46        ));
 47    });
 48}
 49
 50pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 51where
 52    T: AgentServer + 'static,
 53    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
 54{
 55    let fs = init_test(cx).await as _;
 56
 57    let tempdir = tempfile::tempdir().unwrap();
 58    std::fs::write(
 59        tempdir.path().join("foo.rs"),
 60        indoc! {"
 61            fn main() {
 62                println!(\"Hello, world!\");
 63            }
 64        "},
 65    )
 66    .expect("failed to write file");
 67    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 68    let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
 69    thread
 70        .update(cx, |thread, cx| {
 71            thread.send(
 72                vec![
 73                    "Read the file ".into(),
 74                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
 75                    " and tell me what the content of the println! is".into(),
 76                ],
 77                cx,
 78            )
 79        })
 80        .await
 81        .unwrap();
 82
 83    thread.read_with(cx, |thread, cx| {
 84        assert!(matches!(
 85            thread.entries()[0],
 86            AgentThreadEntry::UserMessage(_)
 87        ));
 88        let assistant_message = &thread
 89            .entries()
 90            .iter()
 91            .rev()
 92            .find_map(|entry| match entry {
 93                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
 94                _ => None,
 95            })
 96            .unwrap();
 97
 98        assert!(
 99            assistant_message.to_markdown(cx).contains("Hello, world!"),
100            "unexpected assistant message: {:?}",
101            assistant_message.to_markdown(cx)
102        );
103    });
104
105    drop(tempdir);
106}
107
108pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
109where
110    T: AgentServer + 'static,
111    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
112{
113    let fs = init_test(cx).await as _;
114
115    let tempdir = tempfile::tempdir().unwrap();
116    let foo_path = tempdir.path().join("foo");
117    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
118
119    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
120    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
121
122    thread
123        .update(cx, |thread, cx| {
124            thread.send_raw(
125                &format!("Read {} and tell me what you see.", foo_path.display()),
126                cx,
127            )
128        })
129        .await
130        .unwrap();
131    thread.read_with(cx, |thread, _cx| {
132        assert!(thread.entries().iter().any(|entry| {
133            matches!(
134                entry,
135                AgentThreadEntry::ToolCall(ToolCall {
136                    status: ToolCallStatus::Pending
137                        | ToolCallStatus::InProgress
138                        | ToolCallStatus::Completed,
139                    ..
140                })
141            )
142        }));
143        assert!(
144            thread
145                .entries()
146                .iter()
147                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
148        );
149    });
150
151    drop(tempdir);
152}
153
154pub async fn test_tool_call_with_permission<T, F>(
155    server: F,
156    allow_option_id: acp::PermissionOptionId,
157    cx: &mut TestAppContext,
158) where
159    T: AgentServer + 'static,
160    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
161{
162    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
163    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
164    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
165    let full_turn = thread.update(cx, |thread, cx| {
166        thread.send_raw(
167            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
168            cx,
169        )
170    });
171
172    run_until_first_tool_call(
173        &thread,
174        |entry| {
175            matches!(
176                entry,
177                AgentThreadEntry::ToolCall(ToolCall {
178                    status: ToolCallStatus::WaitingForConfirmation { .. },
179                    ..
180                })
181            )
182        },
183        cx,
184    )
185    .await;
186
187    let tool_call_id = thread.read_with(cx, |thread, cx| {
188        let AgentThreadEntry::ToolCall(ToolCall {
189            id,
190            label,
191            status: ToolCallStatus::WaitingForConfirmation { .. },
192            ..
193        }) = &thread
194            .entries()
195            .iter()
196            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
197            .unwrap()
198        else {
199            panic!();
200        };
201
202        let label = label.read(cx).source();
203        assert!(label.contains("touch"), "Got: {}", label);
204
205        id.clone()
206    });
207
208    thread.update(cx, |thread, cx| {
209        thread.authorize_tool_call(
210            tool_call_id,
211            allow_option_id,
212            acp::PermissionOptionKind::AllowOnce,
213            cx,
214        );
215
216        assert!(thread.entries().iter().any(|entry| matches!(
217            entry,
218            AgentThreadEntry::ToolCall(ToolCall {
219                status: ToolCallStatus::Pending
220                    | ToolCallStatus::InProgress
221                    | ToolCallStatus::Completed,
222                ..
223            })
224        )));
225    });
226
227    full_turn.await.unwrap();
228
229    thread.read_with(cx, |thread, cx| {
230        let AgentThreadEntry::ToolCall(ToolCall {
231            content,
232            status: ToolCallStatus::Pending
233                | ToolCallStatus::InProgress
234                | ToolCallStatus::Completed,
235            ..
236        }) = thread
237            .entries()
238            .iter()
239            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
240            .unwrap()
241        else {
242            panic!();
243        };
244
245        assert!(
246            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
247            "Expected content to contain 'Hello'"
248        );
249    });
250}
251
252pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
253where
254    T: AgentServer + 'static,
255    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
256{
257    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
258
259    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
260    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
261    let _ = thread.update(cx, |thread, cx| {
262        thread.send_raw(
263            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
264            cx,
265        )
266    });
267
268    let first_tool_call_ix = run_until_first_tool_call(
269        &thread,
270        |entry| {
271            matches!(
272                entry,
273                AgentThreadEntry::ToolCall(ToolCall {
274                    status: ToolCallStatus::WaitingForConfirmation { .. },
275                    ..
276                })
277            )
278        },
279        cx,
280    )
281    .await;
282
283    thread.read_with(cx, |thread, cx| {
284        let AgentThreadEntry::ToolCall(ToolCall {
285            id,
286            label,
287            status: ToolCallStatus::WaitingForConfirmation { .. },
288            ..
289        }) = &thread.entries()[first_tool_call_ix]
290        else {
291            panic!("{:?}", thread.entries()[1]);
292        };
293
294        let label = label.read(cx).source();
295        assert!(label.contains("touch"), "Got: {}", label);
296
297        id.clone()
298    });
299
300    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
301    thread.read_with(cx, |thread, _cx| {
302        let AgentThreadEntry::ToolCall(ToolCall {
303            status: ToolCallStatus::Canceled,
304            ..
305        }) = &thread.entries()[first_tool_call_ix]
306        else {
307            panic!();
308        };
309    });
310
311    thread
312        .update(cx, |thread, cx| {
313            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
314        })
315        .await
316        .unwrap();
317    thread.read_with(cx, |thread, _| {
318        assert!(matches!(
319            &thread.entries().last().unwrap(),
320            AgentThreadEntry::AssistantMessage(..),
321        ))
322    });
323}
324
325pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
326where
327    T: AgentServer + 'static,
328    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
329{
330    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
331    let project = Project::test(fs.clone(), [], cx).await;
332    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
333
334    thread
335        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
336        .await
337        .unwrap();
338
339    thread.read_with(cx, |thread, _| {
340        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
341    });
342
343    let weak_thread = thread.downgrade();
344    drop(thread);
345
346    cx.executor().run_until_parked();
347    assert!(!weak_thread.is_upgradable());
348}
349
350#[macro_export]
351macro_rules! common_e2e_tests {
352    ($server:expr, allow_option_id = $allow_option_id:expr) => {
353        mod common_e2e {
354            use super::*;
355
356            #[::gpui::test]
357            #[cfg_attr(not(feature = "e2e"), ignore)]
358            async fn basic(cx: &mut ::gpui::TestAppContext) {
359                $crate::e2e_tests::test_basic($server, cx).await;
360            }
361
362            #[::gpui::test]
363            #[cfg_attr(not(feature = "e2e"), ignore)]
364            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
365                $crate::e2e_tests::test_path_mentions($server, cx).await;
366            }
367
368            #[::gpui::test]
369            #[cfg_attr(not(feature = "e2e"), ignore)]
370            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
371                $crate::e2e_tests::test_tool_call($server, cx).await;
372            }
373
374            #[::gpui::test]
375            #[cfg_attr(not(feature = "e2e"), ignore)]
376            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
377                $crate::e2e_tests::test_tool_call_with_permission(
378                    $server,
379                    ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
380                    cx,
381                )
382                .await;
383            }
384
385            #[::gpui::test]
386            #[cfg_attr(not(feature = "e2e"), ignore)]
387            async fn cancel(cx: &mut ::gpui::TestAppContext) {
388                $crate::e2e_tests::test_cancel($server, cx).await;
389            }
390
391            #[::gpui::test]
392            #[cfg_attr(not(feature = "e2e"), ignore)]
393            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
394                $crate::e2e_tests::test_thread_drop($server, cx).await;
395            }
396        }
397    };
398}
399pub use common_e2e_tests;
400
401// Helpers
402
403pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
404    env_logger::try_init().ok();
405
406    cx.update(|cx| {
407        let settings_store = settings::SettingsStore::test(cx);
408        cx.set_global(settings_store);
409        gpui_tokio::init(cx);
410        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
411        cx.set_http_client(Arc::new(http_client));
412        let client = client::Client::production(cx);
413        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
414        language_model::init(client.clone(), cx);
415        language_models::init(user_store, client, cx);
416
417        #[cfg(test)]
418        project::agent_server_store::AllAgentServersSettings::override_global(
419            project::agent_server_store::AllAgentServersSettings {
420                claude: Some(BuiltinAgentServerSettings {
421                    path: Some("claude-code-acp".into()),
422                    ..Default::default()
423                }),
424                gemini: Some(crate::gemini::tests::local_command().into()),
425                codex: Some(BuiltinAgentServerSettings {
426                    path: Some("codex-acp".into()),
427                    ..Default::default()
428                }),
429                custom: collections::HashMap::default(),
430            },
431            cx,
432        );
433    });
434
435    cx.executor().allow_parking();
436
437    FakeFs::new(cx.executor())
438}
439
440pub async fn new_test_thread(
441    server: impl AgentServer + 'static,
442    project: Entity<Project>,
443    current_dir: impl AsRef<Path>,
444    cx: &mut TestAppContext,
445) -> Entity<AcpThread> {
446    let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
447    let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
448
449    let (connection, _) = cx
450        .update(|cx| server.connect(Some(current_dir.as_ref()), delegate, cx))
451        .await
452        .unwrap();
453
454    cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
455        .await
456        .unwrap()
457}
458
459pub async fn run_until_first_tool_call(
460    thread: &Entity<AcpThread>,
461    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
462    cx: &mut TestAppContext,
463) -> usize {
464    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
465
466    let subscription = cx.update(|cx| {
467        cx.subscribe(thread, move |thread, _, cx| {
468            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
469                if wait_until(entry) {
470                    return tx.try_send(ix).unwrap();
471                }
472            }
473        })
474    });
475
476    select! {
477        // We have to use a smol timer here because
478        // cx.background_executor().timer isn't real in the test context
479        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
480            panic!("Timeout waiting for tool call")
481        }
482        ix = rx.next().fuse() => {
483            drop(subscription);
484            ix.unwrap()
485        }
486    }
487}
488
489pub fn get_zed_path() -> PathBuf {
490    let mut zed_path = std::env::current_exe().unwrap();
491
492    while zed_path
493        .file_name()
494        .is_none_or(|name| name.to_string_lossy() != "debug")
495    {
496        if !zed_path.pop() {
497            panic!("Could not find target directory");
498        }
499    }
500
501    zed_path.push("zed");
502
503    if !zed_path.exists() {
504        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
505    }
506
507    zed_path
508}