e2e_tests.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4    time::Duration,
  5};
  6
  7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  9use agent_client_protocol as acp;
 10
 11use futures::{FutureExt, StreamExt, channel::mpsc, select};
 12use gpui::{Entity, TestAppContext};
 13use indoc::indoc;
 14use project::{FakeFs, Project};
 15use settings::{Settings, SettingsStore};
 16use util::path;
 17
 18pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 19    let fs = init_test(cx).await;
 20    let project = Project::test(fs, [], cx).await;
 21    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
 22
 23    thread
 24        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 25        .await
 26        .unwrap();
 27
 28    thread.read_with(cx, |thread, _| {
 29        assert!(
 30            thread.entries().len() >= 2,
 31            "Expected at least 2 entries. Got: {:?}",
 32            thread.entries()
 33        );
 34        assert!(matches!(
 35            thread.entries()[0],
 36            AgentThreadEntry::UserMessage(_)
 37        ));
 38        assert!(matches!(
 39            thread.entries()[1],
 40            AgentThreadEntry::AssistantMessage(_)
 41        ));
 42    });
 43}
 44
 45pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 46    let _fs = init_test(cx).await;
 47
 48    let tempdir = tempfile::tempdir().unwrap();
 49    std::fs::write(
 50        tempdir.path().join("foo.rs"),
 51        indoc! {"
 52            fn main() {
 53                println!(\"Hello, world!\");
 54            }
 55        "},
 56    )
 57    .expect("failed to write file");
 58    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 59    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 60    thread
 61        .update(cx, |thread, cx| {
 62            thread.send(
 63                vec![
 64                    acp::ContentBlock::Text(acp::TextContent {
 65                        text: "Read the file ".into(),
 66                        annotations: None,
 67                    }),
 68                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 69                        uri: "foo.rs".into(),
 70                        name: "foo.rs".into(),
 71                        annotations: None,
 72                        description: None,
 73                        mime_type: None,
 74                        size: None,
 75                        title: None,
 76                    }),
 77                    acp::ContentBlock::Text(acp::TextContent {
 78                        text: " and tell me what the content of the println! is".into(),
 79                        annotations: None,
 80                    }),
 81                ],
 82                cx,
 83            )
 84        })
 85        .await
 86        .unwrap();
 87
 88    thread.read_with(cx, |thread, cx| {
 89        assert!(matches!(
 90            thread.entries()[0],
 91            AgentThreadEntry::UserMessage(_)
 92        ));
 93        let assistant_message = &thread
 94            .entries()
 95            .iter()
 96            .rev()
 97            .find_map(|entry| match entry {
 98                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
 99                _ => None,
100            })
101            .unwrap();
102
103        assert!(
104            assistant_message.to_markdown(cx).contains("Hello, world!"),
105            "unexpected assistant message: {:?}",
106            assistant_message.to_markdown(cx)
107        );
108    });
109
110    drop(tempdir);
111}
112
113pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
114    let _fs = init_test(cx).await;
115
116    let tempdir = tempfile::tempdir().unwrap();
117    let foo_path = tempdir.path().join("foo");
118    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
119
120    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
121    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
122
123    thread
124        .update(cx, |thread, cx| {
125            thread.send_raw(
126                &format!("Read {} and tell me what you see.", foo_path.display()),
127                cx,
128            )
129        })
130        .await
131        .unwrap();
132    thread.read_with(cx, |thread, _cx| {
133        assert!(thread.entries().iter().any(|entry| {
134            matches!(
135                entry,
136                AgentThreadEntry::ToolCall(ToolCall {
137                    status: ToolCallStatus::Pending
138                        | ToolCallStatus::InProgress
139                        | ToolCallStatus::Completed,
140                    ..
141                })
142            )
143        }));
144        assert!(
145            thread
146                .entries()
147                .iter()
148                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
149        );
150    });
151
152    drop(tempdir);
153}
154
155pub async fn test_tool_call_with_permission(
156    server: impl AgentServer + 'static,
157    allow_option_id: acp::PermissionOptionId,
158    cx: &mut TestAppContext,
159) {
160    let fs = init_test(cx).await;
161    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
162    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
163    let full_turn = thread.update(cx, |thread, cx| {
164        thread.send_raw(
165            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
166            cx,
167        )
168    });
169
170    run_until_first_tool_call(
171        &thread,
172        |entry| {
173            matches!(
174                entry,
175                AgentThreadEntry::ToolCall(ToolCall {
176                    status: ToolCallStatus::WaitingForConfirmation { .. },
177                    ..
178                })
179            )
180        },
181        cx,
182    )
183    .await;
184
185    let tool_call_id = thread.read_with(cx, |thread, cx| {
186        let AgentThreadEntry::ToolCall(ToolCall {
187            id,
188            label,
189            status: ToolCallStatus::WaitingForConfirmation { .. },
190            ..
191        }) = &thread
192            .entries()
193            .iter()
194            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
195            .unwrap()
196        else {
197            panic!();
198        };
199
200        let label = label.read(cx).source();
201        assert!(label.contains("touch"), "Got: {}", label);
202
203        id.clone()
204    });
205
206    thread.update(cx, |thread, cx| {
207        thread.authorize_tool_call(
208            tool_call_id,
209            allow_option_id,
210            acp::PermissionOptionKind::AllowOnce,
211            cx,
212        );
213
214        assert!(thread.entries().iter().any(|entry| matches!(
215            entry,
216            AgentThreadEntry::ToolCall(ToolCall {
217                status: ToolCallStatus::Pending
218                    | ToolCallStatus::InProgress
219                    | ToolCallStatus::Completed,
220                ..
221            })
222        )));
223    });
224
225    full_turn.await.unwrap();
226
227    thread.read_with(cx, |thread, cx| {
228        let AgentThreadEntry::ToolCall(ToolCall {
229            content,
230            status: ToolCallStatus::Pending
231                | ToolCallStatus::InProgress
232                | ToolCallStatus::Completed,
233            ..
234        }) = thread
235            .entries()
236            .iter()
237            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
238            .unwrap()
239        else {
240            panic!();
241        };
242
243        assert!(
244            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
245            "Expected content to contain 'Hello'"
246        );
247    });
248}
249
250pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
251    let fs = init_test(cx).await;
252
253    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
254    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
255    let _ = thread.update(cx, |thread, cx| {
256        thread.send_raw(
257            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
258            cx,
259        )
260    });
261
262    let first_tool_call_ix = run_until_first_tool_call(
263        &thread,
264        |entry| {
265            matches!(
266                entry,
267                AgentThreadEntry::ToolCall(ToolCall {
268                    status: ToolCallStatus::WaitingForConfirmation { .. },
269                    ..
270                })
271            )
272        },
273        cx,
274    )
275    .await;
276
277    thread.read_with(cx, |thread, cx| {
278        let AgentThreadEntry::ToolCall(ToolCall {
279            id,
280            label,
281            status: ToolCallStatus::WaitingForConfirmation { .. },
282            ..
283        }) = &thread.entries()[first_tool_call_ix]
284        else {
285            panic!("{:?}", thread.entries()[1]);
286        };
287
288        let label = label.read(cx).source();
289        assert!(label.contains("touch"), "Got: {}", label);
290
291        id.clone()
292    });
293
294    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
295    thread.read_with(cx, |thread, _cx| {
296        let AgentThreadEntry::ToolCall(ToolCall {
297            status: ToolCallStatus::Canceled,
298            ..
299        }) = &thread.entries()[first_tool_call_ix]
300        else {
301            panic!();
302        };
303    });
304
305    thread
306        .update(cx, |thread, cx| {
307            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
308        })
309        .await
310        .unwrap();
311    thread.read_with(cx, |thread, _| {
312        assert!(matches!(
313            &thread.entries().last().unwrap(),
314            AgentThreadEntry::AssistantMessage(..),
315        ))
316    });
317}
318
319pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
320    let fs = init_test(cx).await;
321    let project = Project::test(fs, [], cx).await;
322    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
323
324    thread
325        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
326        .await
327        .unwrap();
328
329    thread.read_with(cx, |thread, _| {
330        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
331    });
332
333    let weak_thread = thread.downgrade();
334    drop(thread);
335
336    cx.executor().run_until_parked();
337    assert!(!weak_thread.is_upgradable());
338}
339
340#[macro_export]
341macro_rules! common_e2e_tests {
342    ($server:expr, allow_option_id = $allow_option_id:expr) => {
343        mod common_e2e {
344            use super::*;
345
346            #[::gpui::test]
347            #[cfg_attr(not(feature = "e2e"), ignore)]
348            async fn basic(cx: &mut ::gpui::TestAppContext) {
349                $crate::e2e_tests::test_basic($server, cx).await;
350            }
351
352            #[::gpui::test]
353            #[cfg_attr(not(feature = "e2e"), ignore)]
354            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
355                $crate::e2e_tests::test_path_mentions($server, cx).await;
356            }
357
358            #[::gpui::test]
359            #[cfg_attr(not(feature = "e2e"), ignore)]
360            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
361                $crate::e2e_tests::test_tool_call($server, cx).await;
362            }
363
364            #[::gpui::test]
365            #[cfg_attr(not(feature = "e2e"), ignore)]
366            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
367                $crate::e2e_tests::test_tool_call_with_permission(
368                    $server,
369                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
370                    cx,
371                )
372                .await;
373            }
374
375            #[::gpui::test]
376            #[cfg_attr(not(feature = "e2e"), ignore)]
377            async fn cancel(cx: &mut ::gpui::TestAppContext) {
378                $crate::e2e_tests::test_cancel($server, cx).await;
379            }
380
381            #[::gpui::test]
382            #[cfg_attr(not(feature = "e2e"), ignore)]
383            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
384                $crate::e2e_tests::test_thread_drop($server, cx).await;
385            }
386        }
387    };
388}
389
390// Helpers
391
392pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
393    env_logger::try_init().ok();
394
395    cx.update(|cx| {
396        let settings_store = SettingsStore::test(cx);
397        cx.set_global(settings_store);
398        Project::init_settings(cx);
399        language::init(cx);
400        crate::settings::init(cx);
401
402        crate::AllAgentServersSettings::override_global(
403            AllAgentServersSettings {
404                claude: Some(AgentServerSettings {
405                    command: crate::claude::tests::local_command(),
406                }),
407                gemini: Some(AgentServerSettings {
408                    command: crate::gemini::tests::local_command(),
409                }),
410            },
411            cx,
412        );
413    });
414
415    cx.executor().allow_parking();
416
417    FakeFs::new(cx.executor())
418}
419
420pub async fn new_test_thread(
421    server: impl AgentServer + 'static,
422    project: Entity<Project>,
423    current_dir: impl AsRef<Path>,
424    cx: &mut TestAppContext,
425) -> Entity<AcpThread> {
426    let connection = cx
427        .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
428        .await
429        .unwrap();
430
431    cx.update(|cx| connection.new_thread(project.clone(), current_dir.as_ref(), cx))
432        .await
433        .unwrap()
434}
435
436pub async fn run_until_first_tool_call(
437    thread: &Entity<AcpThread>,
438    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
439    cx: &mut TestAppContext,
440) -> usize {
441    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
442
443    let subscription = cx.update(|cx| {
444        cx.subscribe(thread, move |thread, _, cx| {
445            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
446                if wait_until(entry) {
447                    return tx.try_send(ix).unwrap();
448                }
449            }
450        })
451    });
452
453    select! {
454        // We have to use a smol timer here because
455        // cx.background_executor().timer isn't real in the test context
456        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
457            panic!("Timeout waiting for tool call")
458        }
459        ix = rx.next().fuse() => {
460            drop(subscription);
461            ix.unwrap()
462        }
463    }
464}
465
466pub fn get_zed_path() -> PathBuf {
467    let mut zed_path = std::env::current_exe().unwrap();
468
469    while zed_path
470        .file_name()
471        .is_none_or(|name| name.to_string_lossy() != "debug")
472    {
473        if !zed_path.pop() {
474            panic!("Could not find target directory");
475        }
476    }
477
478    zed_path.push("zed");
479
480    if !zed_path.exists() {
481        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
482    }
483
484    zed_path
485}