e2e_tests.rs

  1use std::{path::Path, sync::Arc, time::Duration};
  2
  3use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  4use acp_thread::{AcpThread, AgentThreadEntry, ToolCall, ToolCallStatus};
  5use agent_client_protocol as acp;
  6
  7use futures::{FutureExt, StreamExt, channel::mpsc, select};
  8use gpui::{Entity, TestAppContext};
  9use indoc::indoc;
 10use project::{FakeFs, Project};
 11use serde_json::json;
 12use settings::{Settings, SettingsStore};
 13use util::path;
 14
 15pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 16    let fs = init_test(cx).await;
 17    let project = Project::test(fs, [], cx).await;
 18    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
 19
 20    thread
 21        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 22        .await
 23        .unwrap();
 24
 25    thread.read_with(cx, |thread, _| {
 26        assert_eq!(thread.entries().len(), 2);
 27        assert!(matches!(
 28            thread.entries()[0],
 29            AgentThreadEntry::UserMessage(_)
 30        ));
 31        assert!(matches!(
 32            thread.entries()[1],
 33            AgentThreadEntry::AssistantMessage(_)
 34        ));
 35    });
 36}
 37
 38pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 39    let _fs = init_test(cx).await;
 40
 41    let tempdir = tempfile::tempdir().unwrap();
 42    std::fs::write(
 43        tempdir.path().join("foo.rs"),
 44        indoc! {"
 45            fn main() {
 46                println!(\"Hello, world!\");
 47            }
 48        "},
 49    )
 50    .expect("failed to write file");
 51    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 52    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 53    thread
 54        .update(cx, |thread, cx| {
 55            thread.send(
 56                vec![
 57                    acp::ContentBlock::Text(acp::TextContent {
 58                        text: "Read the file ".into(),
 59                        annotations: None,
 60                    }),
 61                    acp::ContentBlock::ResourceLink(acp::ResourceLink {
 62                        uri: "foo.rs".into(),
 63                        name: "foo.rs".into(),
 64                        annotations: None,
 65                        description: None,
 66                        mime_type: None,
 67                        size: None,
 68                        title: None,
 69                    }),
 70                    acp::ContentBlock::Text(acp::TextContent {
 71                        text: " and tell me what the content of the println! is".into(),
 72                        annotations: None,
 73                    }),
 74                ],
 75                cx,
 76            )
 77        })
 78        .await
 79        .unwrap();
 80
 81    thread.read_with(cx, |thread, cx| {
 82        assert_eq!(thread.entries().len(), 3);
 83        assert!(matches!(
 84            thread.entries()[0],
 85            AgentThreadEntry::UserMessage(_)
 86        ));
 87        assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
 88        let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
 89            panic!("Expected AssistantMessage")
 90        };
 91        assert!(
 92            assistant_message.to_markdown(cx).contains("Hello, world!"),
 93            "unexpected assistant message: {:?}",
 94            assistant_message.to_markdown(cx)
 95        );
 96    });
 97}
 98
 99pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
