e2e_tests.rs

  1use std::{path::Path, sync::Arc, time::Duration};
  2
  3use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  4use acp_thread::{
  5    AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
  6};
  7use agentic_coding_protocol as acp;
  8use futures::{FutureExt, StreamExt, channel::mpsc, select};
  9use gpui::{Entity, TestAppContext};
 10use indoc::indoc;
 11use project::{FakeFs, Project};
 12use serde_json::json;
 13use settings::{Settings, SettingsStore};
 14use util::path;
 15
 16pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 17    let fs = init_test(cx).await;
 18    let project = Project::test(fs, [], cx).await;
 19    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
 20
 21    thread
 22        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 23        .await
 24        .unwrap();
 25
 26    thread.read_with(cx, |thread, _| {
 27        assert_eq!(thread.entries().len(), 2);
 28        assert!(matches!(
 29            thread.entries()[0],
 30            AgentThreadEntry::UserMessage(_)
 31        ));
 32        assert!(matches!(
 33            thread.entries()[1],
 34            AgentThreadEntry::AssistantMessage(_)
 35        ));
 36    });
 37}
 38
 39pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 40    let _fs = init_test(cx).await;
 41
 42    let tempdir = tempfile::tempdir().unwrap();
 43    std::fs::write(
 44        tempdir.path().join("foo.rs"),
 45        indoc! {"
 46            fn main() {
 47                println!(\"Hello, world!\");
 48            }
 49        "},
 50    )
 51    .expect("failed to write file");
 52    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 53    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 54    thread
 55        .update(cx, |thread, cx| {
 56            thread.send(
 57                acp::SendUserMessageParams {
 58                    chunks: vec![
 59                        acp::UserMessageChunk::Text {
 60                            text: "Read the file ".into(),
 61                        },
 62                        acp::UserMessageChunk::Path {
 63                            path: Path::new("foo.rs").into(),
 64                        },
 65                        acp::UserMessageChunk::Text {
 66                            text: " and tell me what the content of the println! is".into(),
 67                        },
 68                    ],
 69                },
 70                cx,
 71            )
 72        })
 73        .await
 74        .unwrap();
 75
 76    thread.read_with(cx, |thread, cx| {
 77        assert_eq!(thread.entries().len(), 3);
 78        assert!(matches!(
 79            thread.entries()[0],
 80            AgentThreadEntry::UserMessage(_)
 81        ));
 82        assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
 83        let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
 84            panic!("Expected AssistantMessage")
 85        };
 86        assert!(
 87            assistant_message.to_markdown(cx).contains("Hello, world!"),
 88            "unexpected assistant message: {:?}",
 89            assistant_message.to_markdown(cx)
 90        );
 91    });
 92}
 93
 94pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 95    let fs = init_test(cx).await;
 96    fs.insert_tree(
 97        path!("/private/tmp"),
 98        json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
 99    )
