e2e_tests.rs

  1use std::{path::Path, sync::Arc, time::Duration};
  2
  3use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  4use acp_thread::{
  5    AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
  6};
  7use agentic_coding_protocol as acp;
  8use futures::{FutureExt, StreamExt, channel::mpsc, select};
  9use gpui::{Entity, TestAppContext};
 10use indoc::indoc;
 11use project::{FakeFs, Project};
 12use serde_json::json;
 13use settings::{Settings, SettingsStore};
 14use util::path;
 15
 16pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 17    let fs = init_test(cx).await;
 18    let project = Project::test(fs, [], cx).await;
 19    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
 20
 21    thread
 22        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 23        .await
 24        .unwrap();
 25
 26    thread.read_with(cx, |thread, _| {
 27        assert_eq!(thread.entries().len(), 2);
 28        assert!(matches!(
 29            thread.entries()[0],
 30            AgentThreadEntry::UserMessage(_)
 31        ));
 32        assert!(matches!(
 33            thread.entries()[1],
 34            AgentThreadEntry::AssistantMessage(_)
 35        ));
 36    });
 37}
 38
 39pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 40    let _fs = init_test(cx).await;
 41
 42    let tempdir = tempfile::tempdir().unwrap();
 43    std::fs::write(
 44        tempdir.path().join("foo.rs"),
 45        indoc! {"
 46            fn main() {
 47                println!(\"Hello, world!\");
 48            }
 49        "},
 50    )
 51    .expect("failed to write file");
 52    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 53    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 54    thread
 55        .update(cx, |thread, cx| {
 56            thread.send(
 57                acp::SendUserMessageParams {
 58                    chunks: vec![
 59                        acp::UserMessageChunk::Text {
 60                            text: "Read the file ".into(),
 61                        },
 62                        acp::UserMessageChunk::Path {
 63                            path: Path::new("foo.rs").into(),
 64                        },
 65                        acp::UserMessageChunk::Text {
 66                            text: " and tell me what the content of the println! is".into(),
 67                        },
 68                    ],
 69                },
 70                cx,
 71            )
 72        })
 73        .await
 74        .unwrap();
 75
 76    thread.read_with(cx, |thread, cx| {
 77        assert_eq!(thread.entries().len(), 3);
 78        assert!(matches!(
 79            thread.entries()[0],
 80            AgentThreadEntry::UserMessage(_)
 81        ));
 82        assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
 83        let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
 84            panic!("Expected AssistantMessage")
 85        };
 86        assert!(
 87            assistant_message.to_markdown(cx).contains("Hello, world!"),
 88            "unexpected assistant message: {:?}",
 89            assistant_message.to_markdown(cx)
 90        );
 91    });
 92}
 93
 94pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 95    let fs = init_test(cx).await;
 96    fs.insert_tree(
 97        path!("/private/tmp"),
 98        json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
 99    )
