Revert "Revert windows implementation of "Multiple priority scheduler (#44701)"" (#46066)

Yara 🏳️‍⚧️ created

Reverts zed-industries/zed#44990

Release Notes:

- N/A

Change summary

crates/gpui/src/executor.rs                    | 10 --
crates/gpui/src/gpui.rs                        |  4 
crates/gpui/src/platform/windows/dispatcher.rs | 91 +++++++++++--------
crates/gpui/src/platform/windows/events.rs     |  3 
crates/gpui/src/platform/windows/platform.rs   | 24 ++--
crates/gpui/src/platform/windows/window.rs     |  4 
6 files changed, 71 insertions(+), 65 deletions(-)

Detailed changes

crates/gpui/src/executor.rs 🔗

@@ -366,19 +366,9 @@ impl BackgroundExecutor {
         &self,
         future: AnyFuture<R>,
         label: Option<TaskLabel>,
-        #[cfg_attr(
-            target_os = "windows",
-            expect(
-                unused_variables,
-                reason = "Multi priority scheduler is broken on windows"
-            )
-        )]
         priority: Priority,
     ) -> Task<R> {
         let dispatcher = self.dispatcher.clone();
-        #[cfg(target_os = "windows")]
-        let priority = Priority::Medium; // multi-prio scheduler is broken on windows
-
         let (runnable, task) = if let Priority::Realtime(realtime) = priority {
             let location = core::panic::Location::caller();
             let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);

crates/gpui/src/gpui.rs 🔗

@@ -31,7 +31,7 @@ mod path_builder;
 mod platform;
 pub mod prelude;
 mod profiler;
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "windows", target_os = "linux"))]
 mod queue;
 mod scene;
 mod shared_string;
@@ -91,7 +91,7 @@ pub use keymap::*;
 pub use path_builder::*;
 pub use platform::*;
 pub use profiler::*;
-#[cfg(target_os = "linux")]
+#[cfg(any(target_os = "windows", target_os = "linux"))]
 pub(crate) use queue::{PriorityQueueReceiver, PriorityQueueSender};
 pub use refineable::*;
 pub use scene::*;

crates/gpui/src/platform/windows/dispatcher.rs 🔗

@@ -4,24 +4,31 @@ use std::{
     time::{Duration, Instant},
 };
 
