dispatcher.rs

  1use std::{
  2    sync::{
  3        atomic::{AtomicIsize, Ordering},
  4        Arc,
  5    },
  6    thread::{current, ThreadId},
  7};
  8
  9use async_task::Runnable;
 10use flume::Sender;
 11use parking::Parker;
 12use parking_lot::Mutex;
 13use windows::Win32::{Foundation::*, System::Threading::*};
 14
 15use crate::{PlatformDispatcher, TaskLabel};
 16
 17pub(crate) struct WindowsDispatcher {
 18    threadpool: PTP_POOL,
 19    main_sender: Sender<Runnable>,
 20    parker: Mutex<Parker>,
 21    main_thread_id: ThreadId,
 22    dispatch_event: HANDLE,
 23}
 24
 25impl WindowsDispatcher {
 26    pub(crate) fn new(main_sender: Sender<Runnable>, dispatch_event: HANDLE) -> Self {
 27        let parker = Mutex::new(Parker::new());
 28        let threadpool = unsafe {
 29            let ret = CreateThreadpool(None);
 30            if ret.0 == 0 {
 31                panic!(
 32                    "unable to initialize a thread pool: {}",
 33                    std::io::Error::last_os_error()
 34                );
 35            }
 36            // set minimum 1 thread in threadpool
 37            let _ = SetThreadpoolThreadMinimum(ret, 1)
 38                .inspect_err(|_| log::error!("unable to configure thread pool"));
 39
 40            ret
 41        };
 42        let main_thread_id = current().id();
 43        WindowsDispatcher {
 44            threadpool,
 45            main_sender,
 46            parker,
 47            main_thread_id,
 48            dispatch_event,
 49        }
 50    }
 51
 52    fn dispatch_on_threadpool(&self, runnable: Runnable) {
 53        unsafe {
 54            let ptr = Box::into_raw(Box::new(runnable));
 55            let environment = get_threadpool_environment(self.threadpool);
 56            let Ok(work) =
 57                CreateThreadpoolWork(Some(threadpool_runner), Some(ptr as _), Some(&environment))
 58                    .inspect_err(|_| {
 59                        log::error!(
 60                            "unable to dispatch work on thread pool: {}",
 61                            std::io::Error::last_os_error()
 62                        )
 63                    })
 64            else {
 65                return;
 66            };
 67            SubmitThreadpoolWork(work);
 68        }
 69    }
 70}
 71
 72impl PlatformDispatcher for WindowsDispatcher {
 73    fn is_main_thread(&self) -> bool {
 74        current().id() == self.main_thread_id
 75    }
 76
 77    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
 78        self.dispatch_on_threadpool(runnable);
 79        if let Some(label) = label {
 80            log::debug!("TaskLabel: {label:?}");
 81        }
 82    }
 83
 84    fn dispatch_on_main_thread(&self, runnable: Runnable) {
 85        self.main_sender
 86            .send(runnable)
 87            .inspect_err(|e| log::error!("Dispatch failed: {e}"))
 88            .ok();
 89        unsafe { SetEvent(self.dispatch_event) }.ok();
 90    }
 91
 92    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
 93        if duration.as_millis() == 0 {
 94            self.dispatch_on_threadpool(runnable);
 95            return;
 96        }
 97        unsafe {
 98            let mut handle = std::mem::zeroed();
 99            let task = Arc::new(DelayedTask::new(runnable));
100            let _ = CreateTimerQueueTimer(
101                &mut handle,
102                None,
103                Some(timer_queue_runner),
104                Some(Arc::into_raw(task.clone()) as _),
105                duration.as_millis() as u32,
106                0,
107                WT_EXECUTEONLYONCE,
108            )
109            .inspect_err(|_| {
110                log::error!(
111                    "unable to dispatch delayed task: {}",
112                    std::io::Error::last_os_error()
113                )
114            });
115            task.raw_timer_handle.store(handle.0, Ordering::SeqCst);
116        }
117    }
118
119    fn tick(&self, _background_only: bool) -> bool {
120        false
121    }
122
123    fn park(&self) {
124        self.parker.lock().park();
125    }
126
127    fn unparker(&self) -> parking::Unparker {
128        self.parker.lock().unparker()
129    }
130}
131
132extern "system" fn threadpool_runner(
133    _: PTP_CALLBACK_INSTANCE,
134    ptr: *mut std::ffi::c_void,
135    _: PTP_WORK,
136) {
137    unsafe {
138        let runnable = Box::from_raw(ptr as *mut Runnable);
139        runnable.run();
140    }
141}
142
143unsafe extern "system" fn timer_queue_runner(ptr: *mut std::ffi::c_void, _: BOOLEAN) {
144    let task = Arc::from_raw(ptr as *mut DelayedTask);
145    task.runnable.lock().take().unwrap().run();
146    unsafe {
147        let timer = task.raw_timer_handle.load(Ordering::SeqCst);
148        let _ = DeleteTimerQueueTimer(None, HANDLE(timer), None);
149    }
150}
151
152struct DelayedTask {
153    runnable: Mutex<Option<Runnable>>,
154    raw_timer_handle: AtomicIsize,
155}
156
157impl DelayedTask {
158    pub fn new(runnable: Runnable) -> Self {
159        DelayedTask {
160            runnable: Mutex::new(Some(runnable)),
161            raw_timer_handle: AtomicIsize::new(0),
162        }
163    }
164}
165
166#[inline]
167fn get_threadpool_environment(pool: PTP_POOL) -> TP_CALLBACK_ENVIRON_V3 {
168    TP_CALLBACK_ENVIRON_V3 {
169        Version: 3, // Win7+, otherwise this value should be 1
170        Pool: pool,
171        CallbackPriority: TP_CALLBACK_PRIORITY_NORMAL,
172        Size: std::mem::size_of::<TP_CALLBACK_ENVIRON_V3>() as _,
173        ..Default::default()
174    }
175}