e2e_tests.rs

  1use crate::{AgentServer, AgentServerDelegate};
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::{Entity, TestAppContext};
  6use indoc::indoc;
  7use project::{FakeFs, Project};
  8#[cfg(test)]
  9use settings::Settings;
 10use std::{
 11    path::{Path, PathBuf},
 12    sync::Arc,
 13    time::Duration,
 14};
 15use util::path;
 16
 17pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 18where
 19    T: AgentServer + 'static,
 20    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
 21{
 22    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 23    let project = Project::test(fs.clone(), [], cx).await;
 24    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
 25
 26    thread
 27        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 28        .await
 29        .unwrap();
 30
 31    thread.read_with(cx, |thread, _| {
 32        assert!(
 33            thread.entries().len() >= 2,
 34            "Expected at least 2 entries. Got: {:?}",
 35            thread.entries()
 36        );
 37        assert!(matches!(
 38            thread.entries()[0],
 39            AgentThreadEntry::UserMessage(_)
 40        ));
 41        assert!(matches!(
 42            thread.entries()[1],
 43            AgentThreadEntry::AssistantMessage(_)
 44        ));
 45    });
 46}
 47
 48pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 49where
 50    T: AgentServer + 'static,
 51    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
 52{
 53    let fs = init_test(cx).await as _;
 54
 55    let tempdir = tempfile::tempdir().unwrap();
 56    std::fs::write(
 57        tempdir.path().join("foo.rs"),
 58        indoc! {"
 59            fn main() {
 60                println!(\"Hello, world!\");
 61            }
 62        "},
 63    )
 64    .expect("failed to write file");
 65    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 66    let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
 67    thread
 68        .update(cx, |thread, cx| {
 69            thread.send(
 70                vec![
 71                    "Read the file ".into(),
 72                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
 73                    " and tell me what the content of the println! is".into(),
 74                ],
 75                cx,
 76            )
 77        })
 78        .await
 79        .unwrap();
 80
 81    thread.read_with(cx, |thread, cx| {
 82        assert!(matches!(
 83            thread.entries()[0],
 84            AgentThreadEntry::UserMessage(_)
 85        ));
 86        let assistant_message = &thread
 87            .entries()
 88            .iter()
 89            .rev()
 90            .find_map(|entry| match entry {
 91                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
 92                _ => None,
 93            })
 94            .unwrap();
 95
 96        assert!(
 97            assistant_message.to_markdown(cx).contains("Hello, world!"),
 98            "unexpected assistant message: {:?}",
 99            assistant_message.to_markdown(cx)
100        );
101    });
102
103    drop(tempdir);
104}
105
106pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
107where
108    T: AgentServer + 'static,
109    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
110{
111    let fs = init_test(cx).await as _;
112
113    let tempdir = tempfile::tempdir().unwrap();
114    let foo_path = tempdir.path().join("foo");
115    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
116
117    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
118    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
119
120    thread
121        .update(cx, |thread, cx| {
122            thread.send_raw(
123                &format!("Read {} and tell me what you see.", foo_path.display()),
124                cx,
125            )
126        })
127        .await
128        .unwrap();
129    thread.read_with(cx, |thread, _cx| {
130        assert!(thread.entries().iter().any(|entry| {
131            matches!(
132                entry,
133                AgentThreadEntry::ToolCall(ToolCall {
134                    status: ToolCallStatus::Pending
135                        | ToolCallStatus::InProgress
136                        | ToolCallStatus::Completed,
137                    ..
138                })
139            )
140        }));
141        assert!(
142            thread
143                .entries()
144                .iter()
145                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
146        );
147    });
148
149    drop(tempdir);
150}
151
152pub async fn test_tool_call_with_permission<T, F>(
153    server: F,
154    allow_option_id: acp::PermissionOptionId,
155    cx: &mut TestAppContext,
156) where
157    T: AgentServer + 'static,
158    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
159{
160    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
161    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
162    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
163    let full_turn = thread.update(cx, |thread, cx| {
164        thread.send_raw(
165            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
166            cx,
167        )
168    });
169
170    run_until_first_tool_call(
171        &thread,
172        |entry| {
173            matches!(
174                entry,
175                AgentThreadEntry::ToolCall(ToolCall {
176                    status: ToolCallStatus::WaitingForConfirmation { .. },
177                    ..
178                })
179            )
180        },
181        cx,
182    )
183    .await;
184
185    let tool_call_id = thread.read_with(cx, |thread, cx| {
186        let AgentThreadEntry::ToolCall(ToolCall {
187            id,
188            label,
189            status: ToolCallStatus::WaitingForConfirmation { .. },
190            ..
191        }) = &thread
192            .entries()
193            .iter()
194            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
195            .unwrap()
196        else {
197            panic!();
198        };
199
200        let label = label.read(cx).source();
201        assert!(label.contains("touch"), "Got: {}", label);
202
203        id.clone()
204    });
205
206    thread.update(cx, |thread, cx| {
207        thread.authorize_tool_call(
208            tool_call_id,
209            allow_option_id,
210            acp::PermissionOptionKind::AllowOnce,
211            cx,
212        );
213
214        assert!(thread.entries().iter().any(|entry| matches!(
215            entry,
216            AgentThreadEntry::ToolCall(ToolCall {
217                status: ToolCallStatus::Pending
218                    | ToolCallStatus::InProgress
219                    | ToolCallStatus::Completed,
220                ..
221            })
222        )));
223    });
224
225    full_turn.await.unwrap();
226
227    thread.read_with(cx, |thread, cx| {
228        let AgentThreadEntry::ToolCall(ToolCall {
229            content,
230            status: ToolCallStatus::Pending
231                | ToolCallStatus::InProgress
232                | ToolCallStatus::Completed,
233            ..
234        }) = thread
235            .entries()
236            .iter()
237            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
238            .unwrap()
239        else {
240            panic!();
241        };
242
243        assert!(
244            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
245            "Expected content to contain 'Hello'"
246        );
247    });
248}
249
250pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
251where
252    T: AgentServer + 'static,
253    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
254{
255    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
256
257    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
258    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
259    let _ = thread.update(cx, |thread, cx| {
260        thread.send_raw(
261            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
262            cx,
263        )
264    });
265
266    let first_tool_call_ix = run_until_first_tool_call(
267        &thread,
268        |entry| {
269            matches!(
270                entry,
271                AgentThreadEntry::ToolCall(ToolCall {
272                    status: ToolCallStatus::WaitingForConfirmation { .. },
273                    ..
274                })
275            )
276        },
277        cx,
278    )
279    .await;
280
281    thread.read_with(cx, |thread, cx| {
282        let AgentThreadEntry::ToolCall(ToolCall {
283            id,
284            label,
285            status: ToolCallStatus::WaitingForConfirmation { .. },
286            ..
287        }) = &thread.entries()[first_tool_call_ix]
288        else {
289            panic!("{:?}", thread.entries()[1]);
290        };
291
292        let label = label.read(cx).source();
293        assert!(label.contains("touch"), "Got: {}", label);
294
295        id.clone()
296    });
297
298    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
299    thread.read_with(cx, |thread, _cx| {
300        let AgentThreadEntry::ToolCall(ToolCall {
301            status: ToolCallStatus::Canceled,
302            ..
303        }) = &thread.entries()[first_tool_call_ix]
304        else {
305            panic!();
306        };
307    });
308
309    thread
310        .update(cx, |thread, cx| {
311            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
312        })
313        .await
314        .unwrap();
315    thread.read_with(cx, |thread, _| {
316        assert!(matches!(
317            &thread.entries().last().unwrap(),
318            AgentThreadEntry::AssistantMessage(..),
319        ))
320    });
321}
322
323pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
324where
325    T: AgentServer + 'static,
326    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
327{
328    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
329    let project = Project::test(fs.clone(), [], cx).await;
330    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
331
332    thread
333        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
334        .await
335        .unwrap();
336
337    thread.read_with(cx, |thread, _| {
338        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
339    });
340
341    let weak_thread = thread.downgrade();
342    drop(thread);
343
344    cx.executor().run_until_parked();
345    assert!(!weak_thread.is_upgradable());
346}
347
348#[macro_export]
349macro_rules! common_e2e_tests {
350    ($server:expr, allow_option_id = $allow_option_id:expr) => {
351        mod common_e2e {
352            use super::*;
353
354            #[::gpui::test]
355            #[cfg_attr(not(feature = "e2e"), ignore)]
356            async fn basic(cx: &mut ::gpui::TestAppContext) {
357                $crate::e2e_tests::test_basic($server, cx).await;
358            }
359
360            #[::gpui::test]
361            #[cfg_attr(not(feature = "e2e"), ignore)]
362            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
363                $crate::e2e_tests::test_path_mentions($server, cx).await;
364            }
365
366            #[::gpui::test]
367            #[cfg_attr(not(feature = "e2e"), ignore)]
368            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
369                $crate::e2e_tests::test_tool_call($server, cx).await;
370            }
371
372            #[::gpui::test]
373            #[cfg_attr(not(feature = "e2e"), ignore)]
374            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
375                $crate::e2e_tests::test_tool_call_with_permission(
376                    $server,
377                    ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
378                    cx,
379                )
380                .await;
381            }
382
383            #[::gpui::test]
384            #[cfg_attr(not(feature = "e2e"), ignore)]
385            async fn cancel(cx: &mut ::gpui::TestAppContext) {
386                $crate::e2e_tests::test_cancel($server, cx).await;
387            }
388
389            #[::gpui::test]
390            #[cfg_attr(not(feature = "e2e"), ignore)]
391            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
392                $crate::e2e_tests::test_thread_drop($server, cx).await;
393            }
394        }
395    };
396}
397pub use common_e2e_tests;
398
399// Helpers
400
401pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
402    env_logger::try_init().ok();
403
404    cx.update(|cx| {
405        let settings_store = settings::SettingsStore::test(cx);
406        cx.set_global(settings_store);
407        gpui_tokio::init(cx);
408        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
409        cx.set_http_client(Arc::new(http_client));
410        let client = client::Client::production(cx);
411        language_model::init(client, cx);
412
413        #[cfg(test)]
414        project::agent_server_store::AllAgentServersSettings::override_global(
415            project::agent_server_store::AllAgentServersSettings(collections::HashMap::default()),
416            cx,
417        );
418    });
419
420    cx.executor().allow_parking();
421
422    FakeFs::new(cx.executor())
423}
424
425pub async fn new_test_thread(
426    server: impl AgentServer + 'static,
427    project: Entity<Project>,
428    current_dir: impl AsRef<Path>,
429    cx: &mut TestAppContext,
430) -> Entity<AcpThread> {
431    let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
432    let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
433
434    let connection = cx.update(|cx| server.connect(delegate, cx)).await.unwrap();
435
436    cx.update(|cx| connection.new_session(project.clone(), current_dir.as_ref(), cx))
437        .await
438        .unwrap()
439}
440
441pub async fn run_until_first_tool_call(
442    thread: &Entity<AcpThread>,
443    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
444    cx: &mut TestAppContext,
445) -> usize {
446    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
447
448    let subscription = cx.update(|cx| {
449        cx.subscribe(thread, move |thread, _, cx| {
450            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
451                if wait_until(entry) {
452                    return tx.try_send(ix).unwrap();
453                }
454            }
455        })
456    });
457
458    select! {
459        _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(20))) => {
460            panic!("Timeout waiting for tool call")
461        }
462        ix = rx.next().fuse() => {
463            drop(subscription);
464            ix.unwrap()
465        }
466    }
467}
468
469pub fn get_zed_path() -> PathBuf {
470    let mut zed_path = std::env::current_exe().unwrap();
471
472    while zed_path
473        .file_name()
474        .is_none_or(|name| name.to_string_lossy() != "debug")
475    {
476        if !zed_path.pop() {
477            panic!("Could not find target directory");
478        }
479    }
480
481    zed_path.push("zed");
482
483    if !zed_path.exists() {
484        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
485    }
486
487    zed_path
488}