Allow resetting `executor::Deterministic`'s RNG

Antonio Scandurra created

Change summary

gpui/src/executor.rs | 80 +++++++++++++++++++++++++++------------------
1 file changed, 47 insertions(+), 33 deletions(-)

Detailed changes

gpui/src/executor.rs 🔗

@@ -9,7 +9,11 @@ use std::{
     mem,
     pin::Pin,
     rc::Rc,
-    sync::{mpsc::SyncSender, Arc},
+    sync::{
+        atomic::{AtomicBool, Ordering::SeqCst},
+        mpsc::SyncSender,
+        Arc,
+    },
     thread,
 };
 
@@ -33,24 +37,25 @@ pub enum Background {
     },
 }
 
-#[derive(Default)]
-struct Runnables {
+struct DeterministicState {
+    rng: StdRng,
+    seed: u64,
     scheduled: Vec<Runnable>,
     spawned_from_foreground: Vec<Runnable>,
     waker: Option<SyncSender<()>>,
 }
 
-pub struct Deterministic {
-    seed: u64,
-    runnables: Arc<Mutex<Runnables>>,
-}
+pub struct Deterministic(Arc<Mutex<DeterministicState>>);
 
 impl Deterministic {
     fn new(seed: u64) -> Self {
-        Self {
+        Self(Arc::new(Mutex::new(DeterministicState {
+            rng: StdRng::seed_from_u64(seed),
             seed,
-            runnables: Default::default(),
-        }
+            scheduled: Default::default(),
+            spawned_from_foreground: Default::default(),
+            waker: None,
+        })))
     }
 
     pub fn spawn_from_foreground<F, T>(&self, future: F) -> Task<T>
@@ -58,17 +63,16 @@ impl Deterministic {
         T: 'static,
         F: Future<Output = T> + 'static,
     {
-        let scheduled_once = Mutex::new(false);
-        let runnables = self.runnables.clone();
+        let scheduled_once = AtomicBool::new(false);
+        let state = self.0.clone();
         let (runnable, task) = async_task::spawn_local(future, move |runnable| {
-            let mut runnables = runnables.lock();
-            if *scheduled_once.lock() {
-                runnables.scheduled.push(runnable);
+            let mut state = state.lock();
+            if scheduled_once.fetch_or(true, SeqCst) {
+                state.scheduled.push(runnable);
             } else {
-                runnables.spawned_from_foreground.push(runnable);
-                *scheduled_once.lock() = true;
+                state.spawned_from_foreground.push(runnable);
             }
-            if let Some(waker) = runnables.waker.as_ref() {
+            if let Some(waker) = state.waker.as_ref() {
                 waker.send(()).ok();
             }
         });
@@ -81,11 +85,11 @@ impl Deterministic {
         T: 'static + Send,
         F: 'static + Send + Future<Output = T>,
     {
-        let runnables = self.runnables.clone();
+        let state = self.0.clone();
         let (runnable, task) = async_task::spawn(future, move |runnable| {
-            let mut runnables = runnables.lock();
-            runnables.scheduled.push(runnable);
-            if let Some(waker) = runnables.waker.as_ref() {
+            let mut state = state.lock();
+            state.scheduled.push(runnable);
+            if let Some(waker) = state.waker.as_ref() {
                 waker.send(()).ok();
             }
         });
@@ -99,8 +103,8 @@ impl Deterministic {
         F: Future<Output = T> + 'static,
     {
         let (wake_tx, wake_rx) = std::sync::mpsc::sync_channel(32);
-        let runnables = self.runnables.clone();
-        runnables.lock().waker = Some(wake_tx);
+        let state = self.0.clone();
+        state.lock().waker = Some(wake_tx);
 
         let (output_tx, output_rx) = std::sync::mpsc::channel();
         self.spawn_from_foreground(async move {
@@ -109,23 +113,22 @@ impl Deterministic {
         })
         .detach();
 
-        let mut rng = StdRng::seed_from_u64(self.seed);
         loop {
             if let Ok(value) = output_rx.try_recv() {
-                runnables.lock().waker = None;
+                state.lock().waker = None;
                 return value;
             }
 
             wake_rx.recv().unwrap();
             let runnable = {
-                let mut runnables = runnables.lock();
-                let ix = rng.gen_range(
-                    0..runnables.scheduled.len() + runnables.spawned_from_foreground.len(),
-                );
-                if ix < runnables.scheduled.len() {
-                    runnables.scheduled.remove(ix)
+                let state = &mut *state.lock();
+                let ix = state
+                    .rng
+                    .gen_range(0..state.scheduled.len() + state.spawned_from_foreground.len());
+                if ix < state.scheduled.len() {
+                    state.scheduled.remove(ix)
                 } else {
-                    runnables.spawned_from_foreground.remove(0)
+                    state.spawned_from_foreground.remove(0)
                 }
             };
 
@@ -171,6 +174,17 @@ impl Foreground {
             Self::Deterministic(executor) => executor.run(future),
         }
     }
+
+    pub fn reset(&self) {
+        match self {
+            Self::Platform { .. } => panic!("can't call this method on a platform executor"),
+            Self::Test(_) => panic!("can't call this method on a test executor"),
+            Self::Deterministic(executor) => {
+                let state = &mut *executor.0.lock();
+                state.rng = StdRng::seed_from_u64(state.seed);
+            }
+        }
+    }
 }
 
 impl Background {