Add `seed` argument to `#[gpui::test]` attribute macro (#26764)

João Marcos created

This PR introduces the arguments `seed` and `seeds` to `gpui::test`,
e.g.:
- `#[gpui::test(seed = 10)]`
- `#[gpui::test(seeds(10, 20, 30, 40))]`

Which allows us to run a test against a specific seed value without
slowing
down our tests like `iterations` does with high values.

This was motivated by a diff hunk test that only fails in a 400+ seed,
but is
slow to run 400+ times for every `cargo test`.

If your test failed with a specific seed, you can now add the `seed` arg
to
increase the chances of detecting a regression.

There are now three ways of setting seeds, the `SEED` env var,
`iterations`,
and the args this PR adds. See docs in `gpui::test`.

---

I also relaxed the limitation on `retries` not working with
`iterations`, as
that seemed unnecessary.

Release Notes:

- N/A

Change summary

crates/gpui/Cargo.toml                |   1 
crates/gpui/src/test.rs               |  77 +++++++++++---
crates/gpui_macros/src/gpui_macros.rs |  25 +++
crates/gpui_macros/src/test.rs        | 150 ++++++++++++++++------------
4 files changed, 164 insertions(+), 89 deletions(-)

Detailed changes

crates/gpui/Cargo.toml 🔗

@@ -4,6 +4,7 @@ version = "0.1.0"
 edition.workspace = true
 authors = ["Nathan Sobo <nathan@zed.dev>"]
 description = "Zed's GPU-accelerated UI framework"
+repository = "https://github.com/zed-industries/zed"
 publish.workspace = true
 license = "Apache-2.0"
 

crates/gpui/src/test.rs 🔗

@@ -39,23 +39,18 @@ use std::{
 /// This is intended for use with the `gpui::test` macro
 /// and generally should not be used directly.
 pub fn run_test(
-    mut num_iterations: u64,
+    num_iterations: usize,
+    explicit_seeds: &[u64],
     max_retries: usize,
     test_fn: &mut (dyn RefUnwindSafe + Fn(TestDispatcher, u64)),
     on_fail_fn: Option<fn()>,
 ) {
-    let starting_seed = env::var("SEED")
-        .map(|seed| seed.parse().expect("invalid SEED variable"))
-        .unwrap_or(0);
-    if let Ok(iterations) = env::var("ITERATIONS") {
-        num_iterations = iterations.parse().expect("invalid ITERATIONS variable");
-    }
-    let is_randomized = num_iterations > 1;
+    let (seeds, is_multiple_runs) = calculate_seeds(num_iterations as u64, explicit_seeds);
 
-    for seed in starting_seed..starting_seed + num_iterations {
-        let mut retry = 0;
+    for seed in seeds {
+        let mut attempt = 0;
         loop {
-            if is_randomized {
+            if is_multiple_runs {
                 eprintln!("seed = {seed}");
             }
             let result = panic::catch_unwind(|| {
@@ -66,15 +61,15 @@ pub fn run_test(
             match result {
                 Ok(_) => break,
                 Err(error) => {
-                    if retry < max_retries {
-                        println!("retrying: attempt {}", retry);
-                        retry += 1;
+                    if attempt < max_retries {
+                        println!("attempt {} failed, retrying", attempt);
+                        attempt += 1;
                     } else {
-                        if is_randomized {
+                        if is_multiple_runs {
                             eprintln!("failing seed: {}", seed);
                         }
-                        if let Some(f) = on_fail_fn {
-                            f()
+                        if let Some(on_fail_fn) = on_fail_fn {
+                            on_fail_fn()
                         }
                         panic::resume_unwind(error);
                     }
@@ -84,6 +79,54 @@ pub fn run_test(
     }
 }
 
+fn calculate_seeds(
+    iterations: u64,
+    explicit_seeds: &[u64],
+) -> (impl Iterator<Item = u64> + '_, bool) {
+    let iterations = env::var("ITERATIONS")
+        .ok()
+        .map(|var| var.parse().expect("invalid ITERATIONS variable"))
+        .unwrap_or(iterations);
+
+    let env_num = env::var("SEED")
+        .map(|seed| seed.parse().expect("invalid SEED variable as integer"))
+        .ok();
+
+    let empty_range = || 0..0;
+
+    let iter = {
+        let env_range = if let Some(env_num) = env_num {
+            env_num..env_num + 1
+        } else {
+            empty_range()
+        };
+
+        // if `iterations` is 1 and !(`explicit_seeds` is non-empty || `SEED` is set), then add     the run `0`
+        // if `iterations` is 1 and  (`explicit_seeds` is non-empty || `SEED` is set), then discard the run `0`
+        // if `iterations` isn't 1 and `SEED` is set, do `SEED..SEED+iterations`
+        // otherwise, do `0..iterations`
+        let iterations_range = match (iterations, env_num) {
+            (1, None) if explicit_seeds.is_empty() => 0..1,
+            (1, None) | (1, Some(_)) => empty_range(),
+            (iterations, Some(env)) => env..env + iterations,
+            (iterations, None) => 0..iterations,
+        };
+
+        // if `SEED` is set, ignore `explicit_seeds`
+        let explicit_seeds = if env_num.is_some() {
+            &[]
+        } else {
+            explicit_seeds
+        };
+
+        env_range
+            .chain(iterations_range)
+            .chain(explicit_seeds.iter().copied())
+    };
+    let is_multiple_runs = iter.clone().nth(1).is_some();
+    (iter, is_multiple_runs)
+}
+
 /// A test struct for converting an observation callback into a stream.
 pub struct Observation<T> {
     rx: Pin<Box<channel::Receiver<T>>>,

crates/gpui_macros/src/gpui_macros.rs 🔗

@@ -144,7 +144,8 @@ pub fn box_shadow_style_methods(input: TokenStream) -> TokenStream {
 }
 
 /// `#[gpui::test]` can be used to annotate test functions that run with GPUI support.
-/// it supports both synchronous and asynchronous tests, and can provide you with
+///
+/// It supports both synchronous and asynchronous tests, and can provide you with
 /// as many `TestAppContext` instances as you need.
 /// The output contains a `#[test]` annotation so this can be used with any existing
 /// test harness (`cargo test` or `cargo-nextest`).
@@ -160,11 +161,25 @@ pub fn box_shadow_style_methods(input: TokenStream) -> TokenStream {
 /// Using the same `StdRng` for behavior in your test will allow you to exercise a wide
 /// variety of scenarios and interleavings just by changing the seed.
 ///
-/// `#[gpui::test]` also takes three different arguments:
-/// - `#[gpui::test(iterations=10)]` will run the test ten times with a different initial SEED.
-/// - `#[gpui::test(retries=3)]` will run the test up to four times if it fails to try and make it pass.
-/// - `#[gpui::test(on_failure="crate::test::report_failure")]` will call the specified function after the
+/// # Arguments
+///
+/// - `#[gpui::test]` with no arguments runs once with the seed `0` or `SEED` env var if set.
+/// - `#[gpui::test(seed = 10)]` runs once with the seed `10`.
+/// - `#[gpui::test(seeds(10, 20, 30))]` runs three times with seeds `10`, `20`, and `30`.
+/// - `#[gpui::test(iterations = 5)]` runs five times, providing as seed the values in the range `0..5`.
+/// - `#[gpui::test(retries = 3)]` runs up to four times if it fails to try and make it pass.
+/// - `#[gpui::test(on_failure = "crate::test::report_failure")]` will call the specified function after the
 ///    tests fail so that you can write out more detail about the failure.
+///
+/// You can combine `iterations = ...` with `seeds(...)`:
+/// - `#[gpui::test(iterations = 5, seed = 10)]` is equivalent to `#[gpui::test(seeds(0, 1, 2, 3, 4, 10))]`.
+/// - `#[gpui::test(iterations = 5, seeds(10, 20, 30)]` is equivalent to `#[gpui::test(seeds(0, 1, 2, 3, 4, 10, 20, 30))]`.
+/// - `#[gpui::test(seeds(10, 20, 30), iterations = 5]` is equivalent to `#[gpui::test(seeds(0, 1, 2, 3, 4, 10, 20, 30))]`.
+///
+/// # Environment Variables
+///
+/// - `SEED`: sets a seed for the first run
+/// - `ITERATIONS`: forces the value of the `iterations` argument
 #[proc_macro_attribute]
 pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
     test::test(args, function)

crates/gpui_macros/src/test.rs 🔗

@@ -3,73 +3,72 @@ use proc_macro2::Ident;
 use quote::{format_ident, quote};
 use std::mem;
 use syn::{
-    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, FnArg, ItemFn, Lit, Meta,
-    NestedMeta, Type,
+    parse_quote, spanned::Spanned, AttributeArgs, FnArg, ItemFn, Lit, Meta, MetaList, NestedMeta,
+    PathSegment, Type,
 };
 
 pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
     let args = syn::parse_macro_input!(args as AttributeArgs);
+    try_test(args, function).unwrap_or_else(|err| err)
+}
+
+fn try_test(args: Vec<NestedMeta>, function: TokenStream) -> Result<TokenStream, TokenStream> {
+    let mut seeds = Vec::<u64>::new();
     let mut max_retries = 0;
     let mut num_iterations = 1;
     let mut on_failure_fn_name = quote!(None);
 
     for arg in args {
-        match arg {
-            NestedMeta::Meta(Meta::NameValue(meta)) => {
-                let key_name = meta.path.get_ident().map(|i| i.to_string());
-                let result = (|| {
-                    match key_name.as_deref() {
-                        Some("retries") => max_retries = parse_int(&meta.lit)?,
-                        Some("iterations") => num_iterations = parse_int(&meta.lit)?,
-                        Some("on_failure") => {
-                            if let Lit::Str(name) = meta.lit {
-                                let mut path = syn::Path {
-                                    leading_colon: None,
-                                    segments: Default::default(),
-                                };
-                                for part in name.value().split("::") {
-                                    path.segments.push(Ident::new(part, name.span()).into());
-                                }
-                                on_failure_fn_name = quote!(Some(#path));
-                            } else {
-                                return Err(TokenStream::from(
-                                    syn::Error::new(
-                                        meta.lit.span(),
-                                        "on_failure argument must be a string",
-                                    )
-                                    .into_compile_error(),
-                                ));
-                            }
-                        }
-                        _ => {
-                            return Err(TokenStream::from(
-                                syn::Error::new(meta.path.span(), "invalid argument")
-                                    .into_compile_error(),
-                            ))
-                        }
-                    }
-                    Ok(())
-                })();
+        let NestedMeta::Meta(arg) = arg else {
+            return Err(error_with_message("unexpected literal", arg));
+        };
 
-                if let Err(tokens) = result {
-                    return tokens;
-                }
+        let ident = {
+            let meta_path = match &arg {
+                Meta::NameValue(meta) => &meta.path,
+                Meta::List(list) => &list.path,
+                Meta::Path(path) => return Err(error_with_message("invalid path argument", path)),
+            };
+            let Some(ident) = meta_path.get_ident() else {
+                return Err(error_with_message("unexpected path", meta_path));
+            };
+            ident.to_string()
+        };
+
+        match (&arg, ident.as_str()) {
+            (Meta::NameValue(meta), "retries") => max_retries = parse_usize(&meta.lit)?,
+            (Meta::NameValue(meta), "iterations") => num_iterations = parse_usize(&meta.lit)?,
+            (Meta::NameValue(meta), "on_failure") => {
+                let Lit::Str(name) = &meta.lit else {
+                    return Err(error_with_message(
+                        "on_failure argument must be a string",
+                        &meta.lit,
+                    ));
+                };
+                let segments = name
+                    .value()
+                    .split("::")
+                    .map(|part| PathSegment::from(Ident::new(part, name.span())))
+                    .collect();
+                let path = syn::Path {
+                    leading_colon: None,
+                    segments,
+                };
+                on_failure_fn_name = quote!(Some(#path));
             }
-            other => {
-                return TokenStream::from(
-                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
-                )
+            (Meta::NameValue(meta), "seed") => seeds = vec![parse_usize(&meta.lit)? as u64],
+            (Meta::List(list), "seeds") => seeds = parse_u64_array(&list)?,
+            (Meta::Path(path), _) => {
+                return Err(error_with_message("invalid path argument", path));
+            }
+            (_, _) => {
+                return Err(error_with_message("invalid argument name", arg));
             }
         }
     }
+    let seeds = quote!( #(#seeds),* );
 
-    let mut inner_fn = parse_macro_input!(function as ItemFn);
-    if max_retries > 0 && num_iterations > 1 {
-        return TokenStream::from(
-            syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
-                .into_compile_error(),
-        );
-    }
+    let mut inner_fn = syn::parse::<ItemFn>(function).map_err(error_to_stream)?;
     let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
     let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
     let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
@@ -122,9 +121,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
                 }
             }
 
-            return TokenStream::from(
-                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
-            );
+            return Err(error_with_message("invalid function signature", arg));
         }
 
         parse_quote! {
@@ -133,7 +130,8 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
                 #inner_fn
 
                 gpui::run_test(
-                    #num_iterations as u64,
+                    #num_iterations,
+                    &[#seeds],
                     #max_retries,
                     &mut |dispatcher, _seed| {
                         let executor = gpui::BackgroundExecutor::new(std::sync::Arc::new(dispatcher.clone()));
@@ -205,9 +203,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
                 }
             }
 
-            return TokenStream::from(
-                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
-            );
+            return Err(error_with_message("invalid function signature", arg));
         }
 
         parse_quote! {
@@ -216,7 +212,8 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
                 #inner_fn
 
                 gpui::run_test(
-                    #num_iterations as u64,
+                    #num_iterations,
+                    &[#seeds],
                     #max_retries,
                     &mut |dispatcher, _seed| {
                         #cx_vars
@@ -230,15 +227,34 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
     };
     outer_fn.attrs.extend(inner_fn_attributes);
 
-    TokenStream::from(quote!(#outer_fn))
+    Ok(TokenStream::from(quote!(#outer_fn)))
 }
 
-fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
-    let result = if let Lit::Int(int) = &literal {
-        int.base10_parse()
-    } else {
-        Err(syn::Error::new(literal.span(), "must be an integer"))
+fn parse_usize(literal: &Lit) -> Result<usize, TokenStream> {
+    let Lit::Int(int) = &literal else {
+        return Err(error_with_message("expected an usize", literal));
     };
+    int.base10_parse().map_err(error_to_stream)
+}
+
+fn parse_u64_array(meta_list: &MetaList) -> Result<Vec<u64>, TokenStream> {
+    meta_list
+        .nested
+        .iter()
+        .map(|meta| {
+            if let NestedMeta::Lit(literal) = &meta {
+                parse_usize(literal).map(|value| value as u64)
+            } else {
+                Err(error_with_message("expected an integer", meta.span()))
+            }
+        })
+        .collect()
+}
+
+fn error_with_message(message: &str, spanned: impl Spanned) -> TokenStream {
+    error_to_stream(syn::Error::new(spanned.span(), message))
+}
 
-    result.map_err(|err| TokenStream::from(err.into_compile_error()))
+fn error_to_stream(err: syn::Error) -> TokenStream {
+    TokenStream::from(err.into_compile_error())
 }