Get most tests passing when respecting wake order for foreground tasks in Deterministic executor

Nathan Sobo and Max Brunsfeld created

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>

Change summary

crates/gpui/src/executor.rs | 81 ++++++++++++++------------------------
1 file changed, 30 insertions(+), 51 deletions(-)

Detailed changes

crates/gpui/src/executor.rs 🔗

@@ -73,7 +73,7 @@ unsafe impl<T: Send> Send for Task<T> {}
 struct DeterministicState {
     rng: StdRng,
     seed: u64,
-    scheduled_from_foreground: HashMap<usize, Vec<ScheduledForeground>>,
+    scheduled_from_foreground: HashMap<usize, Vec<ForegroundRunnable>>,
     scheduled_from_background: Vec<Runnable>,
     forbid_parking: bool,
     block_on_ticks: RangeInclusive<usize>,
@@ -82,9 +82,9 @@ struct DeterministicState {
     waiting_backtrace: Option<Backtrace>,
 }
 
-enum ScheduledForeground {
-    MainFuture,
-    Runnable(Runnable),
+struct ForegroundRunnable {
+    runnable: Runnable,
+    main: bool,
 }
 
 pub struct Deterministic {
@@ -123,7 +123,12 @@ impl Deterministic {
         })
     }
 
-    fn spawn_from_foreground(&self, cx_id: usize, future: AnyLocalFuture) -> AnyLocalTask {
+    fn spawn_from_foreground(
+        &self,
+        cx_id: usize,
+        future: AnyLocalFuture,
+        main: bool,
+    ) -> AnyLocalTask {
         let state = self.state.clone();
         let unparker = self.parker.lock().unparker();
         let (runnable, task) = async_task::spawn_local(future, move |runnable| {
@@ -132,7 +137,7 @@ impl Deterministic {
                 .scheduled_from_foreground
                 .entry(cx_id)
                 .or_default()
-                .push(ScheduledForeground::Runnable(runnable));
+                .push(ForegroundRunnable { runnable, main });
             unparker.unpark();
         });
         runnable.schedule();
@@ -151,10 +156,12 @@ impl Deterministic {
         task
     }
 
-    fn run(&self, cx_id: usize, mut future: AnyLocalFuture) -> Box<dyn Any> {
+    fn run(&self, cx_id: usize, main_future: AnyLocalFuture) -> Box<dyn Any> {
         let woken = Arc::new(AtomicBool::new(false));
+        let mut main_task = self.spawn_from_foreground(cx_id, main_future, true);
+
         loop {
-            if let Some(result) = self.run_internal(cx_id, woken.clone(), &mut future) {
+            if let Some(result) = self.run_internal(woken.clone(), Some(&mut main_task)) {
                 return result;
             }
 
@@ -167,44 +174,20 @@ impl Deterministic {
         }
     }
 
-    fn run_until_parked(&self, cx_id: usize) {
+    fn run_until_parked(&self) {
         let woken = Arc::new(AtomicBool::new(false));
-        let mut future = any_local_future(std::future::pending::<()>());
-        self.run_internal(cx_id, woken, &mut future);
+        self.run_internal(woken, None);
     }
 
     fn run_internal(
         &self,
-        cx_id: usize,
         woken: Arc<AtomicBool>,
-        future: &mut AnyLocalFuture,
+        mut main_task: Option<&mut AnyLocalTask>,
     ) -> Option<Box<dyn Any>> {
         let unparker = self.parker.lock().unparker();
-        let scheduled_main_future = Arc::new(AtomicBool::new(true));
-        self.state
-            .lock()
-            .scheduled_from_foreground
-            .entry(cx_id)
-            .or_default()
-            .insert(0, ScheduledForeground::MainFuture);
-
-        let waker = waker_fn({
-            let state = self.state.clone();
-            let scheduled_main_future = scheduled_main_future.clone();
-            move || {
-                woken.store(true, SeqCst);
-                if !scheduled_main_future.load(SeqCst) {
-                    scheduled_main_future.store(true, SeqCst);
-                    state
-                        .lock()
-                        .scheduled_from_foreground
-                        .entry(cx_id)
-                        .or_default()
-                        .push(ScheduledForeground::MainFuture);
-                }
-
-                unparker.unpark();
-            }
+        let waker = waker_fn(move || {
+            woken.store(true, SeqCst);
+            unparker.unpark();
         });
 
         let mut cx = Context::from_waker(&waker);
@@ -234,25 +217,21 @@ impl Deterministic {
                     .scheduled_from_foreground
                     .get_mut(&cx_id_to_run)
                     .unwrap();
-                let runnable = scheduled_from_cx.remove(0);
+                let foreground_runnable = scheduled_from_cx.remove(0);
                 if scheduled_from_cx.is_empty() {
                     state.scheduled_from_foreground.remove(&cx_id_to_run);
                 }
 
                 drop(state);
-                match runnable {
-                    ScheduledForeground::MainFuture => {
-                        scheduled_main_future.store(false, SeqCst);
-                        if let Poll::Ready(result) = future.poll(&mut cx) {
+
+                foreground_runnable.runnable.run();
+                if let Some(main_task) = main_task.as_mut() {
+                    if foreground_runnable.main {
+                        if let Poll::Ready(result) = main_task.poll(&mut cx) {
                             return Some(result);
                         }
                     }
-                    ScheduledForeground::Runnable(runnable) => {
-                        runnable.run();
-                    }
                 }
-            } else {
-                return None;
             }
         }
     }
@@ -364,7 +343,7 @@ impl Foreground {
         let future = any_local_future(future);
         let any_task = match self {
             Self::Deterministic { cx_id, executor } => {
-                executor.spawn_from_foreground(*cx_id, future)
+                executor.spawn_from_foreground(*cx_id, future, false)
             }
             Self::Platform { dispatcher, .. } => {
                 fn spawn_inner(
@@ -448,8 +427,8 @@ impl Foreground {
 
     pub fn advance_clock(&self, duration: Duration) {
         match self {
-            Self::Deterministic { cx_id, executor } => {
-                executor.run_until_parked(*cx_id);
+            Self::Deterministic { executor, .. } => {
+                executor.run_until_parked();
 
                 let mut state = executor.state.lock();
                 state.now += duration;