dispatcher.rs

  1use std::{
  2    sync::atomic::{AtomicBool, Ordering},
  3    thread::{ThreadId, current},
  4    time::{Duration, Instant},
  5};
  6
  7use flume::Sender;
  8use util::ResultExt;
  9use windows::{
 10    System::Threading::{
 11        ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
 12    },
 13    Win32::{
 14        Foundation::{LPARAM, WPARAM},
 15        UI::WindowsAndMessaging::PostMessageW,
 16    },
 17};
 18
 19use crate::{
 20    GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, RunnableVariant, SafeHwnd, THREAD_TIMINGS,
 21    TaskLabel, TaskTiming, ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
 22};
 23
 24pub(crate) struct WindowsDispatcher {
 25    pub(crate) wake_posted: AtomicBool,
 26    main_sender: Sender<RunnableVariant>,
 27    main_thread_id: ThreadId,
 28    platform_window_handle: SafeHwnd,
 29    validation_number: usize,
 30}
 31
 32impl WindowsDispatcher {
 33    pub(crate) fn new(
 34        main_sender: Sender<RunnableVariant>,
 35        platform_window_handle: HWND,
 36        validation_number: usize,
 37    ) -> Self {
 38        let main_thread_id = current().id();
 39        let platform_window_handle = platform_window_handle.into();
 40
 41        WindowsDispatcher {
 42            main_sender,
 43            main_thread_id,
 44            platform_window_handle,
 45            validation_number,
 46            wake_posted: AtomicBool::new(false),
 47        }
 48    }
 49
 50    fn dispatch_on_threadpool(&self, runnable: RunnableVariant) {
 51        let handler = {
 52            let mut task_wrapper = Some(runnable);
 53            WorkItemHandler::new(move |_| {
 54                Self::execute_runnable(task_wrapper.take().unwrap());
 55                Ok(())
 56            })
 57        };
 58        ThreadPool::RunWithPriorityAsync(&handler, WorkItemPriority::High).log_err();
 59    }
 60
 61    fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
 62        let handler = {
 63            let mut task_wrapper = Some(runnable);
 64            TimerElapsedHandler::new(move |_| {
 65                Self::execute_runnable(task_wrapper.take().unwrap());
 66                Ok(())
 67            })
 68        };
 69        ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
 70    }
 71
 72    #[inline(always)]
 73    pub(crate) fn execute_runnable(runnable: RunnableVariant) {
 74        let start = Instant::now();
 75
 76        let mut timing = match runnable {
 77            RunnableVariant::Meta(runnable) => {
 78                let location = runnable.metadata().location;
 79                let timing = TaskTiming {
 80                    location,
 81                    start,
 82                    end: None,
 83                };
 84                Self::add_task_timing(timing);
 85
 86                runnable.run();
 87
 88                timing
 89            }
 90            RunnableVariant::Compat(runnable) => {
 91                let timing = TaskTiming {
 92                    location: core::panic::Location::caller(),
 93                    start,
 94                    end: None,
 95                };
 96                Self::add_task_timing(timing);
 97
 98                runnable.run();
 99
100                timing
101            }
102        };
103
104        let end = Instant::now();
105        timing.end = Some(end);
106
107        Self::add_task_timing(timing);
108    }
109
110    pub(crate) fn add_task_timing(timing: TaskTiming) {
111        THREAD_TIMINGS.with(|timings| {
112            let mut timings = timings.lock();
113            let timings = &mut timings.timings;
114
115            if let Some(last_timing) = timings.iter_mut().rev().next() {
116                if last_timing.location == timing.location {
117                    last_timing.end = timing.end;
118                    return;
119                }
120            }
121
122            timings.push_back(timing);
123        });
124    }
125}
126
127impl PlatformDispatcher for WindowsDispatcher {
128    fn get_all_timings(&self) -> Vec<ThreadTaskTimings> {
129        let global_thread_timings = GLOBAL_THREAD_TIMINGS.lock();
130        ThreadTaskTimings::convert(&global_thread_timings)
131    }
132
133    fn get_current_thread_timings(&self) -> Vec<crate::TaskTiming> {
134        THREAD_TIMINGS.with(|timings| {
135            let timings = timings.lock();
136            let timings = &timings.timings;
137
138            let mut vec = Vec::with_capacity(timings.len());
139
140            let (s1, s2) = timings.as_slices();
141            vec.extend_from_slice(s1);
142            vec.extend_from_slice(s2);
143            vec
144        })
145    }
146
147    fn is_main_thread(&self) -> bool {
148        current().id() == self.main_thread_id
149    }
150
151    fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
152        self.dispatch_on_threadpool(runnable);
153        if let Some(label) = label {
154            log::debug!("TaskLabel: {label:?}");
155        }
156    }
157
158    fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
159        match self.main_sender.send(runnable) {
160            Ok(_) => {
161                if !self.wake_posted.swap(true, Ordering::AcqRel) {
162                    unsafe {
163                        PostMessageW(
164                            Some(self.platform_window_handle.as_raw()),
165                            WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
166                            WPARAM(self.validation_number),
167                            LPARAM(0),
168                        )
169                        .log_err();
170                    }
171                }
172            }
173            Err(runnable) => {
174                // NOTE: Runnable may wrap a Future that is !Send.
175                //
176                // This is usually safe because we only poll it on the main thread.
177                // However if the send fails, we know that:
178                // 1. main_receiver has been dropped (which implies the app is shutting down)
179                // 2. we are on a background thread.
180                // It is not safe to drop something !Send on the wrong thread, and
181                // the app will exit soon anyway, so we must forget the runnable.
182                std::mem::forget(runnable);
183            }
184        }
185    }
186
187    fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
188        self.dispatch_on_threadpool_after(runnable, duration);
189    }
190}