Don't poll foreground futures during `DeterministicExecutor::block_on`

Antonio Scandurra created

Change summary

gpui/src/executor.rs | 104 ++++++++++++++++++++++++++++++++++-----------
1 file changed, 78 insertions(+), 26 deletions(-)

Detailed changes

gpui/src/executor.rs 🔗

@@ -44,7 +44,8 @@ pub enum Background {
 struct DeterministicState {
     rng: StdRng,
     seed: u64,
-    scheduled: Vec<(Runnable, Backtrace)>,
+    scheduled_from_foreground: Vec<(Runnable, Backtrace)>,
+    scheduled_from_background: Vec<(Runnable, Backtrace)>,
     spawned_from_foreground: Vec<(Runnable, Backtrace)>,
     forbid_parking: bool,
     block_on_ticks: RangeInclusive<usize>,
@@ -61,7 +62,8 @@ impl Deterministic {
             state: Arc::new(Mutex::new(DeterministicState {
                 rng: StdRng::seed_from_u64(seed),
                 seed,
-                scheduled: Default::default(),
+                scheduled_from_foreground: Default::default(),
+                scheduled_from_background: Default::default(),
                 spawned_from_foreground: Default::default(),
                 forbid_parking: false,
                 block_on_ticks: 0..=1000,
@@ -83,7 +85,7 @@ impl Deterministic {
             let mut state = state.lock();
             let backtrace = backtrace.clone();
             if scheduled_once.fetch_or(true, SeqCst) {
-                state.scheduled.push((runnable, backtrace));
+                state.scheduled_from_foreground.push((runnable, backtrace));
             } else {
                 state.spawned_from_foreground.push((runnable, backtrace));
             }
@@ -103,7 +105,9 @@ impl Deterministic {
         let unparker = self.parker.lock().unparker();
         let (runnable, task) = async_task::spawn(future, move |runnable| {
             let mut state = state.lock();
-            state.scheduled.push((runnable, backtrace.clone()));
+            state
+                .scheduled_from_background
+                .push((runnable, backtrace.clone()));
             unparker.unpark();
         });
         runnable.schedule();
@@ -115,10 +119,66 @@ impl Deterministic {
         T: 'static,
         F: Future<Output = T> + 'static,
     {
-        self.block_on(usize::MAX, future).unwrap()
+        smol::pin!(future);
+
+        let unparker = self.parker.lock().unparker();
+        let waker = waker_fn(move || {
+            unparker.unpark();
+        });
+
+        let mut cx = Context::from_waker(&waker);
+        let mut trace = Trace::default();
+        loop {
+            let mut state = self.state.lock();
+            let runnable_count = state.scheduled_from_foreground.len()
+                + state.scheduled_from_background.len()
+                + state.spawned_from_foreground.len();
+
+            let ix = state.rng.gen_range(0..=runnable_count);
+            if ix < state.scheduled_from_foreground.len() {
+                let (_, backtrace) = &state.scheduled_from_foreground[ix];
+                trace.record(&state, backtrace.clone());
+                let runnable = state.scheduled_from_foreground.remove(ix).0;
+                drop(state);
+                runnable.run();
+            } else if ix - state.scheduled_from_foreground.len()
+                < state.scheduled_from_background.len()
+            {
+                let ix = ix - state.scheduled_from_foreground.len();
+                let (_, backtrace) = &state.scheduled_from_background[ix];
+                trace.record(&state, backtrace.clone());
+                let runnable = state.scheduled_from_background.remove(ix).0;
+                drop(state);
+                runnable.run();
+            } else if ix < runnable_count {
+                let (_, backtrace) = &state.spawned_from_foreground[0];
+                trace.record(&state, backtrace.clone());
+                let runnable = state.spawned_from_foreground.remove(0).0;
+                drop(state);
+                runnable.run();
+            } else {
+                drop(state);
+                if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
+                    return result;
+                }
+                let state = self.state.lock();
+                if state.scheduled_from_foreground.is_empty()
+                    && state.scheduled_from_background.is_empty()
+                    && state.spawned_from_foreground.is_empty()
+                {
+                    if state.forbid_parking {
+                        panic!("deterministic executor parked after a call to forbid_parking");
+                    }
+                    drop(state);
+                    self.parker.lock().park();
+                }
+
+                continue;
+            }
+        }
     }
 
-    pub fn block_on<F, T>(&self, max_ticks: usize, future: F) -> Option<T>
+    pub fn block_on<F, T>(&self, future: F) -> Option<T>
     where
         T: 'static,
         F: Future<Output = T>,
@@ -129,23 +189,22 @@ impl Deterministic {
         let waker = waker_fn(move || {
             unparker.unpark();
         });
+        let max_ticks = {
+            let mut state = self.state.lock();
+            let range = state.block_on_ticks.clone();
+            state.rng.gen_range(range)
+        };
 
         let mut cx = Context::from_waker(&waker);
         let mut trace = Trace::default();
         for _ in 0..max_ticks {
             let mut state = self.state.lock();
-            let runnable_count = state.scheduled.len() + state.spawned_from_foreground.len();
+            let runnable_count = state.scheduled_from_background.len();
             let ix = state.rng.gen_range(0..=runnable_count);
-            if ix < state.scheduled.len() {
-                let (_, backtrace) = &state.scheduled[ix];
-                trace.record(&state, backtrace.clone());
-                let runnable = state.scheduled.remove(ix).0;
-                drop(state);
-                runnable.run();
-            } else if ix < runnable_count {
-                let (_, backtrace) = &state.spawned_from_foreground[0];
+            if ix < state.scheduled_from_background.len() {
+                let (_, backtrace) = &state.scheduled_from_background[ix];
                 trace.record(&state, backtrace.clone());
-                let runnable = state.spawned_from_foreground.remove(0).0;
+                let runnable = state.scheduled_from_background.remove(ix).0;
                 drop(state);
                 runnable.run();
             } else {
@@ -154,7 +213,7 @@ impl Deterministic {
                     return Some(result);
                 }
                 let state = self.state.lock();
-                if state.scheduled.is_empty() && state.spawned_from_foreground.is_empty() {
+                if state.scheduled_from_background.is_empty() {
                     if state.forbid_parking {
                         panic!("deterministic executor parked after a call to forbid_parking");
                     }
@@ -181,7 +240,7 @@ impl Trace {
     fn record(&mut self, state: &DeterministicState, executed: Backtrace) {
         self.scheduled.push(
             state
-                .scheduled
+                .scheduled_from_foreground
                 .iter()
                 .map(|(_, backtrace)| backtrace.clone())
                 .collect(),
@@ -394,14 +453,7 @@ impl Background {
             Self::Production { .. } => {
                 smol::block_on(util::timeout(timeout, Pin::new(&mut future))).ok()
             }
-            Self::Deterministic(executor) => {
-                let max_ticks = {
-                    let mut state = executor.state.lock();
-                    let range = state.block_on_ticks.clone();
-                    state.rng.gen_range(range)
-                };
-                executor.block_on(max_ticks, Pin::new(&mut future))
-            }
+            Self::Deterministic(executor) => executor.block_on(Pin::new(&mut future)),
         };
 
         if let Some(output) = output {