1use std::{
2 sync::atomic::{AtomicBool, Ordering},
3 thread::{ThreadId, current},
4 time::{Duration, Instant},
5};
6
7use anyhow::Context;
8use util::ResultExt;
9use windows::{
10 System::Threading::{
11 ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
12 },
13 Win32::{
14 Foundation::{LPARAM, WPARAM},
15 System::Threading::{
16 GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
17 THREAD_PRIORITY_HIGHEST, THREAD_PRIORITY_TIME_CRITICAL,
18 },
19 UI::WindowsAndMessaging::PostMessageW,
20 },
21};
22
23use crate::{
24 GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, Priority, PriorityQueueSender,
25 RealtimePriority, RunnableVariant, SafeHwnd, THREAD_TIMINGS, TaskLabel, TaskTiming,
26 ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, profiler,
27};
28
29pub(crate) struct WindowsDispatcher {
30 pub(crate) wake_posted: AtomicBool,
31 main_sender: PriorityQueueSender<RunnableVariant>,
32 main_thread_id: ThreadId,
33 pub(crate) platform_window_handle: SafeHwnd,
34 validation_number: usize,
35}
36
37impl WindowsDispatcher {
38 pub(crate) fn new(
39 main_sender: PriorityQueueSender<RunnableVariant>,
40 platform_window_handle: HWND,
41 validation_number: usize,
42 ) -> Self {
43 let main_thread_id = current().id();
44 let platform_window_handle = platform_window_handle.into();
45
46 WindowsDispatcher {
47 main_sender,
48 main_thread_id,
49 platform_window_handle,
50 validation_number,
51 wake_posted: AtomicBool::new(false),
52 }
53 }
54
55 fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
56 let handler = {
57 let mut task_wrapper = Some(runnable);
58 WorkItemHandler::new(move |_| {
59 Self::execute_runnable(task_wrapper.take().unwrap());
60 Ok(())
61 })
62 };
63
64 ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
65 }
66
67 fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
68 let handler = {
69 let mut task_wrapper = Some(runnable);
70 TimerElapsedHandler::new(move |_| {
71 Self::execute_runnable(task_wrapper.take().unwrap());
72 Ok(())
73 })
74 };
75 ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
76 }
77
78 #[inline(always)]
79 pub(crate) fn execute_runnable(runnable: RunnableVariant) {
80 let start = Instant::now();
81
82 let mut timing = match runnable {
83 RunnableVariant::Meta(runnable) => {
84 let location = runnable.metadata().location;
85 let timing = TaskTiming {
86 location,
87 start,
88 end: None,
89 };
90 profiler::add_task_timing(timing);
91
92 runnable.run();
93
94 timing
95 }
96 RunnableVariant::Compat(runnable) => {
97 let timing = TaskTiming {
98 location: core::panic::Location::caller(),
99 start,
100 end: None,
101 };
102 profiler::add_task_timing(timing);
103
104 runnable.run();
105
106 timing
107 }
108 };
109
110 let end = Instant::now();
111 timing.end = Some(end);
112
113 profiler::add_task_timing(timing);
114 }
115}
116
117impl PlatformDispatcher for WindowsDispatcher {
118 fn get_all_timings(&self) -> Vec<ThreadTaskTimings> {
119 let global_thread_timings = GLOBAL_THREAD_TIMINGS.lock();
120 ThreadTaskTimings::convert(&global_thread_timings)
121 }
122
123 fn get_current_thread_timings(&self) -> Vec<crate::TaskTiming> {
124 THREAD_TIMINGS.with(|timings| {
125 let timings = timings.lock();
126 let timings = &timings.timings;
127
128 let mut vec = Vec::with_capacity(timings.len());
129
130 let (s1, s2) = timings.as_slices();
131 vec.extend_from_slice(s1);
132 vec.extend_from_slice(s2);
133 vec
134 })
135 }
136
137 fn is_main_thread(&self) -> bool {
138 current().id() == self.main_thread_id
139 }
140
141 fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority) {
142 let priority = match priority {
143 Priority::Realtime(_) => unreachable!(),
144 Priority::High => WorkItemPriority::High,
145 Priority::Medium => WorkItemPriority::Normal,
146 Priority::Low => WorkItemPriority::Low,
147 };
148 self.dispatch_on_threadpool(priority, runnable);
149
150 if let Some(label) = label {
151 log::debug!("TaskLabel: {label:?}");
152 }
153 }
154
155 fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
156 match self.main_sender.send(priority, runnable) {
157 Ok(_) => {
158 if !self.wake_posted.swap(true, Ordering::AcqRel) {
159 unsafe {
160 PostMessageW(
161 Some(self.platform_window_handle.as_raw()),
162 WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
163 WPARAM(self.validation_number),
164 LPARAM(0),
165 )
166 .log_err();
167 }
168 }
169 }
170 Err(runnable) => {
171 // NOTE: Runnable may wrap a Future that is !Send.
172 //
173 // This is usually safe because we only poll it on the main thread.
174 // However if the send fails, we know that:
175 // 1. main_receiver has been dropped (which implies the app is shutting down)
176 // 2. we are on a background thread.
177 // It is not safe to drop something !Send on the wrong thread, and
178 // the app will exit soon anyway, so we must forget the runnable.
179 std::mem::forget(runnable);
180 }
181 }
182 }
183
184 fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
185 self.dispatch_on_threadpool_after(runnable, duration);
186 }
187
188 fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
189 std::thread::spawn(move || {
190 // SAFETY: always safe to call
191 let thread_handle = unsafe { GetCurrentThread() };
192
193 let thread_priority = match priority {
194 RealtimePriority::Audio => THREAD_PRIORITY_TIME_CRITICAL,
195 RealtimePriority::Other => THREAD_PRIORITY_HIGHEST,
196 };
197
198 // SAFETY: thread_handle is a valid handle to a thread
199 unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
200 .context("thread priority class")
201 .log_err();
202
203 // SAFETY: thread_handle is a valid handle to a thread
204 unsafe { SetThreadPriority(thread_handle, thread_priority) }
205 .context("thread priority")
206 .log_err();
207
208 f();
209 });
210 }
211}