e2e_tests.rs

  1use crate::{AgentServer, AgentServerDelegate};
  2use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  3use agent_client_protocol as acp;
  4use futures::{FutureExt, StreamExt, channel::mpsc, select};
  5use gpui::AppContext;
  6use gpui::{Entity, TestAppContext};
  7use indoc::indoc;
  8use project::{FakeFs, Project};
  9#[cfg(test)]
 10use settings::Settings;
 11use std::{
 12    path::{Path, PathBuf},
 13    sync::Arc,
 14    time::Duration,
 15};
 16use util::path;
 17
 18pub async fn test_basic<T, F>(server: F, cx: &mut TestAppContext)
 19where
 20    T: AgentServer + 'static,
 21    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
 22{
 23    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
 24    let project = Project::test(fs.clone(), [], cx).await;
 25    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
 26
 27    thread
 28        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 29        .await
 30        .unwrap();
 31
 32    thread.read_with(cx, |thread, _| {
 33        assert!(
 34            thread.entries().len() >= 2,
 35            "Expected at least 2 entries. Got: {:?}",
 36            thread.entries()
 37        );
 38        assert!(matches!(
 39            thread.entries()[0],
 40            AgentThreadEntry::UserMessage(_)
 41        ));
 42        assert!(matches!(
 43            thread.entries()[1],
 44            AgentThreadEntry::AssistantMessage(_)
 45        ));
 46    });
 47}
 48
 49pub async fn test_path_mentions<T, F>(server: F, cx: &mut TestAppContext)
 50where
 51    T: AgentServer + 'static,
 52    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
 53{
 54    let fs = init_test(cx).await as _;
 55
 56    let tempdir = tempfile::tempdir().unwrap();
 57    std::fs::write(
 58        tempdir.path().join("foo.rs"),
 59        indoc! {"
 60            fn main() {
 61                println!(\"Hello, world!\");
 62            }
 63        "},
 64    )
 65    .expect("failed to write file");
 66    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 67    let thread = new_test_thread(server(&fs, cx).await, project.clone(), tempdir.path(), cx).await;
 68    thread
 69        .update(cx, |thread, cx| {
 70            thread.send(
 71                vec![
 72                    "Read the file ".into(),
 73                    acp::ContentBlock::ResourceLink(acp::ResourceLink::new("foo.rs", "foo.rs")),
 74                    " and tell me what the content of the println! is".into(),
 75                ],
 76                cx,
 77            )
 78        })
 79        .await
 80        .unwrap();
 81
 82    thread.read_with(cx, |thread, cx| {
 83        assert!(matches!(
 84            thread.entries()[0],
 85            AgentThreadEntry::UserMessage(_)
 86        ));
 87        let assistant_message = &thread
 88            .entries()
 89            .iter()
 90            .rev()
 91            .find_map(|entry| match entry {
 92                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
 93                _ => None,
 94            })
 95            .unwrap();
 96
 97        assert!(
 98            assistant_message.to_markdown(cx).contains("Hello, world!"),
 99            "unexpected assistant message: {:?}",
100            assistant_message.to_markdown(cx)
101        );
102    });
103
104    drop(tempdir);
105}
106
107pub async fn test_tool_call<T, F>(server: F, cx: &mut TestAppContext)
108where
109    T: AgentServer + 'static,
110    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
111{
112    let fs = init_test(cx).await as _;
113
114    let tempdir = tempfile::tempdir().unwrap();
115    let foo_path = tempdir.path().join("foo");
116    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
117
118    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
119    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
120
121    thread
122        .update(cx, |thread, cx| {
123            thread.send_raw(
124                &format!("Read {} and tell me what you see.", foo_path.display()),
125                cx,
126            )
127        })
128        .await
129        .unwrap();
130    thread.read_with(cx, |thread, _cx| {
131        assert!(thread.entries().iter().any(|entry| {
132            matches!(
133                entry,
134                AgentThreadEntry::ToolCall(ToolCall {
135                    status: ToolCallStatus::Pending
136                        | ToolCallStatus::InProgress
137                        | ToolCallStatus::Completed,
138                    ..
139                })
140            )
141        }));
142        assert!(
143            thread
144                .entries()
145                .iter()
146                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
147        );
148    });
149
150    drop(tempdir);
151}
152
153pub async fn test_tool_call_with_permission<T, F>(
154    server: F,
155    allow_option_id: acp::PermissionOptionId,
156    cx: &mut TestAppContext,
157) where
158    T: AgentServer + 'static,
159    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
160{
161    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
162    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
163    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
164    let full_turn = thread.update(cx, |thread, cx| {
165        thread.send_raw(
166            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
167            cx,
168        )
169    });
170
171    run_until_first_tool_call(
172        &thread,
173        |entry| {
174            matches!(
175                entry,
176                AgentThreadEntry::ToolCall(ToolCall {
177                    status: ToolCallStatus::WaitingForConfirmation { .. },
178                    ..
179                })
180            )
181        },
182        cx,
183    )
184    .await;
185
186    let tool_call_id = thread.read_with(cx, |thread, cx| {
187        let AgentThreadEntry::ToolCall(ToolCall {
188            id,
189            label,
190            status: ToolCallStatus::WaitingForConfirmation { .. },
191            ..
192        }) = &thread
193            .entries()
194            .iter()
195            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
196            .unwrap()
197        else {
198            panic!();
199        };
200
201        let label = label.read(cx).source();
202        assert!(label.contains("touch"), "Got: {}", label);
203
204        id.clone()
205    });
206
207    thread.update(cx, |thread, cx| {
208        thread.authorize_tool_call(
209            tool_call_id,
210            allow_option_id,
211            acp::PermissionOptionKind::AllowOnce,
212            cx,
213        );
214
215        assert!(thread.entries().iter().any(|entry| matches!(
216            entry,
217            AgentThreadEntry::ToolCall(ToolCall {
218                status: ToolCallStatus::Pending
219                    | ToolCallStatus::InProgress
220                    | ToolCallStatus::Completed,
221                ..
222            })
223        )));
224    });
225
226    full_turn.await.unwrap();
227
228    thread.read_with(cx, |thread, cx| {
229        let AgentThreadEntry::ToolCall(ToolCall {
230            content,
231            status: ToolCallStatus::Pending
232                | ToolCallStatus::InProgress
233                | ToolCallStatus::Completed,
234            ..
235        }) = thread
236            .entries()
237            .iter()
238            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
239            .unwrap()
240        else {
241            panic!();
242        };
243
244        assert!(
245            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
246            "Expected content to contain 'Hello'"
247        );
248    });
249}
250
251pub async fn test_cancel<T, F>(server: F, cx: &mut TestAppContext)
252where
253    T: AgentServer + 'static,
254    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
255{
256    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
257
258    let project = Project::test(fs.clone(), [path!("/private/tmp").as_ref()], cx).await;
259    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
260    let _ = thread.update(cx, |thread, cx| {
261        thread.send_raw(
262            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
263            cx,
264        )
265    });
266
267    let first_tool_call_ix = run_until_first_tool_call(
268        &thread,
269        |entry| {
270            matches!(
271                entry,
272                AgentThreadEntry::ToolCall(ToolCall {
273                    status: ToolCallStatus::WaitingForConfirmation { .. },
274                    ..
275                })
276            )
277        },
278        cx,
279    )
280    .await;
281
282    thread.read_with(cx, |thread, cx| {
283        let AgentThreadEntry::ToolCall(ToolCall {
284            id,
285            label,
286            status: ToolCallStatus::WaitingForConfirmation { .. },
287            ..
288        }) = &thread.entries()[first_tool_call_ix]
289        else {
290            panic!("{:?}", thread.entries()[1]);
291        };
292
293        let label = label.read(cx).source();
294        assert!(label.contains("touch"), "Got: {}", label);
295
296        id.clone()
297    });
298
299    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
300    thread.read_with(cx, |thread, _cx| {
301        let AgentThreadEntry::ToolCall(ToolCall {
302            status: ToolCallStatus::Canceled,
303            ..
304        }) = &thread.entries()[first_tool_call_ix]
305        else {
306            panic!();
307        };
308    });
309
310    thread
311        .update(cx, |thread, cx| {
312            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
313        })
314        .await
315        .unwrap();
316    thread.read_with(cx, |thread, _| {
317        assert!(matches!(
318            &thread.entries().last().unwrap(),
319            AgentThreadEntry::AssistantMessage(..),
320        ))
321    });
322}
323
324pub async fn test_thread_drop<T, F>(server: F, cx: &mut TestAppContext)
325where
326    T: AgentServer + 'static,
327    F: AsyncFn(&Arc<dyn fs::Fs>, &mut TestAppContext) -> T,
328{
329    let fs = init_test(cx).await as Arc<dyn fs::Fs>;
330    let project = Project::test(fs.clone(), [], cx).await;
331    let thread = new_test_thread(server(&fs, cx).await, project.clone(), "/private/tmp", cx).await;
332
333    thread
334        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
335        .await
336        .unwrap();
337
338    thread.read_with(cx, |thread, _| {
339        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
340    });
341
342    let weak_thread = thread.downgrade();
343    drop(thread);
344
345    cx.executor().run_until_parked();
346    assert!(!weak_thread.is_upgradable());
347}
348
349#[macro_export]
350macro_rules! common_e2e_tests {
351    ($server:expr, allow_option_id = $allow_option_id:expr) => {
352        mod common_e2e {
353            use super::*;
354
355            #[::gpui::test]
356            #[cfg_attr(not(feature = "e2e"), ignore)]
357            async fn basic(cx: &mut ::gpui::TestAppContext) {
358                $crate::e2e_tests::test_basic($server, cx).await;
359            }
360
361            #[::gpui::test]
362            #[cfg_attr(not(feature = "e2e"), ignore)]
363            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
364                $crate::e2e_tests::test_path_mentions($server, cx).await;
365            }
366
367            #[::gpui::test]
368            #[cfg_attr(not(feature = "e2e"), ignore)]
369            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
370                $crate::e2e_tests::test_tool_call($server, cx).await;
371            }
372
373            #[::gpui::test]
374            #[cfg_attr(not(feature = "e2e"), ignore)]
375            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
376                $crate::e2e_tests::test_tool_call_with_permission(
377                    $server,
378                    ::agent_client_protocol::PermissionOptionId::new($allow_option_id),
379                    cx,
380                )
381                .await;
382            }
383
384            #[::gpui::test]
385            #[cfg_attr(not(feature = "e2e"), ignore)]
386            async fn cancel(cx: &mut ::gpui::TestAppContext) {
387                $crate::e2e_tests::test_cancel($server, cx).await;
388            }
389
390            #[::gpui::test]
391            #[cfg_attr(not(feature = "e2e"), ignore)]
392            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
393                $crate::e2e_tests::test_thread_drop($server, cx).await;
394            }
395        }
396    };
397}
398pub use common_e2e_tests;
399
400// Helpers
401
402pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
403    env_logger::try_init().ok();
404
405    cx.update(|cx| {
406        let settings_store = settings::SettingsStore::test(cx);
407        cx.set_global(settings_store);
408        gpui_tokio::init(cx);
409        let http_client = reqwest_client::ReqwestClient::user_agent("agent tests").unwrap();
410        cx.set_http_client(Arc::new(http_client));
411        let client = client::Client::production(cx);
412        let user_store = cx.new(|cx| client::UserStore::new(client.clone(), cx));
413        language_model::init(user_store, client, cx);
414
415        #[cfg(test)]
416        project::agent_server_store::AllAgentServersSettings::override_global(
417            project::agent_server_store::AllAgentServersSettings(collections::HashMap::default()),
418            cx,
419        );
420    });
421
422    cx.executor().allow_parking();
423
424    FakeFs::new(cx.executor())
425}
426
427pub async fn new_test_thread(
428    server: impl AgentServer + 'static,
429    project: Entity<Project>,
430    current_dir: impl AsRef<Path>,
431    cx: &mut TestAppContext,
432) -> Entity<AcpThread> {
433    let store = project.read_with(cx, |project, _| project.agent_server_store().clone());
434    let delegate = AgentServerDelegate::new(store, project.clone(), None, None);
435
436    let connection = cx.update(|cx| server.connect(delegate, cx)).await.unwrap();
437
438    cx.update(|cx| connection.new_session(project.clone(), current_dir.as_ref(), cx))
439        .await
440        .unwrap()
441}
442
443pub async fn run_until_first_tool_call(
444    thread: &Entity<AcpThread>,
445    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
446    cx: &mut TestAppContext,
447) -> usize {
448    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
449
450    let subscription = cx.update(|cx| {
451        cx.subscribe(thread, move |thread, _, cx| {
452            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
453                if wait_until(entry) {
454                    return tx.try_send(ix).unwrap();
455                }
456            }
457        })
458    });
459
460    select! {
461        _ = futures::FutureExt::fuse(cx.background_executor.timer(Duration::from_secs(20))) => {
462            panic!("Timeout waiting for tool call")
463        }
464        ix = rx.next().fuse() => {
465            drop(subscription);
466            ix.unwrap()
467        }
468    }
469}
470
471pub fn get_zed_path() -> PathBuf {
472    let mut zed_path = std::env::current_exe().unwrap();
473
474    while zed_path
475        .file_name()
476        .is_none_or(|name| name.to_string_lossy() != "debug")
477    {
478        if !zed_path.pop() {
479            panic!("Could not find target directory");
480        }
481    }
482
483    zed_path.push("zed");
484
485    if !zed_path.exists() {
486        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
487    }
488
489    zed_path
490}