dispatcher.rs

  1use std::{
  2    cmp::Ordering,
  3    thread::{current, JoinHandle, ThreadId},
  4    time::{Duration, Instant},
  5};
  6
  7use async_task::Runnable;
  8use collections::BinaryHeap;
  9use flume::{RecvTimeoutError, Sender};
 10use parking::Parker;
 11use parking_lot::Mutex;
 12use windows::Win32::{Foundation::HANDLE, System::Threading::SetEvent};
 13
 14use crate::{PlatformDispatcher, TaskLabel};
 15
 16pub(crate) struct WindowsDispatcher {
 17    background_sender: Sender<(Runnable, Option<TaskLabel>)>,
 18    main_sender: Sender<Runnable>,
 19    timer_sender: Sender<(Runnable, Duration)>,
 20    background_threads: Vec<JoinHandle<()>>,
 21    timer_thread: JoinHandle<()>,
 22    parker: Mutex<Parker>,
 23    main_thread_id: ThreadId,
 24    event: HANDLE,
 25}
 26
 27impl WindowsDispatcher {
 28    pub(crate) fn new(main_sender: Sender<Runnable>, event: HANDLE) -> Self {
 29        let parker = Mutex::new(Parker::new());
 30        let (background_sender, background_receiver) =
 31            flume::unbounded::<(Runnable, Option<TaskLabel>)>();
 32        let background_threads = (0..std::thread::available_parallelism()
 33            .map(|i| i.get())
 34            .unwrap_or(1))
 35            .map(|_| {
 36                let receiver = background_receiver.clone();
 37                std::thread::spawn(move || {
 38                    for (runnable, label) in receiver {
 39                        if let Some(label) = label {
 40                            log::debug!("TaskLabel: {label:?}");
 41                        }
 42                        runnable.run();
 43                    }
 44                })
 45            })
 46            .collect::<Vec<_>>();
 47        let (timer_sender, timer_receiver) = flume::unbounded::<(Runnable, Duration)>();
 48        let timer_thread = std::thread::spawn(move || {
 49            let mut runnables = BinaryHeap::<RunnableAfter>::new();
 50            let mut timeout_dur = None;
 51            loop {
 52                let recv = if let Some(dur) = timeout_dur {
 53                    match timer_receiver.recv_timeout(dur) {
 54                        Ok(recv) => Some(recv),
 55                        Err(RecvTimeoutError::Timeout) => None,
 56                        Err(RecvTimeoutError::Disconnected) => break,
 57                    }
 58                } else if let Ok(recv) = timer_receiver.recv() {
 59                    Some(recv)
 60                } else {
 61                    break;
 62                };
 63                let now = Instant::now();
 64                if let Some((runnable, dur)) = recv {
 65                    runnables.push(RunnableAfter {
 66                        runnable,
 67                        instant: now + dur,
 68                    });
 69                    while let Ok((runnable, dur)) = timer_receiver.try_recv() {
 70                        runnables.push(RunnableAfter {
 71                            runnable,
 72                            instant: now + dur,
 73                        })
 74                    }
 75                }
 76                while runnables.peek().is_some_and(|entry| entry.instant <= now) {
 77                    runnables.pop().unwrap().runnable.run();
 78                }
 79                timeout_dur = runnables.peek().map(|entry| entry.instant - now);
 80            }
 81        });
 82        let main_thread_id = current().id();
 83        Self {
 84            background_sender,
 85            main_sender,
 86            timer_sender,
 87            background_threads,
 88            timer_thread,
 89            parker,
 90            main_thread_id,
 91            event,
 92        }
 93    }
 94}
 95
 96impl PlatformDispatcher for WindowsDispatcher {
 97    fn is_main_thread(&self) -> bool {
 98        current().id() == self.main_thread_id
 99    }
100
101    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
102        self.background_sender
103            .send((runnable, label))
104            .inspect_err(|e| log::error!("Dispatch failed: {e}"))
105            .ok();
106    }
107
108    fn dispatch_on_main_thread(&self, runnable: Runnable) {
109        self.main_sender
110            .send(runnable)
111            .inspect_err(|e| log::error!("Dispatch failed: {e}"))
112            .ok();
113        unsafe { SetEvent(self.event) }.ok();
114    }
115
116    fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
117        self.timer_sender
118            .send((runnable, duration))
119            .inspect_err(|e| log::error!("Dispatch failed: {e}"))
120            .ok();
121    }
122
123    fn tick(&self, _background_only: bool) -> bool {
124        false
125    }
126
127    fn park(&self) {
128        self.parker.lock().park();
129    }
130
131    fn unparker(&self) -> parking::Unparker {
132        self.parker.lock().unparker()
133    }
134}
135
136struct RunnableAfter {
137    runnable: Runnable,
138    instant: Instant,
139}
140
141impl PartialEq for RunnableAfter {
142    fn eq(&self, other: &Self) -> bool {
143        self.instant == other.instant
144    }
145}
146
147impl Eq for RunnableAfter {}
148
149impl Ord for RunnableAfter {
150    fn cmp(&self, other: &Self) -> Ordering {
151        self.instant.cmp(&other.instant).reverse()
152    }
153}
154
155impl PartialOrd for RunnableAfter {
156    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
157        Some(self.cmp(other))
158    }
159}