e2e_tests.rs

  1use std::{path::Path, sync::Arc, time::Duration};
  2
  3use crate::{AgentServer, AgentServerSettings, AllAgentServersSettings};
  4use acp_thread::{
  5    AcpThread, AgentThreadEntry, ToolCall, ToolCallConfirmation, ToolCallContent, ToolCallStatus,
  6};
  7use agentic_coding_protocol as acp_old;
  8use futures::{FutureExt, StreamExt, channel::mpsc, select};
  9use gpui::{Entity, TestAppContext};
 10use indoc::indoc;
 11use project::{FakeFs, Project};
 12use serde_json::json;
 13use settings::{Settings, SettingsStore};
 14use util::path;
 15
 16pub async fn test_basic(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 17    let fs = init_test(cx).await;
 18    let project = Project::test(fs, [], cx).await;
 19    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
 20
 21    thread
 22        .update(cx, |thread, cx| thread.send_raw("Hello from Zed!", cx))
 23        .await
 24        .unwrap();
 25
 26    thread.read_with(cx, |thread, _| {
 27        assert_eq!(thread.entries().len(), 2);
 28        assert!(matches!(
 29            thread.entries()[0],
 30            AgentThreadEntry::UserMessage(_)
 31        ));
 32        assert!(matches!(
 33            thread.entries()[1],
 34            AgentThreadEntry::AssistantMessage(_)
 35        ));
 36    });
 37}
 38
 39pub async fn test_path_mentions(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 40    let _fs = init_test(cx).await;
 41
 42    let tempdir = tempfile::tempdir().unwrap();
 43    std::fs::write(
 44        tempdir.path().join("foo.rs"),
 45        indoc! {"
 46            fn main() {
 47                println!(\"Hello, world!\");
 48            }
 49        "},
 50    )
 51    .expect("failed to write file");
 52    let project = Project::example([tempdir.path()], &mut cx.to_async()).await;
 53    let thread = new_test_thread(server, project.clone(), tempdir.path(), cx).await;
 54    thread
 55        .update(cx, |thread, cx| {
 56            thread.send(
 57                acp_old::SendUserMessageParams {
 58                    chunks: vec![
 59                        acp_old::UserMessageChunk::Text {
 60                            text: "Read the file ".into(),
 61                        },
 62                        acp_old::UserMessageChunk::Path {
 63                            path: Path::new("foo.rs").into(),
 64                        },
 65                        acp_old::UserMessageChunk::Text {
 66                            text: " and tell me what the content of the println! is".into(),
 67                        },
 68                    ],
 69                },
 70                cx,
 71            )
 72        })
 73        .await
 74        .unwrap();
 75
 76    thread.read_with(cx, |thread, cx| {
 77        assert_eq!(thread.entries().len(), 3);
 78        assert!(matches!(
 79            thread.entries()[0],
 80            AgentThreadEntry::UserMessage(_)
 81        ));
 82        assert!(matches!(thread.entries()[1], AgentThreadEntry::ToolCall(_)));
 83        let AgentThreadEntry::AssistantMessage(assistant_message) = &thread.entries()[2] else {
 84            panic!("Expected AssistantMessage")
 85        };
 86        assert!(
 87            assistant_message.to_markdown(cx).contains("Hello, world!"),
 88            "unexpected assistant message: {:?}",
 89            assistant_message.to_markdown(cx)
 90        );
 91    });
 92}
 93
 94pub async fn test_tool_call(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
 95    let fs = init_test(cx).await;
 96    fs.insert_tree(
 97        path!("/private/tmp"),
 98        json!({"foo": "Lorem ipsum dolor", "bar": "bar", "baz": "baz"}),
 99    )
100    .await;
101    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
102    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
103
104    thread
105        .update(cx, |thread, cx| {
106            thread.send_raw(
107                "Read the '/private/tmp/foo' file and tell me what you see.",
108                cx,
109            )
110        })
111        .await
112        .unwrap();
113    thread.read_with(cx, |thread, _cx| {
114        assert!(thread.entries().iter().any(|entry| {
115            matches!(
116                entry,
117                AgentThreadEntry::ToolCall(ToolCall {
118                    status: ToolCallStatus::Allowed { .. },
119                    ..
120                })
121            )
122        }));
123        assert!(
124            thread
125                .entries()
126                .iter()
127                .any(|entry| { matches!(entry, AgentThreadEntry::AssistantMessage(_)) })
128        );
129    });
130}
131
132pub async fn test_tool_call_with_confirmation(
133    server: impl AgentServer + 'static,
134    cx: &mut TestAppContext,
135) {
136    let fs = init_test(cx).await;
137    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
138    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
139    let full_turn = thread.update(cx, |thread, cx| {
140        thread.send_raw(
141            r#"Run `touch hello.txt && echo "Hello, world!" | tee hello.txt`"#,
142            cx,
143        )
144    });
145
146    run_until_first_tool_call(
147        &thread,
148        |entry| {
149            matches!(
150                entry,
151                AgentThreadEntry::ToolCall(ToolCall {
152                    status: ToolCallStatus::WaitingForConfirmation { .. },
153                    ..
154                })
155            )
156        },
157        cx,
158    )
159    .await;
160
161    let tool_call_id = thread.read_with(cx, |thread, _cx| {
162        let AgentThreadEntry::ToolCall(ToolCall {
163            id,
164            content: Some(content),
165            status: ToolCallStatus::WaitingForConfirmation { .. },
166            ..
167        }) = &thread
168            .entries()
169            .iter()
170            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
171            .unwrap()
172        else {
173            panic!();
174        };
175
176        assert!(content.to_markdown(cx).contains("touch"));
177
178        *id
179    });
180
181    thread.update(cx, |thread, cx| {
182        thread.authorize_tool_call(
183            tool_call_id,
184            acp_old::ToolCallConfirmationOutcome::Allow,
185            cx,
186        );
187
188        assert!(thread.entries().iter().any(|entry| matches!(
189            entry,
190            AgentThreadEntry::ToolCall(ToolCall {
191                status: ToolCallStatus::Allowed { .. },
192                ..
193            })
194        )));
195    });
196
197    full_turn.await.unwrap();
198
199    thread.read_with(cx, |thread, cx| {
200        let AgentThreadEntry::ToolCall(ToolCall {
201            content: Some(ToolCallContent::Markdown { markdown }),
202            status: ToolCallStatus::Allowed { .. },
203            ..
204        }) = thread
205            .entries()
206            .iter()
207            .find(|entry| matches!(entry, AgentThreadEntry::ToolCall(_)))
208            .unwrap()
209        else {
210            panic!();
211        };
212
213        markdown.read_with(cx, |md, _cx| {
214            assert!(
215                md.source().contains("Hello"),
216                r#"Expected '{}' to contain "Hello""#,
217                md.source()
218            );
219        });
220    });
221}
222
223pub async fn test_cancel(server: impl AgentServer + 'static, cx: &mut TestAppContext) {
224    let fs = init_test(cx).await;
225
226    let project = Project::test(fs, [path!("/private/tmp").as_ref()], cx).await;
227    let thread = new_test_thread(server, project.clone(), "/private/tmp", cx).await;
228    let full_turn = thread.update(cx, |thread, cx| {
229        thread.send_raw(
230            r#"Run `touch hello.txt && echo "Hello, world!" >> hello.txt`"#,
231            cx,
232        )
233    });
234
235    let first_tool_call_ix = run_until_first_tool_call(
236        &thread,
237        |entry| {
238            matches!(
239                entry,
240                AgentThreadEntry::ToolCall(ToolCall {
241                    status: ToolCallStatus::WaitingForConfirmation { .. },
242                    ..
243                })
244            )
245        },
246        cx,
247    )
248    .await;
249
250    thread.read_with(cx, |thread, _cx| {
251        let AgentThreadEntry::ToolCall(ToolCall {
252            id,
253            content: Some(content),
254            status: ToolCallStatus::WaitingForConfirmation { .. },
255            ..
256        }) = &thread.entries()[first_tool_call_ix]
257        else {
258            panic!("{:?}", thread.entries()[1]);
259        };
260
261        assert!(content.to_markdown(cx).contains("touch"));
262
263        *id
264    });
265
266    thread
267        .update(cx, |thread, cx| thread.cancel(cx))
268        .await
269        .unwrap();
270    full_turn.await.unwrap();
271    thread.read_with(cx, |thread, _| {
272        let AgentThreadEntry::ToolCall(ToolCall {
273            status: ToolCallStatus::Canceled,
274            ..
275        }) = &thread.entries()[first_tool_call_ix]
276        else {
277            panic!();
278        };
279    });
280
281    thread
282        .update(cx, |thread, cx| {
283            thread.send_raw(r#"Stop running and say goodbye to me."#, cx)
284        })
285        .await
286        .unwrap();
287    thread.read_with(cx, |thread, _| {
288        assert!(matches!(
289            &thread.entries().last().unwrap(),
290            AgentThreadEntry::AssistantMessage(..),
291        ))
292    });
293}
294
295#[macro_export]
296macro_rules! common_e2e_tests {
297    ($server:expr) => {
298        mod common_e2e {
299            use super::*;
300
301            #[::gpui::test]
302            #[cfg_attr(not(feature = "e2e"), ignore)]
303            async fn basic(cx: &mut ::gpui::TestAppContext) {
304                $crate::e2e_tests::test_basic($server, cx).await;
305            }
306
307            #[::gpui::test]
308            #[cfg_attr(not(feature = "e2e"), ignore)]
309            async fn path_mentions(cx: &mut ::gpui::TestAppContext) {
310                $crate::e2e_tests::test_path_mentions($server, cx).await;
311            }
312
313            #[::gpui::test]
314            #[cfg_attr(not(feature = "e2e"), ignore)]
315            async fn tool_call(cx: &mut ::gpui::TestAppContext) {
316                $crate::e2e_tests::test_tool_call($server, cx).await;
317            }
318
319            #[::gpui::test]
320            #[cfg_attr(not(feature = "e2e"), ignore)]
321            async fn tool_call_with_confirmation(cx: &mut ::gpui::TestAppContext) {
322                $crate::e2e_tests::test_tool_call_with_confirmation($server, cx).await;
323            }
324
325            #[::gpui::test]
326            #[cfg_attr(not(feature = "e2e"), ignore)]
327            async fn cancel(cx: &mut ::gpui::TestAppContext) {
328                $crate::e2e_tests::test_cancel($server, cx).await;
329            }
330        }
331    };
332}
333
334// Helpers
335
336pub async fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
337    env_logger::try_init().ok();
338
339    cx.update(|cx| {
340        let settings_store = SettingsStore::test(cx);
341        cx.set_global(settings_store);
342        Project::init_settings(cx);
343        language::init(cx);
344        crate::settings::init(cx);
345
346        crate::AllAgentServersSettings::override_global(
347            AllAgentServersSettings {
348                claude: Some(AgentServerSettings {
349                    command: crate::claude::tests::local_command(),
350                }),
351                codex: Some(AgentServerSettings {
352                    command: crate::codex::tests::local_command(),
353                }),
354                gemini: Some(AgentServerSettings {
355                    command: crate::gemini::tests::local_command(),
356                }),
357            },
358            cx,
359        );
360    });
361
362    cx.executor().allow_parking();
363
364    FakeFs::new(cx.executor())
365}
366
367pub async fn new_test_thread(
368    server: impl AgentServer + 'static,
369    project: Entity<Project>,
370    current_dir: impl AsRef<Path>,
371    cx: &mut TestAppContext,
372) -> Entity<AcpThread> {
373    let thread = cx
374        .update(|cx| server.new_thread(current_dir.as_ref(), &project, cx))
375        .await
376        .unwrap();
377
378    thread
379        .update(cx, |thread, _| thread.initialize())
380        .await
381        .unwrap();
382    thread
383}
384
385pub async fn run_until_first_tool_call(
386    thread: &Entity<AcpThread>,
387    wait_until: impl Fn(&AgentThreadEntry) -> bool + 'static,
388    cx: &mut TestAppContext,
389) -> usize {
390    let (mut tx, mut rx) = mpsc::channel::<usize>(1);
391
392    let subscription = cx.update(|cx| {
393        cx.subscribe(thread, move |thread, _, cx| {
394            for (ix, entry) in thread.read(cx).entries().iter().enumerate() {
395                if wait_until(entry) {
396                    return tx.try_send(ix).unwrap();
397                }
398            }
399        })
400    });
401
402    select! {
403        // We have to use a smol timer here because
404        // cx.background_executor().timer isn't real in the test context
405        _ = futures::FutureExt::fuse(smol::Timer::after(Duration::from_secs(20))) => {
406            panic!("Timeout waiting for tool call")
407        }
408        ix = rx.next().fuse() => {
409            drop(subscription);
410            ix.unwrap()
411        }
412    }
413}