100    .await;
101    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
102    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
103
104    thread
105        .update(cx, |thread, cx| {
106            thread.send_raw(
107                "Read the '/private/tmp/foo' file and tell me what you see.",
108                cx,
109            )
110        })
111        .await
112        .unwrap();
113    thread.read_with(cx, |thread, _cx| {
114        assert!(thread.entries().iter().any(|entry| {
115            matches!(
116                entry,
117                AgentThreadEntry::ToolCall(ToolCall {
118                    status: ToolCallStatus::Allowed { .. },
119                    ..
120                })
121            )
122        }));
123        assert!(
124            thread
125                .entries()
126                .iter()
127                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
128        );
129    });
130}
131
132pub async fn test_tool_call_with_confirmation(
133    server: impl AgentServer + 'static,
134    cx: &mut TestAppContext,
135) {
136    let fs = init_test(cx).await;
137    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
138    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
139    let full_turn = thread.update(cx, |thread, cx| {
140        thread.send_raw(
141            r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
142            cx,
143        )
144    });
145
146    run_until_first_tool_call(
147        &thread,
148        |entry| {
149            matches!(
150                entry,
151                AgentThreadEntry::ToolCall(ToolCall {
152                    status: ToolCallStatus::WaitingForConfirmation { .. },
153                    ..
154                })
155            )
156        },
157        cx,
158    )
159    .await;
160
161    let tool_call_id = thread.read_with(cx, |thread, _cx| {
162        let AgentThreadEntry::ToolCall(ToolCall {
163            id,
164            status:
165                ToolCallStatus::WaitingForConfirmation {
166                    confirmation: ToolCallConfirmation::Execute { root_command, .. },
167                    ..
168                },
169            ..
170        }) = &thread
171            .entries()
172            .iter()
173            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
174            .unwrap()
175        else {
176            panic!();
177        };
178
179        assert!(root_command.contains("touch"));
180
181        *id
182    });
183
184    thread.update(cx, |thread, cx| {
185        thread.authorize_tool_call(tool_call_id, acp::ToolCallConfirmationOutcome::Allow, cx);
186
187        assert!(thread.entries().iter().any(|entry| matches!(
188            entry,
189            AgentThreadEntry::ToolCall(ToolCall {
190                status: ToolCallStatus::Allowed { .. },
191                ..
192            })
193        )));
194    });
195
196    full_turn.await.unwrap();
197
198    thread.read_with(cx, |thread, cx| {
199        let AgentThreadEntry::ToolCall(ToolCall {
200            content: Some(ToolCallContent::Markdown { markdown }),
201            status: ToolCallStatus::Allowed { .. },
202            ..
203        }) = thread
204            .entries()
205            .iter()
206            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
207            .unwrap()
208        else {
209            panic!();
210        };
211
212        markdown.read_with(cx, |md, _cx| {
213            assert!(
214                md.source().contains("Hello"),
215                r#"Expected '{}' to contain "Hello""#,
216                md.source()
217            );
218        });
219    });
220}
221
222pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
223    let fs = init_test(cx).await;
224
225    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
226    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
227    let full_turn = thread.update(cx, |thread, cx| {
228        thread.send_raw(
229            r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
230            cx,
231        )
232    });
233
234    let first_tool_call_ix = run_until_first_tool_call(
235        &thread,
236        |entry| {
237            matches!(
238                entry,
239                AgentThreadEntry::ToolCall(ToolCall {
240                    status: ToolCallStatus::WaitingForConfirmation { .. },
241                    ..
242                })
243            )
244        },
245        cx,
246    )
247    .await;
248
249    thread.read_with(cx, |thread, _cx| {
250        let AgentThreadEntry::ToolCall(ToolCall {
251            id,
252            status:
253                ToolCallStatus::WaitingForConfirmation {
254                    confirmation: ToolCallConfirmation::Execute { root_command, .. },
255                    ..
256                },
257            ..
258        }) = &thread.entries()[first_tool_call_ix]
259        else {
260            panic!("{:?}", thread.entries()[1]);
261        };
262
263        assert!(root_command.contains("touch"));
264
265        *id
266    });
267
268    thread
269        .update(cx, |thread, cx| thread.cancel(cx))
270        .await
271        .unwrap();
272    full_turn.await.unwrap();
273    thread.read_with(cx, |thread, _| {
274        let AgentThreadEntry::ToolCall(ToolCall {
275            status: ToolCallStatus::Canceled,
276            ..
277        }) = &thread.entries()[first_tool_call_ix]
278        else {
279            panic!();
280        };
281    });
282
283    thread
284        .update(cx, |thread, cx| {
285            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
286        })
287        .await
288        .unwrap();
289    thread.read_with(cx, |thread, _| {
290        assert!(matches!(
291            &thread.entries().last().unwrap(),
292            AgentThreadEntry::AssistantMessage(..),
293        ))
294    });
295}
296
297#[macro_export]
298macro_rules! common_e2e_tests {
299    ($server:expr) => {
300        mod common_e2e {
301            use super::*;
302
303            #[::gpui::test]
304            #[cfg_attr(not(feature = "e2e"), ignore)]
305            async fn basic(cx: &mut ::gpui::TestAppContext) {
306                $crate::e2e_tests::test_basic($server, cx).await;
307            }
308
309            #[::gpui::test]
310            #[cfg_attr(not(feature = "e2e"), ignore)]
311            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
312                $crate::e2e_tests::test_path_mentions($server, cx).await;
313            }
314
315            #[::gpui::test]
316            #[cfg_attr(not(feature = "e2e"), ignore)]
317            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
318                $crate::e2e_tests::test_tool_call($server, cx).await;
319            }
320
321            #[::gpui::test]
322            #[cfg_attr(not(feature = "e2e"), ignore)]
323            async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
324                $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
325            }
326
327            #[::gpui::test]
328            #[cfg_attr(not(feature = "e2e"), ignore)]
329            async fn cancel(cx: &mut ::gpui::TestAppContext) {
330                $crate::e2e_tests::test_cancel($server, cx).await;
331            }
332        }
333    };
334}
335
336// Helpers
337
338pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
339    env_logger::try_init().ok();
340
341    cx.update(|cx| {
342        let settings_store = SettingsStore::test(cx);
343        cx.set_global(settings_store);
344        Project::init_settings(cx);
345        language::init(cx);
346        crate::settings::init(cx);
347
348        crate::AllAgentServersSettings::override_global(
349            AllAgentServersSettings {
350                claude: Some(AgentServerSettings {
351                    command: crate::claude::tests::local_command(),
352                }),
353                gemini: Some(AgentServerSettings {
354                    command: crate::gemini::tests::local_command(),
355                }),
356            },
357            cx,
358        );
359    });
360
361    cx.executor().allow_parking();
362
363    FakeFs::new(cx.executor())
364}
365
366pub async fn new_test_thread(
367    server: impl AgentServer + 'static,
368    project: Entity<Project>,
369    current_dir: impl AsRef<Path>,
370    cx: &mut TestAppContext,
371) -> Entity<AcpThread> {
372    let thread = cx
373        .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
374        .await
375        .unwrap();
376
377    thread
378        .update(cx, |thread, _| thread.initialize())
379        .await
380        .unwrap();
381    thread
382}
383
384pub async fn run_until_first_tool_call(
385    thread: &Entity<AcpThread>,
386    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
387    cx: &mut TestAppContext,
388) -> usize {
389    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
390
391    let subscription = cx.update(|cx| {
392        cx.subscribe(thread, move |thread, _, cx| {
393            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
394                if wait_until(entry) {
395                    return tx.try_send(ix).unwrap();
396                }
397            }
398        })
399    });
400
401    select! {
402        // We have to use a smol timer here because
403        // cx.background_executor().timer isn't real in the test context
404        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
405            panic!("Timeout waiting for tool call")
406        }
407        ix = rx.next().fuse() => {
408            drop(subscription);
409            ix.unwrap()
410        }
411    }
412}