e2e_tests.rs

  1use std::{
  2    path::{Path, PathBuf},
  3    sync::Arc,
  4    time::Duration,
  5};
  6
  7use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  8use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  9use agent_client_protocol as acp;
 10
 11use futures::{FutureExt, StreamExt, channel::mpsc, select};
 12use gpui::{Entity, TestAppContext};
 13use indoc::indoc;
 14use project::{FakeFs, Project};
 15use serde_json::json;
 16use settings::{Settings, SettingsStore};
 17use util::path;
 18
 19pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 20    let fs = init_test(cx).await;
 21    let project = Project::test(fs, [], cx).await;
 22    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
 23
 24    thread
 25        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 26        .await
 27        .unwrap();
 28
 29    thread.read_with(cx, |thread, _| {
 30        assert_eq!(thread.entries().len(), 2);
 31        assert!(matches!(
 32            thread.entries()[0],
 33            AgentThreadEntry::UserMessage(_)
 34        ));
 35        assert!(matches!(
 36            thread.entries()[1],
 37            AgentThreadEntry::AssistantMessage(_)
 38        ));
 39    });
 40}
 41
 42pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 43    let _fs = init_test(cx).await;
 44
 45    let tempdir = tempfile::tempdir().unwrap();
 46    std::fs::write(
 47        tempdir.path().join("foo.rs"),
 48        indoc! {"
 49            fn main() {
 50                println!(\"Hello, world!\");
 51            }
 52        "},
 53    )
 54    .expect("failed to write file");
 55    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 56    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 57    thread
 58        .update(cx, |thread, cx| {
 59            thread.send(
 60                vec![
 61                    acp::ContentBlock::Text(acp::TextContent {
 62                        text: "Read the file ".into(),
 63                        annotations: None,
 64                    }),
 65                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 66                        uri: "foo.rs".into(),
 67                        name: "foo.rs".into(),
 68                        annotations: None,
 69                        description: None,
 70                        mime_type: None,
 71                        size: None,
 72                        title: None,
 73                    }),
 74                    acp::ContentBlock::Text(acp::TextContent {
 75                        text: " and tell me what the content of the println! is".into(),
 76                        annotations: None,
 77                    }),
 78                ],
 79                cx,
 80            )
 81        })
 82        .await
 83        .unwrap();
 84
 85    thread.read_with(cx, |thread, cx| {
 86        assert!(matches!(
 87            thread.entries()[0],
 88            AgentThreadEntry::UserMessage(_)
 89        ));
 90        let assistant_message = &thread
 91            .entries()
 92            .iter()
 93            .rev()
 94            .find_map(|entry| match entry {
 95                AgentThreadEntry::AssistantMessage(msg) => Some(msg),
 96                _ => None,
 97            })
 98            .unwrap();
 99
100        assert!(
101            assistant_message.to_markdown(cx).contains("Hello, world!"),
102            "unexpected assistant message: {:?}",
103            assistant_message.to_markdown(cx)
104        );
105    });
106
107    drop(tempdir);
108}
109
110pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
111    let fs = init_test(cx).await;
112    fs.insert_tree(
113        path!("/private/tmp"),
114        json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
115    )
116    .await;
117    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
118    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
119
120    thread
121        .update(cx, |thread, cx| {
122            thread.send_raw(
123                "Read the '/private/tmp/foo' file and tell me what you see.",
124                cx,
125            )
126        })
127        .await
128        .unwrap();
129    thread.read_with(cx, |thread, _cx| {
130        assert!(thread.entries().iter().any(|entry| {
131            matches!(
132                entry,
133                AgentThreadEntry::ToolCall(ToolCall {
134                    status: ToolCallStatus::Allowed { .. },
135                    ..
136                })
137            )
138        }));
139        assert!(
140            thread
141                .entries()
142                .iter()
143                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
144        );
145    });
146}
147
148pub async fn test_tool_call_with_confirmation(
149    server: impl AgentServer + 'static,
150    allow_option_id: acp::PermissionOptionId,
151    cx: &mut TestAppContext,
152) {
153    let fs = init_test(cx).await;
154    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
155    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
156    let full_turn = thread.update(cx, |thread, cx| {
157        thread.send_raw(
158            r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
159            cx,
160        )
161    });
162
163    run_until_first_tool_call(
164        &thread,
165        |entry| {
166            matches!(
167                entry,
168                AgentThreadEntry::ToolCall(ToolCall {
169                    status: ToolCallStatus::WaitingForConfirmation { .. },
170                    ..
171                })
172            )
173        },
174        cx,
175    )
176    .await;
177
178    let tool_call_id = thread.read_with(cx, |thread, _cx| {
179        let AgentThreadEntry::ToolCall(ToolCall {
180            id,
181            content,
182            status: ToolCallStatus::WaitingForConfirmation { .. },
183            ..
184        }) = &thread
185            .entries()
186            .iter()
187            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
188            .unwrap()
189        else {
190            panic!();
191        };
192
193        assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
194
195        id.clone()
196    });
197
198    thread.update(cx, |thread, cx| {
199        thread.authorize_tool_call(
200            tool_call_id,
201            allow_option_id,
202            acp::PermissionOptionKind::AllowOnce,
203            cx,
204        );
205
206        assert!(thread.entries().iter().any(|entry| matches!(
207            entry,
208            AgentThreadEntry::ToolCall(ToolCall {
209                status: ToolCallStatus::Allowed { .. },
210                ..
211            })
212        )));
213    });
214
215    full_turn.await.unwrap();
216
217    thread.read_with(cx, |thread, cx| {
218        let AgentThreadEntry::ToolCall(ToolCall {
219            content,
220            status: ToolCallStatus::Allowed { .. },
221            ..
222        }) = thread
223            .entries()
224            .iter()
225            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
226            .unwrap()
227        else {
228            panic!();
229        };
230
231        assert!(
232            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
233            "Expected content to contain 'Hello'"
234        );
235    });
236}
237
238pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
239    let fs = init_test(cx).await;
240
241    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
242    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
243    let full_turn = thread.update(cx, |thread, cx| {
244        thread.send_raw(
245            r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
246            cx,
247        )
248    });
249
250    let first_tool_call_ix = run_until_first_tool_call(
251        &thread,
252        |entry| {
253            matches!(
254                entry,
255                AgentThreadEntry::ToolCall(ToolCall {
256                    status: ToolCallStatus::WaitingForConfirmation { .. },
257                    ..
258                })
259            )
260        },
261        cx,
262    )
263    .await;
264
265    thread.read_with(cx, |thread, _cx| {
266        let AgentThreadEntry::ToolCall(ToolCall {
267            id,
268            content,
269            status: ToolCallStatus::WaitingForConfirmation { .. },
270            ..
271        }) = &thread.entries()[first_tool_call_ix]
272        else {
273            panic!("{:?}", thread.entries()[1]);
274        };
275
276        assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
277
278        id.clone()
279    });
280
281    let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
282    full_turn.await.unwrap();
283    thread.read_with(cx, |thread, _| {
284        let AgentThreadEntry::ToolCall(ToolCall {
285            status: ToolCallStatus::Canceled,
286            ..
287        }) = &thread.entries()[first_tool_call_ix]
288        else {
289            panic!();
290        };
291    });
292
293    thread
294        .update(cx, |thread, cx| {
295            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
296        })
297        .await
298        .unwrap();
299    thread.read_with(cx, |thread, _| {
300        assert!(matches!(
301            &thread.entries().last().unwrap(),
302            AgentThreadEntry::AssistantMessage(..),
303        ))
304    });
305}
306
307#[macro_export]
308macro_rules! common_e2e_tests {
309    ($server:expr, allow_option_id = $allow_option_id:expr) => {
310        mod common_e2e {
311            use super::*;
312
313            #[::gpui::test]
314            #[cfg_attr(not(feature = "e2e"), ignore)]
315            async fn basic(cx: &mut ::gpui::TestAppContext) {
316                $crate::e2e_tests::test_basic($server, cx).await;
317            }
318
319            #[::gpui::test]
320            #[cfg_attr(not(feature = "e2e"), ignore)]
321            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
322                $crate::e2e_tests::test_path_mentions($server, cx).await;
323            }
324
325            #[::gpui::test]
326            #[cfg_attr(not(feature = "e2e"), ignore)]
327            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
328                $crate::e2e_tests::test_tool_call($server, cx).await;
329            }
330
331            #[::gpui::test]
332            #[cfg_attr(not(feature = "e2e"), ignore)]
333            async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
334                $crate::e2e_tests::test_tool_call_with_confirmation(
335                    $server,
336                    ::agent_client_protocol::PermissionOptionId($allow_option_id.into()),
337                    cx,
338                )
339                .await;
340            }
341
342            #[::gpui::test]
343            #[cfg_attr(not(feature = "e2e"), ignore)]
344            async fn cancel(cx: &mut ::gpui::TestAppContext) {
345                $crate::e2e_tests::test_cancel($server, cx).await;
346            }
347        }
348    };
349}
350
351// Helpers
352
353pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
354    env_logger::try_init().ok();
355
356    cx.update(|cx| {
357        let settings_store = SettingsStore::test(cx);
358        cx.set_global(settings_store);
359        Project::init_settings(cx);
360        language::init(cx);
361        crate::settings::init(cx);
362
363        crate::AllAgentServersSettings::override_global(
364            AllAgentServersSettings {
365                claude: Some(AgentServerSettings {
366                    command: crate::claude::tests::local_command(),
367                }),
368                gemini: Some(AgentServerSettings {
369                    command: crate::gemini::tests::local_command(),
370                }),
371                codex: Some(AgentServerSettings {
372                    command: crate::codex::tests::local_command(),
373                }),
374            },
375            cx,
376        );
377    });
378
379    cx.executor().allow_parking();
380
381    FakeFs::new(cx.executor())
382}
383
384pub async fn new_test_thread(
385    server: impl AgentServer + 'static,
386    project: Entity<Project>,
387    current_dir: impl AsRef<Path>,
388    cx: &mut TestAppContext,
389) -> Entity<AcpThread> {
390    let connection = cx
391        .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
392        .await
393        .unwrap();
394
395    let thread = connection
396        .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
397        .await
398        .unwrap();
399
400    thread
401}
402
403pub async fn run_until_first_tool_call(
404    thread: &Entity<AcpThread>,
405    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
406    cx: &mut TestAppContext,
407) -> usize {
408    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
409
410    let subscription = cx.update(|cx| {
411        cx.subscribe(thread, move |thread, _, cx| {
412            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
413                if wait_until(entry) {
414                    return tx.try_send(ix).unwrap();
415                }
416            }
417        })
418    });
419
420    select! {
421        // We have to use a smol timer here because
422        // cx.background_executor().timer isn't real in the test context
423        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
424            panic!("Timeout waiting for tool call")
425        }
426        ix = rx.next().fuse() => {
427            drop(subscription);
428            ix.unwrap()
429        }
430    }
431}
432
433pub fn get_zed_path() -> PathBuf {
434    let mut zed_path = std::env::current_exe().unwrap();
435
436    while zed_path
437        .file_name()
438        .map_or(true, |name| name.to_string_lossy() != "debug")
439    {
440        if !zed_path.pop() {
441            panic!("Could not find target directory");
442        }
443    }
444
445    zed_path.push("zed");
446
447    if !zed_path.exists() {
448        panic!("\n🚨 Run `cargo build` at least once before running e2e tests\n\n");
449    }
450
451    zed_path
452}