Add the ability to deprioritize specific labeled tasks in tests

Max Brunsfeld created

Change summary

crates/gpui2/src/executor.rs                 | 64 ++++++++++++++---
crates/gpui2/src/platform.rs                 |  4 
crates/gpui2/src/platform/mac/dispatcher.rs  |  4 
crates/gpui2/src/platform/test/dispatcher.rs | 79 ++++++++++++++-------
4 files changed, 108 insertions(+), 43 deletions(-)

Detailed changes

crates/gpui2/src/executor.rs 🔗

@@ -5,10 +5,11 @@ use std::{
     fmt::Debug,
     marker::PhantomData,
     mem,
+    num::NonZeroUsize,
     pin::Pin,
     rc::Rc,
     sync::{
-        atomic::{AtomicBool, Ordering::SeqCst},
+        atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
         Arc,
     },
     task::{Context, Poll},
@@ -71,30 +72,57 @@ impl<T> Future for Task<T> {
         }
     }
 }
+
+#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
+pub struct TaskLabel(NonZeroUsize);
+
+impl TaskLabel {
+    pub fn new() -> Self {
+        static NEXT_TASK_LABEL: AtomicUsize = AtomicUsize::new(1);
+        Self(NEXT_TASK_LABEL.fetch_add(1, SeqCst).try_into().unwrap())
+    }
+}
+
 type AnyLocalFuture<R> = Pin<Box<dyn 'static + Future<Output = R>>>;
+
 type AnyFuture<R> = Pin<Box<dyn 'static + Send + Future<Output = R>>>;
+
 impl BackgroundExecutor {
     pub fn new(dispatcher: Arc<dyn PlatformDispatcher>) -> Self {
         Self { dispatcher }
     }
 
-    /// Enqueues the given closure to be run on any thread. The closure returns
-    /// a future which will be run to completion on any available thread.
+    /// Enqueues the given future to be run to completion on a background thread.
     pub fn spawn<R>(&self, future: impl Future<Output = R> + Send + 'static) -> Task<R>
     where
         R: Send + 'static,
     {
+        self.spawn_internal::<R>(Box::pin(future), None)
+    }
+
+    /// Enqueues the given future to be run to completion on a background thread.
+    /// The given label can be used to control the priority of the task in tests.
+    pub fn spawn_labeled<R>(
+        &self,
+        label: TaskLabel,
+        future: impl Future<Output = R> + Send + 'static,
+    ) -> Task<R>
+    where
+        R: Send + 'static,
+    {
+        self.spawn_internal::<R>(Box::pin(future), Some(label))
+    }
+
+    fn spawn_internal<R: Send + 'static>(
+        &self,
+        future: AnyFuture<R>,
+        label: Option<TaskLabel>,
+    ) -> Task<R> {
         let dispatcher = self.dispatcher.clone();
-        fn inner<R: Send + 'static>(
-            dispatcher: Arc<dyn PlatformDispatcher>,
-            future: AnyFuture<R>,
-        ) -> Task<R> {
-            let (runnable, task) =
-                async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable));
-            runnable.schedule();
-            Task::Spawned(task)
-        }
-        inner::<R>(dispatcher, Box::pin(future))
+        let (runnable, task) =
+            async_task::spawn(future, move |runnable| dispatcher.dispatch(runnable, label));
+        runnable.schedule();
+        Task::Spawned(task)
     }
 
     #[cfg(any(test, feature = "test-support"))]
