e2e_tests.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4    time::Duration,
  5};
  6
  7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  9use agent_client_protocol as acp;
 10
 11use futures::{FutureExt, StreamExt, channel::mpsc, select};
 12use gpui::{Entity, TestAppContext};
 13use indoc::indoc;
 14use project::{FakeFs, Project};
 15use settings::{Settings, SettingsStore};
 16use util::path;
 17
 18pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 19    let fs = init_test(cx).await;
 20    let project = Project::test(fs, [], cx).await;
 21    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
 22
 23    thread
 24        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 25        .await
 26        .unwrap();
 27
 28    thread.read_with(cx, |thread, _| {
 29        assert!(
 30            thread.entries().len() >= 2,
 31            "Expected at least 2 entries. Got: {:?}",
 32            thread.entries()
 33        );
 34        assert!(matches!(
 35            thread.entries()[0],
 36            AgentThreadEntry::UserMessage(_)
 37        ));
 38        assert!(matches!(
 39            thread.entries()[1],
 40            AgentThreadEntry::AssistantMessage(_)
 41        ));
 42    });
 43}
 44
 45pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 46    let _fs = init_test(cx).await;
 47
 48    let tempdir = tempfile::tempdir().unwrap();
 49    std::fs::write(
 50        tempdir.path().join("foo.rs"),
 51        indoc! {"
 52            fn main() {
 53                println!(\"Hello, world!\");
 54            }
 55        "},
 56    )
 57    .expect("failed to write file");
 58    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 59    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 60    thread
 61        .update(cx, |thread, cx| {
 62            thread.send(
 63                vec![
 64                    acp::ContentBlock::Text(acp::TextContent {
 65                        text: "Read the file ".into(),
 66                        annotations: None,
 67                    }),
 68                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 69                        uri: "foo.rs".into(),
 70                        name: "foo.rs".into(),
 71                        annotations: None,
 72                        description: None,
 73                        mime_type: None,
 74                        size: None,
 75                        title: None,
 76                    }),
 77                    acp::ContentBlock::Text(acp::TextContent {
 78                        text: " and tell me what the content of the println! is".into(),
 79                        annotations: None,
 80                    }),
 81                ],
 82                cx,
 83            )
 84        })
 85        .await
 86        .unwrap();
 87
 88    thread.read_with(cx, |thread, cx| {
 89        assert!(matches!(
 90            thread.entries()[0],
 91            AgentThreadEntry::UserMessage(_)
 92        ));
 93        let assistant_message = &thread
 94            .entries()
 95            .iter()
 96            .rev()
 97            .find_map(|entry| match entry {
 98                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
 99                _ => None,
100            })
101            .unwrap();
102
103        assert!(
104            assistant_message.to_markdown(cx).contains("Hello, world!"),
105            "unexpected assistant message: {:?}",
106            assistant_message.to_markdown(cx)
107        );
108    });
109
110    drop(tempdir);
111}
112
113pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
114    let _fs = init_test(cx).await;
115
116    let tempdir = tempfile::tempdir().unwrap();
117    let foo_path = tempdir.path().join("foo");
118    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
119
120    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
121    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
122
123    thread
124        .update(cx, |thread, cx| {
125            thread.send_raw(
126                &format!("Read {} and tell me what you see.", foo_path.display()),
127                cx,
128            )
129        })
130        .await
131        .unwrap();
132    thread.read_with(cx, |thread, _cx| {
133        assert!(thread.entries().iter().any(|entry| {
134            matches!(
135                entry,
136                AgentThreadEntry::ToolCall(ToolCall {
137                    status: ToolCallStatus::Allowed { .. },
138                    ..
139                })
140            )
141        }));
142        assert!(
143            thread
144                .entries()
145                .iter()
146                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
147        );
148    });
149
150    drop(tempdir);
151}
152
153pub async fn test_tool_call_with_permission(
154    server: impl AgentServer + 'static,
155    allow_option_id: acp::PermissionOptionId,
156    cx: &mut TestAppContext,
157) {
158    let fs = init_test(cx).await;
159    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
160    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
161    let full_turn = thread.update(cx, |thread, cx| {
162        thread.send_raw(
163            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
164            cx,
165        )
166    });
167
168    run_until_first_tool_call(
169        &thread,
170        |entry| {
171            matches!(
172                entry,
173                AgentThreadEntry::ToolCall(ToolCall {
174                    status: ToolCallStatus::WaitingForConfirmation { .. },
175                    ..
176                })
177            )
178        },
179        cx,
180    )
181    .await;
182
183    let tool_call_id = thread.read_with(cx, |thread, cx| {
184        let AgentThreadEntry::ToolCall(ToolCall {
185            id,
186            label,
187            status: ToolCallStatus::WaitingForConfirmation { .. },
188            ..
189        }) = &thread
190            .entries()
191            .iter()
192            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
193            .unwrap()
194        else {
195            panic!();
196        };
197
198        let label = label.read(cx).source();
199        assert!(label.contains("touch"), "Got: {}", label);
200
201        id.clone()
202    });
203
204    thread.update(cx, |thread, cx| {
205        thread.authorize_tool_call(
206            tool_call_id,
207            allow_option_id,
208            acp::PermissionOptionKind::AllowOnce,
209            cx,
210        );
211
212        assert!(thread.entries().iter().any(|entry| matches!(
213            entry,
214            AgentThreadEntry::ToolCall(ToolCall {
215                status: ToolCallStatus::Allowed { .. },
216                ..
217            })
218        )));
219    });
220
221    full_turn.await.unwrap();
222
223    thread.read_with(cx, |thread, cx| {
224        let AgentThreadEntry::ToolCall(ToolCall {
225            content,
226            status: ToolCallStatus::Allowed { .. },
227            ..
228        }) = thread
229            .entries()
230            .iter()
231            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
232            .unwrap()
233        else {
234            panic!();
235        };
236
237        assert!(
238            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
239            "Expected content to contain 'Hello'"
240        );
241    });
242}
243
244pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
245    let fs = init_test(cx).await;
246
247    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
248    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
249    let full_turn = thread.update(cx, |thread, cx| {
250        thread.send_raw(
251            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
252            cx,
253        )
254    });
255
256    let first_tool_call_ix = run_until_first_tool_call(
257        &thread,
258        |entry| {
259            matches!(
260                entry,
261                AgentThreadEntry::ToolCall(ToolCall {
262                    status: ToolCallStatus::WaitingForConfirmation { .. },
263                    ..
264                })
265            )
266        },
267        cx,
268    )
269    .await;
270
271    thread.read_with(cx, |thread, cx| {
272        let AgentThreadEntry::ToolCall(ToolCall {
273            id,
274            label,
275            status: ToolCallStatus::WaitingForConfirmation { .. },
276            ..
277        }) = &thread.entries()[first_tool_call_ix]
278        else {
279            panic!("{:?}", thread.entries()[1]);
280        };
281
282        let label = label.read(cx).source();
283        assert!(label.contains("touch"), "Got: {}", label);
284
285        id.clone()
286    });
287
288    let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
289    full_turn.await.unwrap();
290    thread.read_with(cx, |thread, _| {
291        let AgentThreadEntry::ToolCall(ToolCall {
292            status: ToolCallStatus::Canceled,
293            ..
294        }) = &thread.entries()[first_tool_call_ix]
295        else {
296            panic!();
297        };
298    });
299
300    thread
301        .update(cx, |thread, cx| {
302            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
303        })
304        .await
305        .unwrap();
306    thread.read_with(cx, |thread, _| {
307        assert!(matches!(
308            &thread.entries().last().unwrap(),
309            AgentThreadEntry::AssistantMessage(..),
310        ))
311    });
312}
313
314pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
315    let fs = init_test(cx).await;
316    let project = Project::test(fs, [], cx).await;
317    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
318
319    thread
320        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
321        .await
322        .unwrap();
323
324    thread.read_with(cx, |thread, _| {
325        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
326    });
327
328    let weak_thread = thread.downgrade();
329    drop(thread);
330
331    cx.executor().run_until_parked();
332    assert!(!weak_thread.is_upgradable());
333}
334
335#[macro_export]
336macro_rules! common_e2e_tests {
337    ($server:expr, allow_option_id = $allow_option_id:expr) => {
338        mod common_e2e {
339            use super::*;
340
341            #[::gpui::test]
342            #[cfg_attr(not(feature = "e2e"), ignore)]
343            async fn basic(cx: &mut ::gpui::TestAppContext) {
344                $crate::e2e_tests::test_basic($server, cx).await;
345            }
346
347            #[::gpui::test]
348            #[cfg_attr(not(feature = "e2e"), ignore)]
349            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
350                $crate::e2e_tests::test_path_mentions($server, cx).await;
351            }
352
353            #[::gpui::test]
354            #[cfg_attr(not(feature = "e2e"), ignore)]
355            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
356                $crate::e2e_tests::test_tool_call($server, cx).await;
357            }
358
359            #[::gpui::test]
360            #[cfg_attr(not(feature = "e2e"), ignore)]
361            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
362                $crate::e2e_tests::test_tool_call_with_permission(
363                    $server,
364                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
365                    cx,
366                )
367                .await;
368            }
369
370            #[::gpui::test]
371            #[cfg_attr(not(feature = "e2e"), ignore)]
372            async fn cancel(cx: &mut ::gpui::TestAppContext) {
373                $crate::e2e_tests::test_cancel($server, cx).await;
374            }
375
376            #[::gpui::test]
377            #[cfg_attr(not(feature = "e2e"), ignore)]
378            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
379                $crate::e2e_tests::test_thread_drop($server, cx).await;
380            }
381        }
382    };
383}
384
385// Helpers
386
387pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
388    env_logger::try_init().ok();
389
390    cx.update(|cx| {
391        let settings_store = SettingsStore::test(cx);
392        cx.set_global(settings_store);
393        Project::init_settings(cx);
394        language::init(cx);
395        crate::settings::init(cx);
396
397        crate::AllAgentServersSettings::override_global(
398            AllAgentServersSettings {
399                claude: Some(AgentServerSettings {
400                    command: crate::claude::tests::local_command(),
401                }),
402                gemini: Some(AgentServerSettings {
403                    command: crate::gemini::tests::local_command(),
404                }),
405            },
406            cx,
407        );
408    });
409
410    cx.executor().allow_parking();
411
412    FakeFs::new(cx.executor())
413}
414
415pub async fn new_test_thread(
416    server: impl AgentServer + 'static,
417    project: Entity<Project>,
418    current_dir: impl AsRef<Path>,
419    cx: &mut TestAppContext,
420) -> Entity<AcpThread> {
421    let connection = cx
422        .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
423        .await
424        .unwrap();
425
426    let thread = connection
427        .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
428        .await
429        .unwrap();
430
431    thread
432}
433
434pub async fn run_until_first_tool_call(
435    thread: &Entity<AcpThread>,
436    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
437    cx: &mut TestAppContext,
438) -> usize {
439    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
440
441    let subscription = cx.update(|cx| {
442        cx.subscribe(thread, move |thread, _, cx| {
443            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
444                if wait_until(entry) {
445                    return tx.try_send(ix).unwrap();
446                }
447            }
448        })
449    });
450
451    select! {
452        // We have to use a smol timer here because
453        // cx.background_executor().timer isn't real in the test context
454        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
455            panic!("Timeout waiting for tool call")
456        }
457        ix = rx.next().fuse() => {
458            drop(subscription);
459            ix.unwrap()
460        }
461    }
462}
463
464pub fn get_zed_path() -> PathBuf {
465    let mut zed_path = std::env::current_exe().unwrap();
466
467    while zed_path
468        .file_name()
469        .map_or(true, |name| name.to_string_lossy() != "debug")
470    {
471        if !zed_path.pop() {
472            panic!("Could not find target directory");
473        }
474    }
475
476    zed_path.push("zed");
477
478    if !zed_path.exists() {
479        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
480    }
481
482    zed_path
483}