1use std::{
2 sync::atomic::{AtomicBool, Ordering},
3 thread::{ThreadId, current},
4 time::{Duration, Instant},
5};
6
7use flume::Sender;
8use util::ResultExt;
9use windows::{
10 System::Threading::{
11 ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
12 },
13 Win32::{
14 Foundation::{LPARAM, WPARAM},
15 UI::WindowsAndMessaging::PostMessageW,
16 },
17};
18
19use crate::{
20 GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, RunnableVariant, SafeHwnd, THREAD_TIMINGS,
21 TaskLabel, TaskTiming, ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
22};
23
24pub(crate) struct WindowsDispatcher {
25 pub(crate) wake_posted: AtomicBool,
26 main_sender: Sender<RunnableVariant>,
27 main_thread_id: ThreadId,
28 platform_window_handle: SafeHwnd,
29 validation_number: usize,
30}
31
32impl WindowsDispatcher {
33 pub(crate) fn new(
34 main_sender: Sender<RunnableVariant>,
35 platform_window_handle: HWND,
36 validation_number: usize,
37 ) -> Self {
38 let main_thread_id = current().id();
39 let platform_window_handle = platform_window_handle.into();
40
41 WindowsDispatcher {
42 main_sender,
43 main_thread_id,
44 platform_window_handle,
45 validation_number,
46 wake_posted: AtomicBool::new(false),
47 }
48 }
49
50 fn dispatch_on_threadpool(&self, runnable: RunnableVariant) {
51 let handler = {
52 let mut task_wrapper = Some(runnable);
53 WorkItemHandler::new(move |_| {
54 Self::execute_runnable(task_wrapper.take().unwrap());
55 Ok(())
56 })
57 };
58 ThreadPool::RunWithPriorityAsync(&handler, WorkItemPriority::High).log_err();
59 }
60
61 fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
62 let handler = {
63 let mut task_wrapper = Some(runnable);
64 TimerElapsedHandler::new(move |_| {
65 Self::execute_runnable(task_wrapper.take().unwrap());
66 Ok(())
67 })
68 };
69 ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
70 }
71
72 #[inline(always)]
73 pub(crate) fn execute_runnable(runnable: RunnableVariant) {
74 let start = Instant::now();
75
76 let mut timing = match runnable {
77 RunnableVariant::Meta(runnable) => {
78 let location = runnable.metadata().location;
79 let timing = TaskTiming {
80 location,
81 start,
82 end: None,
83 };
84 Self::add_task_timing(timing);
85
86 runnable.run();
87
88 timing
89 }
90 RunnableVariant::Compat(runnable) => {
91 let timing = TaskTiming {
92 location: core::panic::Location::caller(),
93 start,
94 end: None,
95 };
96 Self::add_task_timing(timing);
97
98 runnable.run();
99
100 timing
101 }
102 };
103
104 let end = Instant::now();
105 timing.end = Some(end);
106
107 Self::add_task_timing(timing);
108 }
109
110 pub(crate) fn add_task_timing(timing: TaskTiming) {
111 THREAD_TIMINGS.with(|timings| {
112 let mut timings = timings.lock();
113 let timings = &mut timings.timings;
114
115 if let Some(last_timing) = timings.iter_mut().rev().next() {
116 if last_timing.location == timing.location {
117 last_timing.end = timing.end;
118 return;
119 }
120 }
121
122 timings.push_back(timing);
123 });
124 }
125}
126
127impl PlatformDispatcher for WindowsDispatcher {
128 fn get_all_timings(&self) -> Vec<ThreadTaskTimings> {
129 let global_thread_timings = GLOBAL_THREAD_TIMINGS.lock();
130 ThreadTaskTimings::convert(&global_thread_timings)
131 }
132
133 fn get_current_thread_timings(&self) -> Vec<crate::TaskTiming> {
134 THREAD_TIMINGS.with(|timings| {
135 let timings = timings.lock();
136 let timings = &timings.timings;
137
138 let mut vec = Vec::with_capacity(timings.len());
139
140 let (s1, s2) = timings.as_slices();
141 vec.extend_from_slice(s1);
142 vec.extend_from_slice(s2);
143 vec
144 })
145 }
146
147 fn is_main_thread(&self) -> bool {
148 current().id() == self.main_thread_id
149 }
150
151 fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
152 self.dispatch_on_threadpool(runnable);
153 if let Some(label) = label {
154 log::debug!("TaskLabel: {label:?}");
155 }
156 }
157
158 fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
159 match self.main_sender.send(runnable) {
160 Ok(_) => {
161 if !self.wake_posted.swap(true, Ordering::AcqRel) {
162 unsafe {
163 PostMessageW(
164 Some(self.platform_window_handle.as_raw()),
165 WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
166 WPARAM(self.validation_number),
167 LPARAM(0),
168 )
169 .log_err();
170 }
171 }
172 }
173 Err(runnable) => {
174 // NOTE: Runnable may wrap a Future that is !Send.
175 //
176 // This is usually safe because we only poll it on the main thread.
177 // However if the send fails, we know that:
178 // 1. main_receiver has been dropped (which implies the app is shutting down)
179 // 2. we are on a background thread.
180 // It is not safe to drop something !Send on the wrong thread, and
181 // the app will exit soon anyway, so we must forget the runnable.
182 std::mem::forget(runnable);
183 }
184 }
185 }
186
187 fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
188 self.dispatch_on_threadpool_after(runnable, duration);
189 }
190}