100    let fs = init_test(cx).await;
101    fs.insert_tree(
102        path!("/private/tmp"),
103        json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
104    )
105    .await;
106    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
107    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
108
109    thread
110        .update(cx, |thread, cx| {
111            thread.send_raw(
112                "Read the '/private/tmp/foo' file and tell me what you see.",
113                cx,
114            )
115        })
116        .await
117        .unwrap();
118    thread.read_with(cx, |thread, _cx| {
119        assert!(thread.entries().iter().any(|entry| {
120            matches!(
121                entry,
122                AgentThreadEntry::ToolCall(ToolCall {
123                    status: ToolCallStatus::Allowed { .. },
124                    ..
125                })
126            )
127        }));
128        assert!(
129            thread
130                .entries()
131                .iter()
132                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
133        );
134    });
135}
136
137pub async fn test_tool_call_with_confirmation(
138    server: impl AgentServer + 'static,
139    cx: &mut TestAppContext,
140) {
141    let fs = init_test(cx).await;
142    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
143    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
144    let full_turn = thread.update(cx, |thread, cx| {
145        thread.send_raw(
146            r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
147            cx,
148        )
149    });
150
151    run_until_first_tool_call(
152        &thread,
153        |entry| {
154            matches!(
155                entry,
156                AgentThreadEntry::ToolCall(ToolCall {
157                    status: ToolCallStatus::WaitingForConfirmation { .. },
158                    ..
159                })
160            )
161        },
162        cx,
163    )
164    .await;
165
166    let tool_call_id = thread.read_with(cx, |thread, _cx| {
167        let AgentThreadEntry::ToolCall(ToolCall {
168            id,
169            content,
170            status: ToolCallStatus::WaitingForConfirmation { .. },
171            ..
172        }) = &thread
173            .entries()
174            .iter()
175            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
176            .unwrap()
177        else {
178            panic!();
179        };
180
181        assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
182
183        id.clone()
184    });
185
186    thread.update(cx, |thread, cx| {
187        thread.authorize_tool_call(
188            tool_call_id,
189            acp::PermissionOptionId("0".into()),
190            acp::PermissionOptionKind::AllowOnce,
191            cx,
192        );
193
194        assert!(thread.entries().iter().any(|entry| matches!(
195            entry,
196            AgentThreadEntry::ToolCall(ToolCall {
197                status: ToolCallStatus::Allowed { .. },
198                ..
199            })
200        )));
201    });
202
203    full_turn.await.unwrap();
204
205    thread.read_with(cx, |thread, cx| {
206        let AgentThreadEntry::ToolCall(ToolCall {
207            content,
208            status: ToolCallStatus::Allowed { .. },
209            ..
210        }) = thread
211            .entries()
212            .iter()
213            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
214            .unwrap()
215        else {
216            panic!();
217        };
218
219        assert!(
220            content.iter().any(|c| c.to_markdown(cx).contains("Hello")),
221            "Expected content to contain 'Hello'"
222        );
223    });
224}
225
226pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
227    let fs = init_test(cx).await;
228
229    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
230    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
231    let full_turn = thread.update(cx, |thread, cx| {
232        thread.send_raw(
233            r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
234            cx,
235        )
236    });
237
238    let first_tool_call_ix = run_until_first_tool_call(
239        &thread,
240        |entry| {
241            matches!(
242                entry,
243                AgentThreadEntry::ToolCall(ToolCall {
244                    status: ToolCallStatus::WaitingForConfirmation { .. },
245                    ..
246                })
247            )
248        },
249        cx,
250    )
251    .await;
252
253    thread.read_with(cx, |thread, _cx| {
254        let AgentThreadEntry::ToolCall(ToolCall {
255            id,
256            content,
257            status: ToolCallStatus::WaitingForConfirmation { .. },
258            ..
259        }) = &thread.entries()[first_tool_call_ix]
260        else {
261            panic!("{:?}", thread.entries()[1]);
262        };
263
264        assert!(content.iter().any(|c| c.to_markdown(_cx).contains("touch")));
265
266        id.clone()
267    });
268
269    let _ = thread.update(cx, |thread, cx| thread.cancel(cx));
270    full_turn.await.unwrap();
271    thread.read_with(cx, |thread, _| {
272        let AgentThreadEntry::ToolCall(ToolCall {
273            status: ToolCallStatus::Canceled,
274            ..
275        }) = &thread.entries()[first_tool_call_ix]
276        else {
277            panic!();
278        };
279    });
280
281    thread
282        .update(cx, |thread, cx| {
283            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
284        })
285        .await
286        .unwrap();
287    thread.read_with(cx, |thread, _| {
288        assert!(matches!(
289            &thread.entries().last().unwrap(),
290            AgentThreadEntry::AssistantMessage(..),
291        ))
292    });
293}
294
295#[macro_export]
296macro_rules! common_e2e_tests {
297    ($server:expr) => {
298        mod common_e2e {
299            use super::*;
300
301            #[::gpui::test]
302            #[cfg_attr(not(feature = "e2e"), ignore)]
303            async fn basic(cx: &mut ::gpui::TestAppContext) {
304                $crate::e2e_tests::test_basic($server, cx).await;
305            }
306
307            #[::gpui::test]
308            #[cfg_attr(not(feature = "e2e"), ignore)]
309            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
310                $crate::e2e_tests::test_path_mentions($server, cx).await;
311            }
312
313            #[::gpui::test]
314            #[cfg_attr(not(feature = "e2e"), ignore)]
315            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
316                $crate::e2e_tests::test_tool_call($server, cx).await;
317            }
318
319            #[::gpui::test]
320            #[cfg_attr(not(feature = "e2e"), ignore)]
321            async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
322                $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
323            }
324
325            #[::gpui::test]
326            #[cfg_attr(not(feature = "e2e"), ignore)]
327            async fn cancel(cx: &mut ::gpui::TestAppContext) {
328                $crate::e2e_tests::test_cancel($server, cx).await;
329            }
330        }
331    };
332}
333
334// Helpers
335
336pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
337    env_logger::try_init().ok();
338
339    cx.update(|cx| {
340        let settings_store = SettingsStore::test(cx);
341        cx.set_global(settings_store);
342        Project::init_settings(cx);
343        language::init(cx);
344        crate::settings::init(cx);
345
346        crate::AllAgentServersSettings::override_global(
347            AllAgentServersSettings {
348                claude: Some(AgentServerSettings {
349                    command: crate::claude::tests::local_command(),
350                }),
351                gemini: Some(AgentServerSettings {
352                    command: crate::gemini::tests::local_command(),
353                }),
354                codex: Some(AgentServerSettings {
355                    command: crate::codex::tests::local_command(),
356                }),
357            },
358            cx,
359        );
360    });
361
362    cx.executor().allow_parking();
363
364    FakeFs::new(cx.executor())
365}
366
367pub async fn new_test_thread(
368    server: impl AgentServer + 'static,
369    project: Entity<Project>,
370    current_dir: impl AsRef<Path>,
371    cx: &mut TestAppContext,
372) -> Entity<AcpThread> {
373    let connection = cx
374        .update(|cx| server.connect(current_dir.as_ref(), &project, cx))
375        .await
376        .unwrap();
377
378    let thread = connection
379        .new_thread(project.clone(), current_dir.as_ref(), &mut cx.to_async())
380        .await
381        .unwrap();
382
383    thread
384}
385
386pub async fn run_until_first_tool_call(
387    thread: &Entity<AcpThread>,
388    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
389    cx: &mut TestAppContext,
390) -> usize {
391    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
392
393    let subscription = cx.update(|cx| {
394        cx.subscribe(thread, move |thread, _, cx| {
395            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
396                if wait_until(entry) {
397                    return tx.try_send(ix).unwrap();
398                }
399            }
400        })
401    });
402
403    select! {
404        // We have to use a smol timer here because
405        // cx.background_executor().timer isn't real in the test context
406        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
407            panic!("Timeout waiting for tool call")
408        }
409        ix = rx.next().fuse() => {
410            drop(subscription);
411            ix.unwrap()
412        }
413    }
414}