-use flume::Sender;
+use anyhow::Context;
 use util::ResultExt;
 use windows::{
-    System::Threading::{ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler},
+    System::Threading::{
+        ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
+    },
     Win32::{
         Foundation::{LPARAM, WPARAM},
+        System::Threading::{
+            GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
+            THREAD_PRIORITY_HIGHEST, THREAD_PRIORITY_TIME_CRITICAL,
+        },
         UI::WindowsAndMessaging::PostMessageW,
     },
 };
 
 use crate::{
-    GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, RunnableVariant, SafeHwnd, THREAD_TIMINGS,
-    TaskLabel, TaskTiming, ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
+    GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, Priority, PriorityQueueSender,
+    RealtimePriority, RunnableVariant, SafeHwnd, THREAD_TIMINGS, TaskLabel, TaskTiming,
+    ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, profiler,
 };
 
 pub(crate) struct WindowsDispatcher {
     pub(crate) wake_posted: AtomicBool,
-    main_sender: Sender<RunnableVariant>,
+    main_sender: PriorityQueueSender<RunnableVariant>,
     main_thread_id: ThreadId,
     pub(crate) platform_window_handle: SafeHwnd,
     validation_number: usize,
@@ -29,7 +36,7 @@ pub(crate) struct WindowsDispatcher {
 
 impl WindowsDispatcher {
     pub(crate) fn new(
-        main_sender: Sender<RunnableVariant>,
+        main_sender: PriorityQueueSender<RunnableVariant>,
         platform_window_handle: HWND,
         validation_number: usize,
     ) -> Self {
@@ -45,7 +52,7 @@ impl WindowsDispatcher {
         }
     }
 
-    fn dispatch_on_threadpool(&self, runnable: RunnableVariant) {
+    fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
         let handler = {
             let mut task_wrapper = Some(runnable);
             WorkItemHandler::new(move |_| {
@@ -53,7 +60,8 @@ impl WindowsDispatcher {
                 Ok(())
             })
         };
-        ThreadPool::RunAsync(&handler).log_err();
+
+        ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
     }
 
     fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
@@ -79,7 +87,7 @@ impl WindowsDispatcher {
                     start,
                     end: None,
                 };
-                Self::add_task_timing(timing);
+                profiler::add_task_timing(timing);
 
                 runnable.run();
 
@@ -91,7 +99,7 @@ impl WindowsDispatcher {
                     start,
                     end: None,
                 };
-                Self::add_task_timing(timing);
+                profiler::add_task_timing(timing);
 
                 runnable.run();
 
@@ -102,23 +110,7 @@ impl WindowsDispatcher {
         let end = Instant::now();
         timing.end = Some(end);
 
-        Self::add_task_timing(timing);
-    }
-
-    pub(crate) fn add_task_timing(timing: TaskTiming) {
-        THREAD_TIMINGS.with(|timings| {
-            let mut timings = timings.lock();
-            let timings = &mut timings.timings;
-
-            if let Some(last_timing) = timings.iter_mut().rev().next() {
-                if last_timing.location == timing.location {
-                    last_timing.end = timing.end;
-                    return;
-                }
-            }
-
-            timings.push_back(timing);
-        });
+        profiler::add_task_timing(timing);
     }
 }
 
@@ -146,20 +138,22 @@ impl PlatformDispatcher for WindowsDispatcher {
         current().id() == self.main_thread_id
     }
 
-    fn dispatch(
-        &self,
-        runnable: RunnableVariant,
-        label: Option<TaskLabel>,
-        _priority: gpui::Priority,
-    ) {
-        self.dispatch_on_threadpool(runnable);
+    fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority) {
+        let priority = match priority {
+            Priority::Realtime(_) => unreachable!(),
+            Priority::High => WorkItemPriority::High,
+            Priority::Medium => WorkItemPriority::Normal,
+            Priority::Low => WorkItemPriority::Low,
+        };
+        self.dispatch_on_threadpool(priority, runnable);
+
         if let Some(label) = label {
             log::debug!("TaskLabel: {label:?}");
         }
     }
 
-    fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: gpui::Priority) {
-        match self.main_sender.send(runnable) {
+    fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
+        match self.main_sender.send(priority, runnable) {
             Ok(_) => {
                 if !self.wake_posted.swap(true, Ordering::AcqRel) {
                     unsafe {
@@ -191,8 +185,27 @@ impl PlatformDispatcher for WindowsDispatcher {
         self.dispatch_on_threadpool_after(runnable, duration);
     }
 
-    fn spawn_realtime(&self, _priority: crate::RealtimePriority, _f: Box<dyn FnOnce() + Send>) {
-        // disabled on windows for now.
-        unimplemented!();
+    fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
+        std::thread::spawn(move || {
+            // SAFETY: always safe to call
+            let thread_handle = unsafe { GetCurrentThread() };
+
+            let thread_priority = match priority {
+                RealtimePriority::Audio => THREAD_PRIORITY_TIME_CRITICAL,
+                RealtimePriority::Other => THREAD_PRIORITY_HIGHEST,
+            };
+
+            // SAFETY: thread_handle is a valid handle to a thread
+            unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
+                .context("thread priority class")
+                .log_err();
+
+            // SAFETY: thread_handle is a valid handle to a thread
+            unsafe { SetThreadPriority(thread_handle, thread_priority) }
+                .context("thread priority")
+                .log_err();
+
+            f();
+        });
     }
 }

crates/gpui/src/platform/windows/events.rs 🔗

@@ -248,7 +248,8 @@ impl WindowsWindowInner {
 
     fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> {
         if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID {
-            for runnable in self.main_receiver.drain() {
+            let mut runnables = self.main_receiver.clone().try_iter();
+            while let Some(Ok(runnable)) = runnables.next() {
                 WindowsDispatcher::execute_runnable(runnable);
             }
             self.handle_paint_msg(handle)

crates/gpui/src/platform/windows/platform.rs 🔗

@@ -51,7 +51,7 @@ struct WindowsPlatformInner {
     raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
     // The below members will never change throughout the entire lifecycle of the app.
     validation_number: usize,
-    main_receiver: flume::Receiver<RunnableVariant>,
+    main_receiver: PriorityQueueReceiver<RunnableVariant>,
     dispatcher: Arc<WindowsDispatcher>,
 }
 
@@ -98,7 +98,7 @@ impl WindowsPlatform {
             OleInitialize(None).context("unable to initialize Windows OLE")?;
         }
         let directx_devices = DirectXDevices::new().context("Creating DirectX devices")?;
-        let (main_sender, main_receiver) = flume::unbounded::<RunnableVariant>();
+        let (main_sender, main_receiver) = PriorityQueueReceiver::new();
         let validation_number = if usize::BITS == 64 {
             rand::random::<u64>() as usize
         } else {
@@ -857,22 +857,24 @@ impl WindowsPlatformInner {
                     }
                     break 'tasks;
                 }
-                match self.main_receiver.try_recv() {
-                    Err(_) => break 'timeout_loop,
-                    Ok(runnable) => WindowsDispatcher::execute_runnable(runnable),
+                let mut main_receiver = self.main_receiver.clone();
+                match main_receiver.try_pop() {
+                    Ok(Some(runnable)) => WindowsDispatcher::execute_runnable(runnable),
+                    _ => break 'timeout_loop,
                 }
             }
 
             // Someone could enqueue a Runnable here. The flag is still true, so they will not PostMessage.
             // We need to check for those Runnables after we clear the flag.
             self.dispatcher.wake_posted.store(false, Ordering::Release);
-            match self.main_receiver.try_recv() {
-                Err(_) => break 'tasks,
-                Ok(runnable) => {
+            let mut main_receiver = self.main_receiver.clone();
+            match main_receiver.try_pop() {
+                Ok(Some(runnable)) => {
                     self.dispatcher.wake_posted.store(true, Ordering::Release);
 
                     WindowsDispatcher::execute_runnable(runnable);
                 }
+                _ => break 'tasks,
             }
         }
 
@@ -934,7 +936,7 @@ pub(crate) struct WindowCreationInfo {
     pub(crate) windows_version: WindowsVersion,
     pub(crate) drop_target_helper: IDropTargetHelper,
     pub(crate) validation_number: usize,
-    pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
+    pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
     pub(crate) platform_window_handle: HWND,
     pub(crate) disable_direct_composition: bool,
     pub(crate) directx_devices: DirectXDevices,
@@ -947,8 +949,8 @@ struct PlatformWindowCreateContext {
     inner: Option<Result<Rc<WindowsPlatformInner>>>,
     raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
     validation_number: usize,
-    main_sender: Option<flume::Sender<RunnableVariant>>,
-    main_receiver: Option<flume::Receiver<RunnableVariant>>,
+    main_sender: Option<PriorityQueueSender<RunnableVariant>>,
+    main_receiver: Option<PriorityQueueReceiver<RunnableVariant>>,
     directx_devices: Option<DirectXDevices>,
     dispatcher: Option<Arc<WindowsDispatcher>>,
 }

crates/gpui/src/platform/windows/window.rs 🔗

@@ -81,7 +81,7 @@ pub(crate) struct WindowsWindowInner {
     pub(crate) executor: ForegroundExecutor,
     pub(crate) windows_version: WindowsVersion,
     pub(crate) validation_number: usize,
-    pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
+    pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
     pub(crate) platform_window_handle: HWND,
     pub(crate) parent_hwnd: Option<HWND>,
 }
@@ -364,7 +364,7 @@ struct WindowCreateContext {
     windows_version: WindowsVersion,
     drop_target_helper: IDropTargetHelper,
     validation_number: usize,
-    main_receiver: flume::Receiver<RunnableVariant>,
+    main_receiver: PriorityQueueReceiver<RunnableVariant>,
     platform_window_handle: HWND,
     appearance: WindowAppearance,
     disable_direct_composition: bool,