100    .await;
101    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
102    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
103
104    thread
105        .update(cx, |thread, cx| {
106            thread.send_raw(
107                "Read the '/private/tmp/foo' file and tell me what you see.",
108                cx,
109            )
110        })
111        .await
112        .unwrap();
113    thread.read_with(cx, |thread, _cx| {
114        assert!(matches!(
115            &thread.entries()[2],
116            AgentThreadEntry::ToolCall(ToolCall {
117                status: ToolCallStatus::Allowed { .. },
118                ..
119            })
120        ));
121
122        assert!(matches!(
123            thread.entries()[3],
124            AgentThreadEntry::AssistantMessage(_)
125        ));
126    });
127}
128
129pub async fn test_tool_call_with_confirmation(
130    server: impl AgentServer + 'static,
131    cx: &mut TestAppContext,
132) {
133    let fs = init_test(cx).await;
134    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
135    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
136    let full_turn = thread.update(cx, |thread, cx| {
137        thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
138    });
139
140    run_until_first_tool_call(&thread, cx).await;
141
142    let tool_call_id = thread.read_with(cx, |thread, _cx| {
143        let AgentThreadEntry::ToolCall(ToolCall {
144            id,
145            status:
146                ToolCallStatus::WaitingForConfirmation {
147                    confirmation: ToolCallConfirmation::Execute { root_command, .. },
148                    ..
149                },
150            ..
151        }) = &thread.entries()[2]
152        else {
153            panic!();
154        };
155
156        assert_eq!(root_command, "echo");
157
158        *id
159    });
160
161    thread.update(cx, |thread, cx| {
162        thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
163
164        assert!(matches!(
165            &thread.entries()[2],
166            AgentThreadEntry::ToolCall(ToolCall {
167                status: ToolCallStatus::Allowed { .. },
168                ..
169            })
170        ));
171    });
172
173    full_turn.await.unwrap();
174
175    thread.read_with(cx, |thread, cx| {
176        let AgentThreadEntry::ToolCall(ToolCall {
177            content: Some(ToolCallContent::Markdown { markdown }),
178            status: ToolCallStatus::Allowed { .. },
179            ..
180        }) = &thread.entries()[2]
181        else {
182            panic!();
183        };
184
185        markdown.read_with(cx, |md, _cx| {
186            assert!(
187                md.source().contains("Hello, world!"),
188                r#"Expected '{}' to contain "Hello, world!""#,
189                md.source()
190            );
191        });
192    });
193}
194
195pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
196    let fs = init_test(cx).await;
197
198    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
199    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
200    let full_turn = thread.update(cx, |thread, cx| {
201        thread.send_raw(r#"Run `echo "Hello, world!"`"#, cx)
202    });
203
204    let first_tool_call_ix = run_until_first_tool_call(&thread, cx).await;
205
206    thread.read_with(cx, |thread, _cx| {
207        let AgentThreadEntry::ToolCall(ToolCall {
208            id,
209            status:
210                ToolCallStatus::WaitingForConfirmation {
211                    confirmation: ToolCallConfirmation::Execute { root_command, .. },
212                    ..
213                },
214            ..
215        }) = &thread.entries()[first_tool_call_ix]
216        else {
217            panic!("{:?}", thread.entries()[1]);
218        };
219
220        assert_eq!(root_command, "echo");
221
222        *id
223    });
224
225    thread
226        .update(cx, |thread, cx| thread.cancel(cx))
227        .await
228        .unwrap();
229    full_turn.await.unwrap();
230    thread.read_with(cx, |thread, _| {
231        let AgentThreadEntry::ToolCall(ToolCall {
232            status: ToolCallStatus::Canceled,
233            ..
234        }) = &thread.entries()[first_tool_call_ix]
235        else {
236            panic!();
237        };
238    });
239
240    thread
241        .update(cx, |thread, cx| {
242            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
243        })
244        .await
245        .unwrap();
246    thread.read_with(cx, |thread, _| {
247        assert!(matches!(
248            &thread.entries().last().unwrap(),
249            AgentThreadEntry::AssistantMessage(..),
250        ))
251    });
252}
253
254#[macro_export]
255macro_rules! common_e2e_tests {
256    ($server:expr) => {
257        mod common_e2e {
258            use super::*;
259
260            #[::gpui::test]
261            #[cfg_attr(not(feature = "e2e"), ignore)]
262            async fn basic(cx: &mut ::gpui::TestAppContext) {
263                $crate::e2e_tests::test_basic($server, cx).await;
264            }
265
266            #[::gpui::test]
267            #[cfg_attr(not(feature = "e2e"), ignore)]
268            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
269                $crate::e2e_tests::test_path_mentions($server, cx).await;
270            }
271
272            #[::gpui::test]
273            #[cfg_attr(not(feature = "e2e"), ignore)]
274            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
275                $crate::e2e_tests::test_tool_call($server, cx).await;
276            }
277
278            #[::gpui::test]
279            #[cfg_attr(not(feature = "e2e"), ignore)]
280            async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
281                $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
282            }
283
284            #[::gpui::test]
285            #[cfg_attr(not(feature = "e2e"), ignore)]
286            async fn cancel(cx: &mut ::gpui::TestAppContext) {
287                $crate::e2e_tests::test_cancel($server, cx).await;
288            }
289        }
290    };
291}
292
293// Helpers
294
295pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
296    env_logger::try_init().ok();
297
298    cx.update(|cx| {
299        let settings_store = SettingsStore::test(cx);
300        cx.set_global(settings_store);
301        Project::init_settings(cx);
302        language::init(cx);
303        crate::settings::init(cx);
304
305        crate::AllAgentServersSettings::override_global(
306            AllAgentServersSettings {
307                claude: Some(AgentServerSettings {
308                    command: crate::claude::tests::local_command(),
309                }),
310                gemini: Some(AgentServerSettings {
311                    command: crate::gemini::tests::local_command(),
312                }),
313            },
314            cx,
315        );
316    });
317
318    cx.executor().allow_parking();
319
320    FakeFs::new(cx.executor())
321}
322
323pub async fn new_test_thread(
324    server: impl AgentServer + 'static,
325    project: Entity<Project>,
326    current_dir: impl AsRef<Path>,
327    cx: &mut TestAppContext,
328) -> Entity<AcpThread> {
329    let thread = cx
330        .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
331        .await
332        .unwrap();
333
334    thread
335        .update(cx, |thread, _| thread.initialize())
336        .await
337        .unwrap();
338    thread
339}
340
341pub async fn run_until_first_tool_call(
342    thread: &Entity<AcpThread>,
343    cx: &mut TestAppContext,
344) -> usize {
345    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
346
347    let subscription = cx.update(|cx| {
348        cx.subscribe(thread, move |thread, _, cx| {
349            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
350                if matches!(entry, AgentThreadEntry::ToolCall(_)) {
351                    return tx.try_send(ix).unwrap();
352                }
353            }
354        })
355    });
356
357    select! {
358        // We have to use a smol timer here because
359        // cx.background_executor().timer isn't real in the test context
360        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(10))) => {
361            panic!("Timeout waiting for tool call")
362        }
363        ix = rx.next().fuse() => {
364            drop(subscription);
365            ix.unwrap()
366        }
367    }
368}