@@ -216,11 +244,21 @@ impl BackgroundExecutor {
         self.dispatcher.as_test().unwrap().simulate_random_delay()
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn deprioritize_task(&self, task_label: TaskLabel) {
+        self.dispatcher.as_test().unwrap().deprioritize(task_label)
+    }
+
     #[cfg(any(test, feature = "test-support"))]
     pub fn advance_clock(&self, duration: Duration) {
         self.dispatcher.as_test().unwrap().advance_clock(duration)
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub fn run_step(&self) -> bool {
+        self.dispatcher.as_test().unwrap().poll(false)
+    }
+
     #[cfg(any(test, feature = "test-support"))]
     pub fn run_until_parked(&self) {
         self.dispatcher.as_test().unwrap().run_until_parked()

crates/gpui2/src/platform.rs 🔗

@@ -8,7 +8,7 @@ use crate::{
     point, size, AnyWindowHandle, BackgroundExecutor, Bounds, DevicePixels, Font, FontId,
     FontMetrics, FontRun, ForegroundExecutor, GlobalPixels, GlyphId, InputEvent, LineLayout,
     Pixels, Point, RenderGlyphParams, RenderImageParams, RenderSvgParams, Result, Scene,
-    SharedString, Size,
+    SharedString, Size, TaskLabel,
 };
 use anyhow::{anyhow, bail};
 use async_task::Runnable;
@@ -162,7 +162,7 @@ pub(crate) trait PlatformWindow {
 
 pub trait PlatformDispatcher: Send + Sync {
     fn is_main_thread(&self) -> bool;
-    fn dispatch(&self, runnable: Runnable);
+    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>);
     fn dispatch_on_main_thread(&self, runnable: Runnable);
     fn dispatch_after(&self, duration: Duration, runnable: Runnable);
     fn poll(&self, background_only: bool) -> bool;

crates/gpui2/src/platform/mac/dispatcher.rs 🔗

@@ -2,7 +2,7 @@
 #![allow(non_camel_case_types)]
 #![allow(non_snake_case)]
 
-use crate::PlatformDispatcher;
+use crate::{PlatformDispatcher, TaskLabel};
 use async_task::Runnable;
 use objc::{
     class, msg_send,
@@ -37,7 +37,7 @@ impl PlatformDispatcher for MacDispatcher {
         is_main_thread == YES
     }
 
-    fn dispatch(&self, runnable: Runnable) {
+    fn dispatch(&self, runnable: Runnable, _: Option<TaskLabel>) {
         unsafe {
             dispatch_async_f(
                 dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_DEFAULT.try_into().unwrap(), 0),

crates/gpui2/src/platform/test/dispatcher.rs 🔗

@@ -1,7 +1,7 @@
-use crate::PlatformDispatcher;
+use crate::{PlatformDispatcher, TaskLabel};
 use async_task::Runnable;
 use backtrace::Backtrace;
-use collections::{HashMap, VecDeque};
+use collections::{HashMap, HashSet, VecDeque};
 use parking::{Parker, Unparker};
 use parking_lot::Mutex;
 use rand::prelude::*;
@@ -28,12 +28,14 @@ struct TestDispatcherState {
     random: StdRng,
     foreground: HashMap<TestDispatcherId, VecDeque<Runnable>>,
     background: Vec<Runnable>,
+    deprioritized_background: Vec<Runnable>,
     delayed: Vec<(Duration, Runnable)>,
     time: Duration,
     is_main_thread: bool,
     next_id: TestDispatcherId,
     allow_parking: bool,
     waiting_backtrace: Option<Backtrace>,
+    deprioritized_task_labels: HashSet<TaskLabel>,
 }
 
 impl TestDispatcher {
@@ -43,12 +45,14 @@ impl TestDispatcher {
             random,
             foreground: HashMap::default(),
             background: Vec::new(),
+            deprioritized_background: Vec::new(),
             delayed: Vec::new(),
             time: Duration::ZERO,
             is_main_thread: true,
             next_id: TestDispatcherId(1),
             allow_parking: false,
             waiting_backtrace: None,
+            deprioritized_task_labels: Default::default(),
         };
 
         TestDispatcher {
@@ -101,6 +105,13 @@ impl TestDispatcher {
         }
     }
 
+    pub fn deprioritize(&self, task_label: TaskLabel) {
+        self.state
+            .lock()
+            .deprioritized_task_labels
+            .insert(task_label);
+    }
+
     pub fn run_until_parked(&self) {
         while self.poll(false) {}
     }
@@ -150,8 +161,17 @@ impl PlatformDispatcher for TestDispatcher {
         self.state.lock().is_main_thread
     }
 
-    fn dispatch(&self, runnable: Runnable) {
-        self.state.lock().background.push(runnable);
+    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
+        {
+            let mut state = self.state.lock();
+            if label.map_or(false, |label| {
+                state.deprioritized_task_labels.contains(&label)
+            }) {
+                state.deprioritized_background.push(runnable);
+            } else {
+                state.background.push(runnable);
+            }
+        }
         self.unparker.unpark();
     }
 
@@ -196,34 +216,41 @@ impl PlatformDispatcher for TestDispatcher {
         };
         let background_len = state.background.len();
 
+        let runnable;
+        let main_thread;
         if foreground_len == 0 && background_len == 0 {
-            return false;
-        }
-
-        let main_thread = state.random.gen_ratio(
-            foreground_len as u32,
-            (foreground_len + background_len) as u32,
-        );
-        let was_main_thread = state.is_main_thread;
-        state.is_main_thread = main_thread;
-
-        let runnable = if main_thread {
-            let state = &mut *state;
-            let runnables = state
-                .foreground
-                .values_mut()
-                .filter(|runnables| !runnables.is_empty())
-                .choose(&mut state.random)
-                .unwrap();
-            runnables.pop_front().unwrap()
+            let deprioritized_background_len = state.deprioritized_background.len();
+            if deprioritized_background_len == 0 {
+                return false;
+            }
+            let ix = state.random.gen_range(0..deprioritized_background_len);
+            main_thread = false;
+            runnable = state.deprioritized_background.swap_remove(ix);
         } else {
-            let ix = state.random.gen_range(0..background_len);
-            state.background.swap_remove(ix)
+            main_thread = state.random.gen_ratio(
+                foreground_len as u32,
+                (foreground_len + background_len) as u32,
+            );
+            if main_thread {
+                let state = &mut *state;
+                runnable = state
+                    .foreground
+                    .values_mut()
+                    .filter(|runnables| !runnables.is_empty())
+                    .choose(&mut state.random)
+                    .unwrap()
+                    .pop_front()
+                    .unwrap();
+            } else {
+                let ix = state.random.gen_range(0..background_len);
+                runnable = state.background.swap_remove(ix);
+            };
         };
 
+        let was_main_thread = state.is_main_thread;
+        state.is_main_thread = main_thread;
         drop(state);
         runnable.run();
-
         self.state.lock().is_main_thread = was_main_thread;
 
         true