e2e_tests.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4    time::Duration,
  5};
  6
  7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  9use agent_client_protocol as acp;
 10
 11use futures::{FutureExt, StreamExt, channel::mpsc, select};
 12use gpui::{Entity, TestAppContext};
 13use indoc::indoc;
 14use project::{FakeFs, Project};
 15use settings::{Settings, SettingsStore};
 16
 17pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 18    let fs = init_test(cx).await;
 19    let tempdir = tempfile::tempdir().unwrap();
 20    let project = Project::test(fs, [], cx).await;
 21    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 22
 23    thread
 24        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 25        .await
 26        .unwrap();
 27
 28    thread.read_with(cx, |thread, _| {
 29        assert!(
 30            thread.entries().len() >= 2,
 31            "Expected at least 2 entries. Got: {:?}",
 32            thread.entries()
 33        );
 34        assert!(matches!(
 35            thread.entries()[0],
 36            AgentThreadEntry::UserMessage(_)
 37        ));
 38        assert!(matches!(
 39            thread.entries()[1],
 40            AgentThreadEntry::AssistantMessage(_)
 41        ));
 42    });
 43
 44    drop(tempdir);
 45}
 46
 47pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 48    let _fs = init_test(cx).await;
 49
 50    let tempdir = tempfile::tempdir().unwrap();
 51    std::fs::write(
 52        tempdir.path().join("foo.rs"),
 53        indoc! {"
 54            fn main() {
 55                println!(\"Hello, world!\");
 56            }
 57        "},
 58    )
 59    .expect("failed to write file");
 60    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 61    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 62    thread
 63        .update(cx, |thread, cx| {
 64            thread.send(
 65                vec![
 66                    acp::ContentBlock::Text(acp::TextContent {
 67                        text: "Read the file ".into(),
 68                        annotations: None,
 69                    }),
 70                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 71                        uri: "foo.rs".into(),
 72                        name: "foo.rs".into(),
 73                        annotations: None,
 74                        description: None,
 75                        mime_type: None,
 76                        size: None,
 77                        title: None,
 78                    }),
 79                    acp::ContentBlock::Text(acp::TextContent {
 80                        text: " and tell me what the content of the println! is".into(),
 81                        annotations: None,
 82                    }),
 83                ],
 84                cx,
 85            )
 86        })
 87        .await
 88        .unwrap();
 89
 90    thread.read_with(cx, |thread, cx| {
 91        assert!(matches!(
 92            thread.entries()[0],
 93            AgentThreadEntry::UserMessage(_)
 94        ));
 95        let assistant_message = &thread
 96            .entries()
 97            .iter()
 98            .rev()
 99            .find_map(|entry| match entry {
100                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
101                _ => None,
102            })
103            .unwrap();
104
105        assert!(
106            assistant_message.to_markdown(cx).contains("Hello, world!"),
107            "unexpected assistant message: {:?}",
108            assistant_message.to_markdown(cx)
109        );
110    });
111
112    drop(tempdir);
113}
114
115pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
116    let _fs = init_test(cx).await;
117
118    let tempdir = tempfile::tempdir().unwrap();
119    let foo_path = tempdir.path().join("foo");
120    std::fs::write(&foo_path, "Lorem ipsum dolor").expect("failed to write file");
121
122    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
123    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
124
125    thread
126        .update(cx, |thread, cx| {
127            thread.send_raw(
128                &format!("Read {} and tell me what you see.", foo_path.display()),
129                cx,
130            )
131        })
132        .await
133        .unwrap();
134    thread.read_with(cx, |thread, _cx| {
135        assert!(thread.entries().iter().any(|entry| {
136            matches!(
137                entry,
138                AgentThreadEntry::ToolCall(ToolCall {
139                    status: ToolCallStatus::Allowed { .. },
140                    ..
141                })
142            )
143        }));
144        assert!(
145            thread
146                .entries()
147                .iter()
148                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
149        );
150    });
151
152    drop(tempdir);
153}
154
155pub async fn test_tool_call_with_permission(
156    server: impl AgentServer + 'static,
157    allow_option_id: acp::PermissionOptionId,
158    cx: &mut TestAppContext,
159) {
160    let fs = init_test(cx).await;
161    let tempdir = tempfile::tempdir().unwrap();
162    let project = Project::test(fs, [tempdir.path()], cx).await;
163    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
164    let full_turn = thread.update(cx, |thread, cx| {
165        thread.send_raw(
166            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
167            cx,
168        )
169    });
170
171    run_until_first_tool_call(
172        &thread,
173        |entry| {
174            matches!(
175                entry,
176                AgentThreadEntry::ToolCall(ToolCall {
177                    status: ToolCallStatus::WaitingForConfirmation { .. },
178                    ..
179                })
180            )
181        },
182        cx,
183    )
184    .await;
185
186    let tool_call_id = thread.read_with(cx, |thread, cx| {
187        let AgentThreadEntry::ToolCall(ToolCall {
188            id,
189            label,
190            status: ToolCallStatus::WaitingForConfirmation { .. },
191            ..
192        }) = &thread
193            .entries()
194            .iter()
195            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
196            .unwrap()
197        else {
198            panic!();
199        };
200
201        let label = label.read(cx).source();
202        assert!(label.contains("touch"), "Got: {}", label);
203
204        id.clone()
205    });
206
207    thread.update(cx, |thread, cx| {
208        thread.authorize_tool_call(
209            tool_call_id,
210            allow_option_id,
211            acp::PermissionOptionKind::AllowOnce,
212            cx,
213        );
214
215        assert!(thread.entries().iter().any(|entry| matches!(
216            entry,
217            AgentThreadEntry::ToolCall(ToolCall {
218                status: ToolCallStatus::Allowed { .. },
219                ..
220            })
221        )));
222    });
223
224    full_turn.await.unwrap();
225
226    thread.read_with(cx, |thread, cx| {
227        let AgentThreadEntry::ToolCall(ToolCall {
228            content,
229            status: ToolCallStatus::Allowed { .. },
230            ..
231        }) = thread
232            .entries()
233            .iter()
234            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
235            .unwrap()
236        else {
237            panic!();
238        };
239
240        assert!(
241            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
242            "Expected content to contain 'Hello'"
243        );
244    });
245
246    drop(tempdir);
247}
248
249pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
250    let fs = init_test(cx).await;
251    let tempdir = tempfile::tempdir().unwrap();
252    let project = Project::test(fs, [tempdir.path()], cx).await;
253    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
254    let _ = thread.update(cx, |thread, cx| {
255        thread.send_raw(
256            r#"Run exactly `touch hello.txt && echo "Hello, world!" | tee hello.txt` in the terminal."#,
257            cx,
258        )
259    });
260
261    let first_tool_call_ix = run_until_first_tool_call(
262        &thread,
263        |entry| {
264            matches!(
265                entry,
266                AgentThreadEntry::ToolCall(ToolCall {
267                    status: ToolCallStatus::WaitingForConfirmation { .. },
268                    ..
269                })
270            )
271        },
272        cx,
273    )
274    .await;
275
276    thread.read_with(cx, |thread, cx| {
277        let AgentThreadEntry::ToolCall(ToolCall {
278            id,
279            label,
280            status: ToolCallStatus::WaitingForConfirmation { .. },
281            ..
282        }) = &thread.entries()[first_tool_call_ix]
283        else {
284            panic!("{:?}", thread.entries()[1]);
285        };
286
287        let label = label.read(cx).source();
288        assert!(label.contains("touch"), "Got: {}", label);
289
290        id.clone()
291    });
292
293    thread.update(cx, |thread, cx| thread.cancel(cx)).await;
294    thread.read_with(cx, |thread, _cx| {
295        let AgentThreadEntry::ToolCall(ToolCall {
296            status: ToolCallStatus::Canceled,
297            ..
298        }) = &thread.entries()[first_tool_call_ix]
299        else {
300            panic!();
301        };
302    });
303
304    thread
305        .update(cx, |thread, cx| {
306            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
307        })
308        .await
309        .unwrap();
310    thread.read_with(cx, |thread, _| {
311        assert!(matches!(
312            &thread.entries().last().unwrap(),
313            AgentThreadEntry::AssistantMessage(..),
314        ))
315    });
316
317    drop(tempdir);
318}
319
320pub async fn test_thread_drop(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
321    let fs = init_test(cx).await;
322    let tempdir = tempfile::tempdir().unwrap();
323    let project = Project::test(fs, [], cx).await;
324    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
325
326    thread
327        .update(cx, |thread, cx| thread.send_raw("Hello from test!", cx))
328        .await
329        .unwrap();
330
331    thread.read_with(cx, |thread, _| {
332        assert!(thread.entries().len() >= 2, "Expected at least 2 entries");
333    });
334
335    let weak_thread = thread.downgrade();
336    drop(thread);
337
338    cx.executor().run_until_parked();
339    assert!(!weak_thread.is_upgradable());
340
341    drop(tempdir);
342}
343
344#[macro_export]
345macro_rules! common_e2e_tests {
346    ($server:expr, allow_option_id = $allow_option_id:expr) => {
347        mod common_e2e {
348            use super::*;
349
350            #[::gpui::test]
351            #[cfg_attr(not(feature = "e2e"), ignore)]
352            async fn basic(cx: &mut ::gpui::TestAppContext) {
353                $crate::e2e_tests::test_basic($server, cx).await;
354            }
355
356            #[::gpui::test]
357            #[cfg_attr(not(feature = "e2e"), ignore)]
358            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
359                $crate::e2e_tests::test_path_mentions($server, cx).await;
360            }
361
362            #[::gpui::test]
363            #[cfg_attr(not(feature = "e2e"), ignore)]
364            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
365                $crate::e2e_tests::test_tool_call($server, cx).await;
366            }
367
368            #[::gpui::test]
369            #[cfg_attr(not(feature = "e2e"), ignore)]
370            async fn tool_call_with_permission(cx: &mut ::gpui::TestAppContext) {
371                $crate::e2e_tests::test_tool_call_with_permission(
372                    $server,
373                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
374                    cx,
375                )
376                .await;
377            }
378
379            #[::gpui::test]
380            #[cfg_attr(not(feature = "e2e"), ignore)]
381            async fn cancel(cx: &mut ::gpui::TestAppContext) {
382                $crate::e2e_tests::test_cancel($server, cx).await;
383            }
384
385            #[::gpui::test]
386            #[cfg_attr(not(feature = "e2e"), ignore)]
387            async fn thread_drop(cx: &mut ::gpui::TestAppContext) {
388                $crate::e2e_tests::test_thread_drop($server, cx).await;
389            }
390        }
391    };
392}
393
394// Helpers
395
396pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
397    env_logger::try_init().ok();
398
399    cx.update(|cx| {
400        let settings_store = SettingsStore::test(cx);
401        cx.set_global(settings_store);
402        Project::init_settings(cx);
403        language::init(cx);
404        crate::settings::init(cx);
405
406        crate::AllAgentServersSettings::override_global(
407            AllAgentServersSettings {
408                claude: Some(AgentServerSettings {
409                    command: crate::claude::tests::local_command(),
410                }),
411                gemini: Some(AgentServerSettings {
412                    command: crate::gemini::tests::local_command(),
413                }),
414            },
415            cx,
416        );
417    });
418
419    cx.executor().allow_parking();
420
421    FakeFs::new(cx.executor())
422}
423
424pub async fn new_test_thread(
425    server: impl AgentServer + 'static,
426    project: Entity<Project>,
427    current_dir: impl AsRef<Path>,
428    cx: &mut TestAppContext,
429) -> Entity<AcpThread> {
430    let connection = cx
431        .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
432        .await
433        .unwrap();
434
435    let thread = connection
436        .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
437        .await
438        .unwrap();
439
440    thread
441}
442
443pub async fn run_until_first_tool_call(
444    thread: &Entity<AcpThread>,
445    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
446    cx: &mut TestAppContext,
447) -> usize {
448    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
449
450    let subscription = cx.update(|cx| {
451        cx.subscribe(thread, move |thread, _, cx| {
452            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
453                if wait_until(entry) {
454                    return tx.try_send(ix).unwrap();
455                }
456            }
457        })
458    });
459
460    select! {
461        // We have to use a smol timer here because
462        // cx.background_executor().timer isn't real in the test context
463        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
464            panic!("Timeout waiting for tool call")
465        }
466        ix = rx.next().fuse() => {
467            drop(subscription);
468            ix.unwrap()
469        }
470    }
471}
472
473pub fn get_zed_path() -> PathBuf {
474    let mut zed_path = std::env::current_exe().unwrap();
475
476    while zed_path
477        .file_name()
478        .map_or(true, |name| name.to_string_lossy() != "debug")
479    {
480        if !zed_path.pop() {
481            panic!("Could not find target directory");
482        }
483    }
484
485    zed_path.push("zed");
486
487    if !zed_path.exists() {
488        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
489    }
490
491    zed_path
492}