dispatcher.rs

  1use std::{
  2    sync::atomic::{AtomicBool, Ordering},
  3    thread::{ThreadId, current},
  4    time::{Duration, Instant},
  5};
  6
  7use anyhow::Context;
  8use util::ResultExt;
  9use windows::{
 10    System::Threading::{
 11        ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
 12    },
 13    Win32::{
 14        Foundation::{LPARAM, WPARAM},
 15        System::Threading::{
 16            GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
 17            THREAD_PRIORITY_HIGHEST, THREAD_PRIORITY_TIME_CRITICAL,
 18        },
 19        UI::WindowsAndMessaging::PostMessageW,
 20    },
 21};
 22
 23use crate::{
 24    GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, Priority, PriorityQueueSender,
 25    RealtimePriority, RunnableVariant, SafeHwnd, THREAD_TIMINGS, TaskLabel, TaskTiming,
 26    ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, profiler,
 27};
 28
 29pub(crate) struct WindowsDispatcher {
 30    pub(crate) wake_posted: AtomicBool,
 31    main_sender: PriorityQueueSender<RunnableVariant>,
 32    main_thread_id: ThreadId,
 33    pub(crate) platform_window_handle: SafeHwnd,
 34    validation_number: usize,
 35}
 36
 37impl WindowsDispatcher {
 38    pub(crate) fn new(
 39        main_sender: PriorityQueueSender<RunnableVariant>,
 40        platform_window_handle: HWND,
 41        validation_number: usize,
 42    ) -> Self {
 43        let main_thread_id = current().id();
 44        let platform_window_handle = platform_window_handle.into();
 45
 46        WindowsDispatcher {
 47            main_sender,
 48            main_thread_id,
 49            platform_window_handle,
 50            validation_number,
 51            wake_posted: AtomicBool::new(false),
 52        }
 53    }
 54
 55    fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
 56        let handler = {
 57            let mut task_wrapper = Some(runnable);
 58            WorkItemHandler::new(move |_| {
 59                Self::execute_runnable(task_wrapper.take().unwrap());
 60                Ok(())
 61            })
 62        };
 63
 64        ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
 65    }
 66
 67    fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
 68        let handler = {
 69            let mut task_wrapper = Some(runnable);
 70            TimerElapsedHandler::new(move |_| {
 71                Self::execute_runnable(task_wrapper.take().unwrap());
 72                Ok(())
 73            })
 74        };
 75        ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
 76    }
 77
 78    #[inline(always)]
 79    pub(crate) fn execute_runnable(runnable: RunnableVariant) {
 80        let start = Instant::now();
 81
 82        let mut timing = match runnable {
 83            RunnableVariant::Meta(runnable) => {
 84                let location = runnable.metadata().location;
 85                let timing = TaskTiming {
 86                    location,
 87                    start,
 88                    end: None,
 89                };
 90                profiler::add_task_timing(timing);
 91
 92                runnable.run();
 93
 94                timing
 95            }
 96            RunnableVariant::Compat(runnable) => {
 97                let timing = TaskTiming {
 98                    location: core::panic::Location::caller(),
 99                    start,
100                    end: None,
101                };
102                profiler::add_task_timing(timing);
103
104                runnable.run();
105
106                timing
107            }
108        };
109
110        let end = Instant::now();
111        timing.end = Some(end);
112
113        profiler::add_task_timing(timing);
114    }
115}
116
117impl PlatformDispatcher for WindowsDispatcher {
118    fn get_all_timings(&self) -> Vec<ThreadTaskTimings> {
119        let global_thread_timings = GLOBAL_THREAD_TIMINGS.lock();
120        ThreadTaskTimings::convert(&global_thread_timings)
121    }
122
123    fn get_current_thread_timings(&self) -> Vec<crate::TaskTiming> {
124        THREAD_TIMINGS.with(|timings| {
125            let timings = timings.lock();
126            let timings = &timings.timings;
127
128            let mut vec = Vec::with_capacity(timings.len());
129
130            let (s1, s2) = timings.as_slices();
131            vec.extend_from_slice(s1);
132            vec.extend_from_slice(s2);
133            vec
134        })
135    }
136
137    fn is_main_thread(&self) -> bool {
138        current().id() == self.main_thread_id
139    }
140
141    fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority) {
142        let priority = match priority {
143            Priority::Realtime(_) => unreachable!(),
144            Priority::High => WorkItemPriority::High,
145            Priority::Medium => WorkItemPriority::Normal,
146            Priority::Low => WorkItemPriority::Low,
147        };
148        self.dispatch_on_threadpool(priority, runnable);
149
150        if let Some(label) = label {
151            log::debug!("TaskLabel: {label:?}");
152        }
153    }
154
155    fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
156        match self.main_sender.send(priority, runnable) {
157            Ok(_) => {
158                if !self.wake_posted.swap(true, Ordering::AcqRel) {
159                    unsafe {
160                        PostMessageW(
161                            Some(self.platform_window_handle.as_raw()),
162                            WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
163                            WPARAM(self.validation_number),
164                            LPARAM(0),
165                        )
166                        .log_err();
167                    }
168                }
169            }
170            Err(runnable) => {
171                // NOTE: Runnable may wrap a Future that is !Send.
172                //
173                // This is usually safe because we only poll it on the main thread.
174                // However if the send fails, we know that:
175                // 1. main_receiver has been dropped (which implies the app is shutting down)
176                // 2. we are on a background thread.
177                // It is not safe to drop something !Send on the wrong thread, and
178                // the app will exit soon anyway, so we must forget the runnable.
179                std::mem::forget(runnable);
180            }
181        }
182    }
183
184    fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
185        self.dispatch_on_threadpool_after(runnable, duration);
186    }
187
188    fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
189        std::thread::spawn(move || {
190            // SAFETY: always safe to call
191            let thread_handle = unsafe { GetCurrentThread() };
192
193            let thread_priority = match priority {
194                RealtimePriority::Audio => THREAD_PRIORITY_TIME_CRITICAL,
195                RealtimePriority::Other => THREAD_PRIORITY_HIGHEST,
196            };
197
198            // SAFETY: thread_handle is a valid handle to a thread
199            unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
200                .context("thread priority class")
201                .log_err();
202
203            // SAFETY: thread_handle is a valid handle to a thread
204            unsafe { SetThreadPriority(thread_handle, thread_priority) }
205                .context("thread priority")
206                .log_err();
207
208            f();
209        });
210    }
211}