e2e_tests.rs

  1use crate::{AgentServer, AgentServerDelegate};
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use client::RefreshLlmTokenListener;
  5use futures::{FutureExt, StreamExt, channel::mpsc, select};
  6use gpui::AppContext;
  7use gpui::{Entity, TestAppContext};
  8use indoc::indoc;
  9use project::{FakeFs, Project};
 10#[cfg(test)]
 11use settings::Settings;
 12use std::{
 13    path::{Path, PathBuf},
 14    sync::Arc,
 15    time::Duration,
 16};
 17use util::path;
 18use util::path_list::PathList;
 19
 20pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 21where
 22    T: AgentServer + 'static,
 23    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
 24{
 25    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 26    let project = Project::test(fs.clone(), [], cx).await;
 27    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
 28
 29    thread
 30        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 31        .await
 32        .unwrap();
 33
 34    thread.read_with(cx, |thread, _| {
 35        assert!(
 36            thread.entries().len() >= 2,
 37            "Expected at least 2 entries. Got: {:?}",
 38            thread.entries()
 39        );
 40        assert!(matches!(
 41            thread.entries()[0],
 42            AgentThreadEntry::UserMessage(_)
 43        ));
 44        assert!(matches!(
 45            thread.entries()[1],
 46            AgentThreadEntry::AssistantMessage(_)
 47        ));
 48    });
 49}
 50
 51pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 52where
 53    T: AgentServer + 'static,
 54    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
 55{
 56    let fs = init_test(cx).await as _;
 57
 58    let tempdir = tempfile::tempdir().unwrap();
 59    std::fs::write(
 60        tempdir.path().join("foo.rs"),
 61        indoc! {"
 62            fn main() {
 63                println!(\"Hello, world!\");
 64            }
 65        "},
 66    )
 67    .expect("failed to write file");
 68    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 69    let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
 70    thread
 71        .update(cx, |thread, cx| {
 72            thread.send(
 73                vec![
 74                    "Read the file ".into(),
 75                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
 76                    " and tell me what the content of the println! is".into(),
 77                ],
 78                cx,
 79            )
 80        })
 81        .await
 82        .unwrap();
 83
 84    thread.read_with(cx, |thread, cx| {
 85        assert!(matches!(
 86            thread.entries()[0],
 87            AgentThreadEntry::UserMessage(_)
 88        ));
 89        let assistant_message = &thread
 90            .entries()
 91            .iter()
 92            .rev()
 93            .find_map(|entry| match entry {
 94                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
 95                _ => None,
 96            })
 97            .unwrap();
 98
 99        assert!(
100            assistant_message.to_markdown(cx).contains("Hello, world!"),
101            "unexpected assistant message: {:?}",
102            assistant_message.to_markdown(cx)
103        );
104    });
105
106    drop(tempdir);
107}
108
109pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
110where
111    T: AgentServer + 'static,
112    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
113{
114    let fs = init_test(cx).await as _;
115
116    let tempdir = tempfile::tempdir().unwrap();
117    let foo_path = tempdir.path().join("foo");
118    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
119
120    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
121    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
122
123    thread
124        .update(cx, |thread, cx| {
125            thread.send_raw(
126                &format!("Read {} and tell me what you see.", foo_path.display()),
127                cx,
128            )
129        })
130        .await
131        .unwrap();
132    thread.read_with(cx, |thread, _cx| {
133        assert!(thread.entries().iter().any(|entry| {
134            matches!(
135                entry,
136                AgentThreadEntry::ToolCall(ToolCall {
137                    status: ToolCallStatus::Pending
138                        | ToolCallStatus::InProgress
139                        | ToolCallStatus::Completed,
140                    ..
141                })
142            )
143        }));
144        assert!(
145            thread
146                .entries()
147                .iter()
148                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
149        );
150    });
151
152    drop(tempdir);
153}
154
155pub async fn test_tool_call_with_permission<T, F>(
156    server: F,
157    allow_option_id: acp::PermissionOptionId,
158    cx: &mut TestAppContext,
159) where
160    T: AgentServer + 'static,
161    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
162{
163    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
164    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
165    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
166    let full_turn = thread.update(cx, |thread, cx| {
167        thread.send_raw(
168            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
169            cx,
170        )
171    });
172
173    run_until_first_tool_call(
174        &thread,
175        |entry| {
176            matches!(
177                entry,
178                AgentThreadEntry::ToolCall(ToolCall {
179                    status: ToolCallStatus::WaitingForConfirmation { .. },
180                    ..
181                })
182            )
183        },
184        cx,
185    )
186    .await;
187
188    let tool_call_id = thread.read_with(cx, |thread, cx| {
189        let AgentThreadEntry::ToolCall(ToolCall {
190            id,
191            label,
192            status: ToolCallStatus::WaitingForConfirmation { .. },
193            ..
194        }) = &thread
195            .entries()
196            .iter()
197            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
198            .unwrap()
199        else {
200            panic!();
201        };
202
203        let label = label.read(cx).source();
204        assert!(label.contains("touch"), "Got: {}", label);
205
206        id.clone()
207    });
208
209    thread.update(cx, |thread, cx| {
210        thread.authorize_tool_call(
211            tool_call_id,
212            acp_thread::SelectedPermissionOutcome::new(
213                allow_option_id,
214                acp::PermissionOptionKind::AllowOnce,
215            ),
216            cx,
217        );
218
219        assert!(thread.entries().iter().any(|entry| matches!(
220            entry,
221            AgentThreadEntry::ToolCall(ToolCall {
222                status: ToolCallStatus::Pending
223                    | ToolCallStatus::InProgress
224                    | ToolCallStatus::Completed,
225                ..
226            })
227        )));
228    });
229
230    full_turn.await.unwrap();
231
232    thread.read_with(cx, |thread, cx| {
233        let AgentThreadEntry::ToolCall(ToolCall {
234            content,
235            status: ToolCallStatus::Pending
236                | ToolCallStatus::InProgress
237                | ToolCallStatus::Completed,
238            ..
239        }) = thread
240            .entries()
241            .iter()
242            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
243            .unwrap()
244        else {
245            panic!();
246        };
247
248        assert!(
249            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
250            "Expected content to contain 'Hello'"
251        );
252    });
253}
254
255pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
256where
257    T: AgentServer + 'static,
258    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
259{
260    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
261
262    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
263    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
264    let _ = thread.update(cx, |thread, cx| {
265        thread.send_raw(
266            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
267            cx,
268        )
269    });
270
271    let first_tool_call_ix = run_until_first_tool_call(
272        &thread,
273        |entry| {
274            matches!(
275                entry,
276                AgentThreadEntry::ToolCall(ToolCall {
277                    status: ToolCallStatus::WaitingForConfirmation { .. },
278                    ..
279                })
280            )
281        },
282        cx,
283    )
284    .await;
285
286    thread.read_with(cx, |thread, cx| {
287        let AgentThreadEntry::ToolCall(ToolCall {
288            id,
289            label,
290            status: ToolCallStatus::WaitingForConfirmation { .. },
291            ..
292        }) = &thread.entries()[first_tool_call_ix]
293        else {
294            panic!("{:?}", thread.entries()[1]);
295        };
296
297        let label = label.read(cx).source();
298        assert!(label.contains("touch"), "Got: {}", label);
299
300        id.clone()
301    });
302
303    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
304    thread.read_with(cx, |thread, _cx| {
305        let AgentThreadEntry::ToolCall(ToolCall {
306            status: ToolCallStatus::Canceled,
307            ..
308        }) = &thread.entries()[first_tool_call_ix]
309        else {
310            panic!();
311        };
312    });
313
314    thread
315        .update(cx, |thread, cx| {
316            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
317        })
318        .await
319        .unwrap();
320    thread.read_with(cx, |thread, _| {
321        assert!(matches!(
322            &thread.entries().last().unwrap(),
323            AgentThreadEntry::AssistantMessage(..),
324        ))
325    });
326}
327
328pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
329where
330    T: AgentServer + 'static,
331    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
332{
333    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
334    let project = Project::test(fs.clone(), [], cx).await;
335    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
336
337    thread
338        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
339        .await
340        .unwrap();
341
342    thread.read_with(cx, |thread, _| {
343        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
344    });
345
346    let weak_thread = thread.downgrade();
347    drop(thread);
348
349    cx.executor().run_until_parked();
350    assert!(!weak_thread.is_upgradable());
351}
352
353#[macro_export]
354macro_rules! common_e2e_tests {
355    ($server:expr, allow_option_id = $allow_option_id:expr) => {
356        mod common_e2e {
357            use super::*;
358
359            #[::gpui::test]
360            #[cfg_attr(not(feature = "e2e"), ignore)]
361            async fn basic(cx: &mut ::gpui::TestAppContext) {
362                $crate::e2e_tests::test_basic($server, cx).await;
363            }
364
365            #[::gpui::test]
366            #[cfg_attr(not(feature = "e2e"), ignore)]
367            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
368                $crate::e2e_tests::test_path_mentions($server, cx).await;
369            }
370
371            #[::gpui::test]
372            #[cfg_attr(not(feature = "e2e"), ignore)]
373            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
374                $crate::e2e_tests::test_tool_call($server, cx).await;
375            }
376
377            #[::gpui::test]
378            #[cfg_attr(not(feature = "e2e"), ignore)]
379            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
380                $crate::e2e_tests::test_tool_call_with_permission(
381                    $server,
382                    ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
383                    cx,
384                )
385                .await;
386            }
387
388            #[::gpui::test]
389            #[cfg_attr(not(feature = "e2e"), ignore)]
390            async fn cancel(cx: &mut ::gpui::TestAppContext) {
391                $crate::e2e_tests::test_cancel($server, cx).await;
392            }
393
394            #[::gpui::test]
395            #[cfg_attr(not(feature = "e2e"), ignore)]
396            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
397                $crate::e2e_tests::test_thread_drop($server, cx).await;
398            }
399        }
400    };
401}
402pub use common_e2e_tests;
403
404// Helpers
405
406pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
407    env_logger::try_init().ok();
408
409    cx.update(|cx| {
410        let settings_store = settings::SettingsStore::test(cx);
411        cx.set_global(settings_store);
412        gpui_tokio::init(cx);
413        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
414        cx.set_http_client(Arc::new(http_client));
415        let client = client::Client::production(cx);
416        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
417        language_model::init(cx);
418        RefreshLlmTokenListener::register(client.clone(), user_store, cx);
419
420        #[cfg(test)]
421        project::agent_server_store::AllAgentServersSettings::override_global(
422            project::agent_server_store::AllAgentServersSettings(collections::HashMap::default()),
423            cx,
424        );
425    });
426
427    cx.executor().allow_parking();
428
429    FakeFs::new(cx.executor())
430}
431
432pub async fn new_test_thread(
433    server: impl AgentServer + 'static,
434    project: Entity<Project>,
435    current_dir: impl AsRef<Path>,
436    cx: &mut TestAppContext,
437) -> Entity<AcpThread> {
438    let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
439    let delegate = AgentServerDelegate::new(store, None);
440
441    let connection = cx
442        .update(|cx| server.connect(delegate, project.clone(), cx))
443        .await
444        .unwrap();
445
446    cx.update(|cx| {
447        connection.new_session(project.clone(), PathList::new(&[current_dir.as_ref()]), cx)
448    })
449    .await
450    .unwrap()
451}
452
453pub async fn run_until_first_tool_call(
454    thread: &Entity<AcpThread>,
455    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
456    cx: &mut TestAppContext,
457) -> usize {
458    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
459
460    let subscription = cx.update(|cx| {
461        cx.subscribe(thread, move |thread, _, cx| {
462            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
463                if wait_until(entry) {
464                    return tx.try_send(ix).unwrap();
465                }
466            }
467        })
468    });
469
470    select! {
471        _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(20))) => {
472            panic!("Timeout waiting for tool call")
473        }
474        ix = rx.next().fuse() => {
475            drop(subscription);
476            ix.unwrap()
477        }
478    }
479}
480
481pub fn get_zed_path() -> PathBuf {
482    let mut zed_path = std::env::current_exe().unwrap();
483
484    while zed_path
485        .file_name()
486        .is_none_or(|name| name.to_string_lossy() != "debug")
487    {
488        if !zed_path.pop() {
489            panic!("Could not find target directory");
490        }
491    }
492
493    zed_path.push("zed");
494
495    if !zed_path.exists() {
496        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
497    }
498
499    zed_path
500}