Allow different app contexts to race

Antonio Scandurra , Max , Kyle , and Conrad created

Co-Authored-By: Max <max@zed.dev>
Co-Authored-By: Kyle <kyle@zed.dev>
Co-Authored-By: Conrad <conrad@zed.dev>

Change summary

crates/gpui2/src/platform/test/dispatcher.rs | 88 ++++++++++++++++-----
1 file changed, 65 insertions(+), 23 deletions(-)

Detailed changes

crates/gpui2/src/platform/test/dispatcher.rs 🔗

@@ -1,57 +1,87 @@
 use crate::PlatformDispatcher;
 use async_task::Runnable;
-use collections::{BTreeMap, VecDeque};
+use collections::{BTreeMap, HashMap, VecDeque};
 use parking_lot::Mutex;
 use rand::prelude::*;
-use std::time::{Duration, Instant};
-
-pub struct TestDispatcher(Mutex<TestDispatcherState>);
+use std::{
+    sync::Arc,
+    time::{Duration, Instant},
+};
+use util::post_inc;
+
+#[derive(Copy, Clone, PartialEq, Eq, Hash)]
+struct TestDispatcherId(usize);
+
+pub struct TestDispatcher {
+    id: TestDispatcherId,
+    state: Arc<Mutex<TestDispatcherState>>,
+}
 
 struct TestDispatcherState {
     random: StdRng,
-    foreground: VecDeque<Runnable>,
+    foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
     background: Vec<Runnable>,
     delayed: BTreeMap<Instant, Runnable>,
     time: Instant,
     is_main_thread: bool,
+    next_id: TestDispatcherId,
 }
 
 impl TestDispatcher {
     pub fn new(random: StdRng) -> Self {
         let state = TestDispatcherState {
             random,
-            foreground: VecDeque::new(),
+            foreground: HashMap::default(),
             background: Vec::new(),
             delayed: BTreeMap::new(),
             time: Instant::now(),
             is_main_thread: true,
+            next_id: TestDispatcherId(1),
         };
 
-        TestDispatcher(Mutex::new(state))
+        TestDispatcher {
+            id: TestDispatcherId(0),
+            state: Arc::new(Mutex::new(state)),
+        }
+    }
+}
+
+impl Clone for TestDispatcher {
+    fn clone(&self) -> Self {
+        let id = post_inc(&mut self.state.lock().next_id.0);
+        Self {
+            id: TestDispatcherId(id),
+            state: self.state.clone(),
+        }
     }
 }
 
 impl PlatformDispatcher for TestDispatcher {
     fn is_main_thread(&self) -> bool {
-        self.0.lock().is_main_thread
+        self.state.lock().is_main_thread
     }
 
     fn dispatch(&self, runnable: Runnable) {
-        self.0.lock().background.push(runnable);
+        self.state.lock().background.push(runnable);
     }
 
     fn dispatch_on_main_thread(&self, runnable: Runnable) {
-        self.0.lock().foreground.push_back(runnable);
+        self.state
+            .lock()
+            .foreground
+            .entry(self.id)
+            .or_default()
+            .push_back(runnable);
     }
 
     fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
-        let mut state = self.0.lock();
+        let mut state = self.state.lock();
         let next_time = state.time + duration;
         state.delayed.insert(next_time, runnable);
     }
 
     fn poll(&self) -> bool {
-        let mut state = self.0.lock();
+        let mut state = self.state.lock();
 
         while let Some((deadline, _)) = state.delayed.first_key_value() {
             if *deadline > state.time {
@@ -61,36 +91,48 @@ impl PlatformDispatcher for TestDispatcher {
             state.background.push(runnable);
         }
 
-        if state.foreground.is_empty() && state.background.is_empty() {
+        let foreground_len: usize = state
+            .foreground
+            .values()
+            .map(|runnables| runnables.len())
+            .sum();
+        let background_len = state.background.len();
+
+        if foreground_len == 0 && background_len == 0 {
             return false;
         }
 
-        let foreground_len = state.foreground.len();
-        let background_len = state.background.len();
-        let main_thread = background_len == 0
-            || state
-                .random
-                .gen_ratio(foreground_len as u32, background_len as u32);
+        let main_thread = state.random.gen_ratio(
+            foreground_len as u32,
+            (foreground_len + background_len) as u32,
+        );
         let was_main_thread = state.is_main_thread;
         state.is_main_thread = main_thread;
 
         let runnable = if main_thread {
-            state.foreground.pop_front().unwrap()
+            let state = &mut *state;
+            let runnables = state
+                .foreground
+                .values_mut()
+                .filter(|runnables| !runnables.is_empty())
+                .choose(&mut state.random)
+                .unwrap();
+            runnables.pop_front().unwrap()
         } else {
             let ix = state.random.gen_range(0..background_len);
-            state.background.remove(ix)
+            state.background.swap_remove(ix)
         };
 
         drop(state);
         runnable.run();
 
-        self.0.lock().is_main_thread = was_main_thread;
+        self.state.lock().is_main_thread = was_main_thread;
 
         true
     }
 
     fn advance_clock(&self, by: Duration) {
-        self.0.lock().time += by;
+        self.state.lock().time += by;
     }
 }