1use std::{
2 cmp::Ordering,
3 thread::{current, JoinHandle, ThreadId},
4 time::{Duration, Instant},
5};
6
7use async_task::Runnable;
8use collections::BinaryHeap;
9use flume::{RecvTimeoutError, Sender};
10use parking::Parker;
11use parking_lot::Mutex;
12use windows::Win32::{Foundation::HANDLE, System::Threading::SetEvent};
13
14use crate::{PlatformDispatcher, TaskLabel};
15
16pub(crate) struct WindowsDispatcher {
17 background_sender: Sender<(Runnable, Option<TaskLabel>)>,
18 main_sender: Sender<Runnable>,
19 timer_sender: Sender<(Runnable, Duration)>,
20 background_threads: Vec<JoinHandle<()>>,
21 timer_thread: JoinHandle<()>,
22 parker: Mutex<Parker>,
23 main_thread_id: ThreadId,
24 event: HANDLE,
25}
26
27impl WindowsDispatcher {
28 pub(crate) fn new(main_sender: Sender<Runnable>, event: HANDLE) -> Self {
29 let parker = Mutex::new(Parker::new());
30 let (background_sender, background_receiver) =
31 flume::unbounded::<(Runnable, Option<TaskLabel>)>();
32 let background_threads = (0..std::thread::available_parallelism()
33 .map(|i| i.get())
34 .unwrap_or(1))
35 .map(|_| {
36 let receiver = background_receiver.clone();
37 std::thread::spawn(move || {
38 for (runnable, label) in receiver {
39 if let Some(label) = label {
40 log::debug!("TaskLabel: {label:?}");
41 }
42 runnable.run();
43 }
44 })
45 })
46 .collect::<Vec<_>>();
47 let (timer_sender, timer_receiver) = flume::unbounded::<(Runnable, Duration)>();
48 let timer_thread = std::thread::spawn(move || {
49 let mut runnables = BinaryHeap::<RunnableAfter>::new();
50 let mut timeout_dur = None;
51 loop {
52 let recv = if let Some(dur) = timeout_dur {
53 match timer_receiver.recv_timeout(dur) {
54 Ok(recv) => Some(recv),
55 Err(RecvTimeoutError::Timeout) => None,
56 Err(RecvTimeoutError::Disconnected) => break,
57 }
58 } else if let Ok(recv) = timer_receiver.recv() {
59 Some(recv)
60 } else {
61 break;
62 };
63 let now = Instant::now();
64 if let Some((runnable, dur)) = recv {
65 runnables.push(RunnableAfter {
66 runnable,
67 instant: now + dur,
68 });
69 while let Ok((runnable, dur)) = timer_receiver.try_recv() {
70 runnables.push(RunnableAfter {
71 runnable,
72 instant: now + dur,
73 })
74 }
75 }
76 while runnables.peek().is_some_and(|entry| entry.instant <= now) {
77 runnables.pop().unwrap().runnable.run();
78 }
79 timeout_dur = runnables.peek().map(|entry| entry.instant - now);
80 }
81 });
82 let main_thread_id = current().id();
83 Self {
84 background_sender,
85 main_sender,
86 timer_sender,
87 background_threads,
88 timer_thread,
89 parker,
90 main_thread_id,
91 event,
92 }
93 }
94}
95
96impl PlatformDispatcher for WindowsDispatcher {
97 fn is_main_thread(&self) -> bool {
98 current().id() == self.main_thread_id
99 }
100
101 fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
102 self.background_sender
103 .send((runnable, label))
104 .inspect_err(|e| log::error!("Dispatch failed: {e}"))
105 .ok();
106 }
107
108 fn dispatch_on_main_thread(&self, runnable: Runnable) {
109 self.main_sender
110 .send(runnable)
111 .inspect_err(|e| log::error!("Dispatch failed: {e}"))
112 .ok();
113 unsafe { SetEvent(self.event) }.ok();
114 }
115
116 fn dispatch_after(&self, duration: std::time::Duration, runnable: Runnable) {
117 self.timer_sender
118 .send((runnable, duration))
119 .inspect_err(|e| log::error!("Dispatch failed: {e}"))
120 .ok();
121 }
122
123 fn tick(&self, _background_only: bool) -> bool {
124 false
125 }
126
127 fn park(&self) {
128 self.parker.lock().park();
129 }
130
131 fn unparker(&self) -> parking::Unparker {
132 self.parker.lock().unparker()
133 }
134}
135
136struct RunnableAfter {
137 runnable: Runnable,
138 instant: Instant,
139}
140
141impl PartialEq for RunnableAfter {
142 fn eq(&self, other: &Self) -> bool {
143 self.instant == other.instant
144 }
145}
146
147impl Eq for RunnableAfter {}
148
149impl Ord for RunnableAfter {
150 fn cmp(&self, other: &Self) -> Ordering {
151 self.instant.cmp(&other.instant).reverse()
152 }
153}
154
155impl PartialOrd for RunnableAfter {
156 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
157 Some(self.cmp(other))
158 }
159}