Allow DeterministicExecutor to park until forbid_parking is called

Nathan Sobo created

This allows us to perform async setup such as talking to the database.

Change summary

Cargo.lock           |   3 
gpui/Cargo.toml      |   3 
gpui/src/executor.rs | 106 +++++++++++++++++++++++++++++----------------
server/src/tests.rs  |  17 ++++++
4 files changed, 89 insertions(+), 40 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2159,13 +2159,13 @@ dependencies = [
  "etagere",
  "font-kit",
  "foreign-types",
- "futures",
  "gpui_macros",
  "log",
  "metal",
  "num_cpus",
  "objc",
  "ordered-float",
+ "parking",
  "parking_lot",
  "pathfinder_color",
  "pathfinder_geometry",
@@ -2183,6 +2183,7 @@ dependencies = [
  "tiny-skia",
  "tree-sitter",
  "usvg",
+ "waker-fn",
 ]
 
 [[package]]

gpui/Cargo.toml 🔗

@@ -9,11 +9,11 @@ async-task = "4.0.3"
 backtrace = "0.3"
 ctor = "0.1"
 etagere = "0.2"
-futures = "0.3"
 gpui_macros = { path = "../gpui_macros" }
 log = "0.4"
 num_cpus = "1.13"
 ordered-float = "2.1.1"
+parking = "2.0.0"
 parking_lot = "0.11.1"
 pathfinder_color = "0.5"
 pathfinder_geometry = "0.5"
@@ -29,6 +29,7 @@ smol = "1.2"
 tiny-skia = "0.5"
 tree-sitter = "0.19"
 usvg = "0.14"
+waker-fn = "1.1.0"
 
 [build-dependencies]
 bindgen = "0.58.1"

gpui/src/executor.rs 🔗

@@ -2,7 +2,6 @@ use anyhow::{anyhow, Result};
 use async_task::Runnable;
 pub use async_task::Task;
 use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
-use futures::task::noop_waker;
 use parking_lot::Mutex;
 use rand::prelude::*;
 use smol::{channel, prelude::*, Executor};
@@ -20,6 +19,7 @@ use std::{
     thread,
     time::Duration,
 };
+use waker_fn::waker_fn;
 
 use crate::{platform, util};
 
@@ -45,18 +45,26 @@ struct DeterministicState {
     seed: u64,
     scheduled: Vec<(Runnable, Backtrace)>,
     spawned_from_foreground: Vec<(Runnable, Backtrace)>,
+    forbid_parking: bool,
 }
 
-pub struct Deterministic(Arc<Mutex<DeterministicState>>);
+pub struct Deterministic {
+    state: Arc<Mutex<DeterministicState>>,
+    parker: Mutex<parking::Parker>,
+}
 
 impl Deterministic {
     fn new(seed: u64) -> Self {
-        Self(Arc::new(Mutex::new(DeterministicState {
-            rng: StdRng::seed_from_u64(seed),
-            seed,
-            scheduled: Default::default(),
-            spawned_from_foreground: Default::default(),
-        })))
+        Self {
+            state: Arc::new(Mutex::new(DeterministicState {
+                rng: StdRng::seed_from_u64(seed),
+                seed,
+                scheduled: Default::default(),
+                spawned_from_foreground: Default::default(),
+                forbid_parking: false,
+            })),
+            parker: Default::default(),
+        }
     }
 
     pub fn spawn_from_foreground<F, T>(&self, future: F) -> Task<T>
@@ -66,7 +74,8 @@ impl Deterministic {
     {
         let backtrace = Backtrace::new_unresolved();
         let scheduled_once = AtomicBool::new(false);
-        let state = self.0.clone();
+        let state = self.state.clone();
+        let unparker = self.parker.lock().unparker();
         let (runnable, task) = async_task::spawn_local(future, move |runnable| {
             let mut state = state.lock();
             let backtrace = backtrace.clone();
@@ -75,6 +84,7 @@ impl Deterministic {
             } else {
                 state.spawned_from_foreground.push((runnable, backtrace));
             }
+            unparker.unpark();
         });
         runnable.schedule();
         task
@@ -86,9 +96,12 @@ impl Deterministic {
         F: 'static + Send + Future<Output = T>,
     {
         let backtrace = Backtrace::new_unresolved();
-        let state = self.0.clone();
+        let state = self.state.clone();
+        let unparker = self.parker.lock().unparker();
         let (runnable, task) = async_task::spawn(future, move |runnable| {
-            state.lock().scheduled.push((runnable, backtrace.clone()));
+            let mut state = state.lock();
+            state.scheduled.push((runnable, backtrace.clone()));
+            unparker.unpark();
         });
         runnable.schedule();
         task
@@ -109,36 +122,45 @@ impl Deterministic {
     {
         smol::pin!(future);
 
-        let waker = noop_waker();
+        let unparker = self.parker.lock().unparker();
+        let waker = waker_fn(move || {
+            unparker.unpark();
+        });
+
         let mut cx = Context::from_waker(&waker);
         let mut trace = Trace::default();
         for _ in 0..max_ticks {
-            let runnable = {
-                let state = &mut *self.0.lock();
-                let runnable_count = state.scheduled.len() + state.spawned_from_foreground.len();
-                let ix = state.rng.gen_range(0..=runnable_count);
-                if ix < state.scheduled.len() {
-                    let (_, backtrace) = &state.scheduled[ix];
-                    trace.record(&state, backtrace.clone());
-                    state.scheduled.remove(ix).0
-                } else if ix < runnable_count {
-                    let (_, backtrace) = &state.spawned_from_foreground[0];
-                    trace.record(&state, backtrace.clone());
-                    state.spawned_from_foreground.remove(0).0
-                } else {
-                    if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
-                        return Some(result);
-                    }
-
-                    if state.scheduled.is_empty() && state.spawned_from_foreground.is_empty() {
-                        panic!("detected non-determinism in deterministic executor");
-                    } else {
-                        continue;
+            let mut state = self.state.lock();
+            let runnable_count = state.scheduled.len() + state.spawned_from_foreground.len();
+            let ix = state.rng.gen_range(0..=runnable_count);
+            if ix < state.scheduled.len() {
+                let (_, backtrace) = &state.scheduled[ix];
+                trace.record(&state, backtrace.clone());
+                let runnable = state.scheduled.remove(ix).0;
+                drop(state);
+                runnable.run();
+            } else if ix < runnable_count {
+                let (_, backtrace) = &state.spawned_from_foreground[0];
+                trace.record(&state, backtrace.clone());
+                let runnable = state.spawned_from_foreground.remove(0).0;
+                drop(state);
+                runnable.run();
+            } else {
+                drop(state);
+                if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
+                    return Some(result);
+                }
+                let state = &mut *self.state.lock();
+                if state.scheduled.is_empty() && state.spawned_from_foreground.is_empty() {
+                    if state.forbid_parking {
+                        panic!("deterministic executor parked after a call to forbid_parking");
                     }
+                    drop(state);
+                    self.parker.lock().park();
                 }
-            };
 
-            runnable.run();
+                continue;
+            }
         }
 
         None
@@ -311,11 +333,21 @@ impl Foreground {
             Self::Platform { .. } => panic!("can't call this method on a platform executor"),
             Self::Test(_) => panic!("can't call this method on a test executor"),
             Self::Deterministic(executor) => {
-                let state = &mut *executor.0.lock();
+                let state = &mut *executor.state.lock();
                 state.rng = StdRng::seed_from_u64(state.seed);
             }
         }
     }
+
+    pub fn forbid_parking(&self) {
+        match self {
+            Self::Platform { .. } => panic!("can't call this method on a platform executor"),
+            Self::Test(_) => panic!("can't call this method on a test executor"),
+            Self::Deterministic(executor) => {
+                executor.state.lock().forbid_parking = true;
+            }
+        }
+    }
 }
 
 impl Background {
@@ -363,7 +395,7 @@ impl Background {
                 smol::block_on(async move { util::timeout(timeout, future).await.ok() })
             }
             Self::Deterministic(executor) => {
-                let max_ticks = executor.0.lock().rng.gen_range(1..=1000);
+                let max_ticks = executor.state.lock().rng.gen_range(1..=1000);
                 executor.block_on(max_ticks, future)
             }
         }

server/src/tests.rs 🔗

@@ -35,6 +35,8 @@ async fn test_share_worktree(mut cx_a: TestAppContext, mut cx_b: TestAppContext)
     let client_a = server.create_client(&mut cx_a, "user_a").await;
     let client_b = server.create_client(&mut cx_b, "user_b").await;
 
+    cx_a.foreground().forbid_parking();
+
     // Share a local worktree as client A
     let fs = Arc::new(FakeFs::new());
     fs.insert_tree(
@@ -141,6 +143,8 @@ async fn test_propagate_saves_and_fs_changes_in_shared_worktree(
     let client_b = server.create_client(&mut cx_b, "user_b").await;
     let client_c = server.create_client(&mut cx_c, "user_c").await;
 
+    cx_a.foreground().forbid_parking();
+
     let fs = Arc::new(FakeFs::new());
 
     // Share a worktree as client A.
@@ -280,6 +284,8 @@ async fn test_buffer_conflict_after_save(mut cx_a: TestAppContext, mut cx_b: Tes
     let client_a = server.create_client(&mut cx_a, "user_a").await;
     let client_b = server.create_client(&mut cx_b, "user_b").await;
 
+    cx_a.foreground().forbid_parking();
+
     // Share a local worktree as client A
     let fs = Arc::new(FakeFs::new());
     fs.save(Path::new("/a.txt"), &"a-contents".into())
@@ -357,6 +363,8 @@ async fn test_editing_while_guest_opens_buffer(mut cx_a: TestAppContext, mut cx_
     let client_a = server.create_client(&mut cx_a, "user_a").await;
     let client_b = server.create_client(&mut cx_b, "user_b").await;
 
+    cx_a.foreground().forbid_parking();
+
     // Share a local worktree as client A
     let fs = Arc::new(FakeFs::new());
     fs.save(Path::new("/a.txt"), &"a-contents".into())
@@ -416,6 +424,8 @@ async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext)
     let client_a = server.create_client(&mut cx_a, "user_a").await;
     let client_b = server.create_client(&mut cx_a, "user_b").await;
 
+    cx_a.foreground().forbid_parking();
+
     // Share a local worktree as client A
     let fs = Arc::new(FakeFs::new());
     fs.insert_tree(
@@ -529,7 +539,12 @@ impl TestServer {
             .connect(&config.database_url)
             .await
             .expect("failed to connect to postgres database");
-        let migrator = Migrator::new(Path::new("./migrations")).await.unwrap();
+        let migrator = Migrator::new(Path::new(concat!(
+            env!("CARGO_MANIFEST_DIR"),
+            "/migrations"
+        )))
+        .await
+        .unwrap();
         migrator.run(&db).await.unwrap();
 
         let github_client = github::AppClient::test();