@@ -1,94 +1,78 @@
use std::{
- cmp::Ordering,
- thread::{current, JoinHandle, ThreadId},
- time::{Duration, Instant},
+ sync::{
+ atomic::{AtomicIsize, Ordering},
+ Arc,
+ },
+ thread::{current, ThreadId},
};
use async_task::Runnable;
-use collections::BinaryHeap;
-use flume::{RecvTimeoutError, Sender};
+use flume::Sender;
use parking::Parker;
use parking_lot::Mutex;
-use windows::Win32::{Foundation::HANDLE, System::Threading::SetEvent};
+use windows::Win32::{
+ Foundation::{BOOLEAN, HANDLE},
+ System::Threading::{
+ CreateThreadpool, CreateThreadpoolWork, CreateTimerQueueTimer, DeleteTimerQueueTimer,
+ SetEvent, SetThreadpoolThreadMinimum, SubmitThreadpoolWork, PTP_CALLBACK_INSTANCE,
+ PTP_POOL, PTP_WORK, TP_CALLBACK_ENVIRON_V3, TP_CALLBACK_PRIORITY_NORMAL,
+ WT_EXECUTEONLYONCE,
+ },
+};
use crate::{PlatformDispatcher, TaskLabel};
pub(crate) struct WindowsDispatcher {
- background_sender: Sender<(Runnable, Option<TaskLabel>)>,
+ threadpool: PTP_POOL,
main_sender: Sender<Runnable>,
- timer_sender: Sender<(Runnable, Duration)>,
- background_threads: Vec<JoinHandle<()>>,
- timer_thread: JoinHandle<()>,
parker: Mutex<Parker>,
main_thread_id: ThreadId,
- event: HANDLE,
+ dispatch_event: HANDLE,
}
impl WindowsDispatcher {
- pub(crate) fn new(main_sender: Sender<Runnable>, event: HANDLE) -> Self {
+ pub(crate) fn new(main_sender: Sender<Runnable>, dispatch_event: HANDLE) -> Self {
let parker = Mutex::new(Parker::new());
- let (background_sender, background_receiver) =
- flume::unbounded::<(Runnable, Option<TaskLabel>)>();
- let background_threads = (0..std::thread::available_parallelism()
- .map(|i| i.get())
- .unwrap_or(1))
- .map(|_| {
- let receiver = background_receiver.clone();
- std::thread::spawn(move || {
- for (runnable, label) in receiver {
- if let Some(label) = label {
- log::debug!("TaskLabel: {label:?}");
- }
- runnable.run();
- }
- })
- })
- .collect::<Vec<_>>();
- let (timer_sender, timer_receiver) = flume::unbounded::<(Runnable, Duration)>();
- let timer_thread = std::thread::spawn(move || {
- let mut runnables = BinaryHeap::<RunnableAfter>::new();
- let mut timeout_dur = None;
- loop {
- let recv = if let Some(dur) = timeout_dur {
- match timer_receiver.recv_timeout(dur) {
- Ok(recv) => Some(recv),
- Err(RecvTimeoutError::Timeout) => None,
- Err(RecvTimeoutError::Disconnected) => break,
- }
- } else if let Ok(recv) = timer_receiver.recv() {
- Some(recv)
- } else {
- break;
- };
- let now = Instant::now();
- if let Some((runnable, dur)) = recv {
- runnables.push(RunnableAfter {
- runnable,
- instant: now + dur,
- });
- while let Ok((runnable, dur)) = timer_receiver.try_recv() {
- runnables.push(RunnableAfter {
- runnable,
- instant: now + dur,
- })
- }
- }
- while runnables.peek().is_some_and(|entry| entry.instant <= now) {
- runnables.pop().unwrap().runnable.run();
- }
- timeout_dur = runnables.peek().map(|entry| entry.instant - now);
+ let threadpool = unsafe {
+ let ret = CreateThreadpool(None);
+ if ret.0 == 0 {
+ panic!(
+ "unable to initialize a thread pool: {}",
+ std::io::Error::last_os_error()
+ );
}
- });
+ // set minimum 1 thread in threadpool
+ let _ = SetThreadpoolThreadMinimum(ret, 1)
+ .inspect_err(|_| log::error!("unable to configure thread pool"));
+
+ ret
+ };
let main_thread_id = current().id();
- Self {
- background_sender,
+ WindowsDispatcher {
+ threadpool,
main_sender,
- timer_sender,
- background_threads,
- timer_thread,
parker,
main_thread_id,
- event,
+ dispatch_event,
+ }
+ }
+
+ fn dispatch_on_threadpool(&self, runnable: Runnable) {
+ unsafe {
+ let ptr = Box::into_raw(Box::new(runnable));
+ let environment = get_threadpool_environment(self.threadpool);
+ let Ok(work) =
+ CreateThreadpoolWork(Some(threadpool_runner), Some(ptr as _), Some(&environment))
+ .inspect_err(|_| {
+ log::error!(
+ "unable to dispatch work on thread pool: {}",
+ std::io::Error::last_os_error()
+ )
+ })
+ else {
+ return;
+ };
+ SubmitThreadpoolWork(work);
}
}
}
@@ -99,10 +83,10 @@ impl PlatformDispatcher for WindowsDispatcher {
}
fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
- self.background_sender
- .send((runnable, label))
- .inspect_err(|e| log::error!("Dispatch failed: {e}"))
- .ok();
+ self.dispatch_on_threadpool(runnable);
+ if let Some(label) = label {
+ log::debug!("TaskLabel: {label:?}");
+ }
}
fn dispatch_on_main_thread(&self, runnable: Runnable) {
@@ -110,14 +94,34 @@ impl PlatformDispatcher for WindowsDispatcher {
.send(runnable)
.inspect_err(|e| log::error!("Dispatch failed: {e}"))
.ok();
- unsafe { SetEvent(self.event) }.ok();
+ unsafe { SetEvent(self.dispatch_event) }.ok();
}
fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
- self.timer_sender
- .send((runnable, duration))
- .inspect_err(|e| log::error!("Dispatch failed: {e}"))
- .ok();
+ if duration.as_millis() == 0 {
+ self.dispatch_on_threadpool(runnable);
+ return;
+ }
+ unsafe {
+ let mut handle = std::mem::zeroed();
+ let task = Arc::new(DelayedTask::new(runnable));
+ let _ = CreateTimerQueueTimer(
+ &mut handle,
+ None,
+ Some(timer_queue_runner),
+ Some(Arc::into_raw(task.clone()) as _),
+ duration.as_millis() as u32,
+ 0,
+ WT_EXECUTEONLYONCE,
+ )
+ .inspect_err(|_| {
+ log::error!(
+ "unable to dispatch delayed task: {}",
+ std::io::Error::last_os_error()
+ )
+ });
+ task.raw_timer_handle.store(handle.0, Ordering::SeqCst);
+ }
}
fn tick(&self, _background_only: bool) -> bool {
@@ -133,27 +137,47 @@ impl PlatformDispatcher for WindowsDispatcher {
}
}
-struct RunnableAfter {
- runnable: Runnable,
- instant: Instant,
+extern "system" fn threadpool_runner(
+ _: PTP_CALLBACK_INSTANCE,
+ ptr: *mut std::ffi::c_void,
+ _: PTP_WORK,
+) {
+ unsafe {
+ let runnable = Box::from_raw(ptr as *mut Runnable);
+ runnable.run();
+ }
}
-impl PartialEq for RunnableAfter {
- fn eq(&self, other: &Self) -> bool {
- self.instant == other.instant
+unsafe extern "system" fn timer_queue_runner(ptr: *mut std::ffi::c_void, _: BOOLEAN) {
+ let task = Arc::from_raw(ptr as *mut DelayedTask);
+ task.runnable.lock().take().unwrap().run();
+ unsafe {
+ let timer = task.raw_timer_handle.load(Ordering::SeqCst);
+ let _ = DeleteTimerQueueTimer(None, HANDLE(timer), None);
}
}
-impl Eq for RunnableAfter {}
+struct DelayedTask {
+ runnable: Mutex<Option<Runnable>>,
+ raw_timer_handle: AtomicIsize,
+}
-impl Ord for RunnableAfter {
- fn cmp(&self, other: &Self) -> Ordering {
- self.instant.cmp(&other.instant).reverse()
+impl DelayedTask {
+ pub fn new(runnable: Runnable) -> Self {
+ DelayedTask {
+ runnable: Mutex::new(Some(runnable)),
+ raw_timer_handle: AtomicIsize::new(0),
+ }
}
}
-impl PartialOrd for RunnableAfter {
- fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
- Some(self.cmp(other))
+#[inline]
+fn get_threadpool_environment(pool: PTP_POOL) -> TP_CALLBACK_ENVIRON_V3 {
+ TP_CALLBACK_ENVIRON_V3 {
+ Version: 3, // Win7+, otherwise this value should be 1
+ Pool: pool,
+ CallbackPriority: TP_CALLBACK_PRIORITY_NORMAL,
+ Size: std::mem::size_of::<TP_CALLBACK_ENVIRON_V3>() as _,
+ ..Default::default()
}
}