Use Windows native API to impl dispatcher (#9280)

张小白 created

Using `Threadpool` and `TimerQueue` which are provided by the native
Windows APIs, to implement the corresponding interfaces, we do not need
to sort tasks ourselves as Windows will handle it in a relatively more
efficient manner, I guess. I am unsure if Zed would welcome this PR, and
suggestions are welcome.

Release Notes:

- N/A

Change summary

crates/gpui/src/platform/windows/dispatcher.rs | 206 +++++++++++--------
1 file changed, 115 insertions(+), 91 deletions(-)

Detailed changes

crates/gpui/src/platform/windows/dispatcher.rs 🔗

@@ -1,94 +1,78 @@
 use std::{
-    cmp::Ordering,
-    thread::{current, JoinHandle, ThreadId},
-    time::{Duration, Instant},
+    sync::{
+        atomic::{AtomicIsize, Ordering},
+        Arc,
+    },
+    thread::{current, ThreadId},
 };
 
 use async_task::Runnable;
-use collections::BinaryHeap;
-use flume::{RecvTimeoutError, Sender};
+use flume::Sender;
 use parking::Parker;
 use parking_lot::Mutex;
-use windows::Win32::{Foundation::HANDLE, System::Threading::SetEvent};
+use windows::Win32::{
+    Foundation::{BOOLEAN, HANDLE},
+    System::Threading::{
+        CreateThreadpool, CreateThreadpoolWork, CreateTimerQueueTimer, DeleteTimerQueueTimer,
+        SetEvent, SetThreadpoolThreadMinimum, SubmitThreadpoolWork, PTP_CALLBACK_INSTANCE,
+        PTP_POOL, PTP_WORK, TP_CALLBACK_ENVIRON_V3, TP_CALLBACK_PRIORITY_NORMAL,
+        WT_EXECUTEONLYONCE,
+    },
+};
 
 use crate::{PlatformDispatcher, TaskLabel};
 
 pub(crate) struct WindowsDispatcher {
-    background_sender: Sender<(Runnable, Option<TaskLabel>)>,
+    threadpool: PTP_POOL,
     main_sender: Sender<Runnable>,
-    timer_sender: Sender<(Runnable, Duration)>,
-    background_threads: Vec<JoinHandle<()>>,
-    timer_thread: JoinHandle<()>,
     parker: Mutex<Parker>,
     main_thread_id: ThreadId,
-    event: HANDLE,
+    dispatch_event: HANDLE,
 }
 
 impl WindowsDispatcher {
-    pub(crate) fn new(main_sender: Sender<Runnable>, event: HANDLE) -> Self {
+    pub(crate) fn new(main_sender: Sender<Runnable>, dispatch_event: HANDLE) -> Self {
         let parker = Mutex::new(Parker::new());
-        let (background_sender, background_receiver) =
-            flume::unbounded::<(Runnable, Option<TaskLabel>)>();
-        let background_threads = (0..std::thread::available_parallelism()
-            .map(|i| i.get())
-            .unwrap_or(1))
-            .map(|_| {
-                let receiver = background_receiver.clone();
-                std::thread::spawn(move || {
-                    for (runnable, label) in receiver {
-                        if let Some(label) = label {
-                            log::debug!("TaskLabel: {label:?}");
-                        }
-                        runnable.run();
-                    }
-                })
-            })
-            .collect::<Vec<_>>();
-        let (timer_sender, timer_receiver) = flume::unbounded::<(Runnable, Duration)>();
-        let timer_thread = std::thread::spawn(move || {
-            let mut runnables = BinaryHeap::<RunnableAfter>::new();
-            let mut timeout_dur = None;
-            loop {
-                let recv = if let Some(dur) = timeout_dur {
-                    match timer_receiver.recv_timeout(dur) {
-                        Ok(recv) => Some(recv),
-                        Err(RecvTimeoutError::Timeout) => None,
-                        Err(RecvTimeoutError::Disconnected) => break,
-                    }
-                } else if let Ok(recv) = timer_receiver.recv() {
-                    Some(recv)
-                } else {
-                    break;
-                };
-                let now = Instant::now();
-                if let Some((runnable, dur)) = recv {
-                    runnables.push(RunnableAfter {
-                        runnable,
-                        instant: now + dur,
-                    });
-                    while let Ok((runnable, dur)) = timer_receiver.try_recv() {
-                        runnables.push(RunnableAfter {
-                            runnable,
-                            instant: now + dur,
-                        })
-                    }
-                }
-                while runnables.peek().is_some_and(|entry| entry.instant <= now) {
-                    runnables.pop().unwrap().runnable.run();
-                }
-                timeout_dur = runnables.peek().map(|entry| entry.instant - now);
+        let threadpool = unsafe {
+            let ret = CreateThreadpool(None);
+            if ret.0 == 0 {
+                panic!(
+                    "unable to initialize a thread pool: {}",
+                    std::io::Error::last_os_error()
+                );
             }
-        });
+            // set minimum 1 thread in threadpool
+            let _ = SetThreadpoolThreadMinimum(ret, 1)
+                .inspect_err(|_| log::error!("unable to configure thread pool"));
+
+            ret
+        };
         let main_thread_id = current().id();
-        Self {
-            background_sender,
+        WindowsDispatcher {
+            threadpool,
             main_sender,
-            timer_sender,
-            background_threads,
-            timer_thread,
             parker,
             main_thread_id,
-            event,
+            dispatch_event,
+        }
+    }
+
+    fn dispatch_on_threadpool(&self, runnable: Runnable) {
+        unsafe {
+            let ptr = Box::into_raw(Box::new(runnable));
+            let environment = get_threadpool_environment(self.threadpool);
+            let Ok(work) =
+                CreateThreadpoolWork(Some(threadpool_runner), Some(ptr as _), Some(&environment))
+                    .inspect_err(|_| {
+                        log::error!(
+                            "unable to dispatch work on thread pool: {}",
+                            std::io::Error::last_os_error()
+                        )
+                    })
+            else {
+                return;
+            };
+            SubmitThreadpoolWork(work);
         }
     }
 }
@@ -99,10 +83,10 @@ impl PlatformDispatcher for WindowsDispatcher {
     }
 
     fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
-        self.background_sender
-            .send((runnable, label))
-            .inspect_err(|e| log::error!("Dispatch failed: {e}"))
-            .ok();
+        self.dispatch_on_threadpool(runnable);
+        if let Some(label) = label {
+            log::debug!("TaskLabel: {label:?}");
+        }
     }
 
     fn dispatch_on_main_thread(&self, runnable: Runnable) {
@@ -110,14 +94,34 @@ impl PlatformDispatcher for WindowsDispatcher {
             .send(runnable)
             .inspect_err(|e| log::error!("Dispatch failed: {e}"))
             .ok();
-        unsafe { SetEvent(self.event) }.ok();
+        unsafe { SetEvent(self.dispatch_event) }.ok();
     }
 
     fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
-        self.timer_sender
-            .send((runnable, duration))
-            .inspect_err(|e| log::error!("Dispatch failed: {e}"))
-            .ok();
+        if duration.as_millis() == 0 {
+            self.dispatch_on_threadpool(runnable);
+            return;
+        }
+        unsafe {
+            let mut handle = std::mem::zeroed();
+            let task = Arc::new(DelayedTask::new(runnable));
+            let _ = CreateTimerQueueTimer(
+                &mut handle,
+                None,
+                Some(timer_queue_runner),
+                Some(Arc::into_raw(task.clone()) as _),
+                duration.as_millis() as u32,
+                0,
+                WT_EXECUTEONLYONCE,
+            )
+            .inspect_err(|_| {
+                log::error!(
+                    "unable to dispatch delayed task: {}",
+                    std::io::Error::last_os_error()
+                )
+            });
+            task.raw_timer_handle.store(handle.0, Ordering::SeqCst);
+        }
     }
 
     fn tick(&self, _background_only: bool) -> bool {
@@ -133,27 +137,47 @@ impl PlatformDispatcher for WindowsDispatcher {
     }
 }
 
-struct RunnableAfter {
-    runnable: Runnable,
-    instant: Instant,
+extern "system" fn threadpool_runner(
+    _: PTP_CALLBACK_INSTANCE,
+    ptr: *mut std::ffi::c_void,
+    _: PTP_WORK,
+) {
+    unsafe {
+        let runnable = Box::from_raw(ptr as *mut Runnable);
+        runnable.run();
+    }
 }
 
-impl PartialEq for RunnableAfter {
-    fn eq(&self, other: &Self) -> bool {
-        self.instant == other.instant
+unsafe extern "system" fn timer_queue_runner(ptr: *mut std::ffi::c_void, _: BOOLEAN) {
+    let task = Arc::from_raw(ptr as *mut DelayedTask);
+    task.runnable.lock().take().unwrap().run();
+    unsafe {
+        let timer = task.raw_timer_handle.load(Ordering::SeqCst);
+        let _ = DeleteTimerQueueTimer(None, HANDLE(timer), None);
     }
 }
 
-impl Eq for RunnableAfter {}
+struct DelayedTask {
+    runnable: Mutex<Option<Runnable>>,
+    raw_timer_handle: AtomicIsize,
+}
 
-impl Ord for RunnableAfter {
-    fn cmp(&self, other: &Self) -> Ordering {
-        self.instant.cmp(&other.instant).reverse()
+impl DelayedTask {
+    pub fn new(runnable: Runnable) -> Self {
+        DelayedTask {
+            runnable: Mutex::new(Some(runnable)),
+            raw_timer_handle: AtomicIsize::new(0),
+        }
     }
 }
 
-impl PartialOrd for RunnableAfter {
-    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
-        Some(self.cmp(other))
+#[inline]
+fn get_threadpool_environment(pool: PTP_POOL) -> TP_CALLBACK_ENVIRON_V3 {
+    TP_CALLBACK_ENVIRON_V3 {
+        Version: 3, // Win7+, otherwise this value should be 1
+        Pool: pool,
+        CallbackPriority: TP_CALLBACK_PRIORITY_NORMAL,
+        Size: std::mem::size_of::<TP_CALLBACK_ENVIRON_V3>() as _,
+        ..Default::default()
     }
 }