dispatcher.rs

  1use std::{
  2    sync::atomic::{AtomicBool, Ordering},
  3    thread::{ThreadId, current},
  4    time::Duration,
  5};
  6
  7use async_task::Runnable;
  8use flume::Sender;
  9use util::ResultExt;
 10use windows::{
 11    System::Threading::{
 12        ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
 13    },
 14    Win32::{
 15        Foundation::{LPARAM, WPARAM},
 16        UI::WindowsAndMessaging::PostMessageW,
 17    },
 18};
 19
 20use crate::{
 21    HWND, PlatformDispatcher, SafeHwnd, TaskLabel, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
 22};
 23
 24pub(crate) struct WindowsDispatcher {
 25    pub(crate) wake_posted: AtomicBool,
 26    main_sender: Sender<Runnable>,
 27    main_thread_id: ThreadId,
 28    platform_window_handle: SafeHwnd,
 29    validation_number: usize,
 30}
 31
 32impl WindowsDispatcher {
 33    pub(crate) fn new(
 34        main_sender: Sender<Runnable>,
 35        platform_window_handle: HWND,
 36        validation_number: usize,
 37    ) -> Self {
 38        let main_thread_id = current().id();
 39        let platform_window_handle = platform_window_handle.into();
 40
 41        WindowsDispatcher {
 42            main_sender,
 43            main_thread_id,
 44            platform_window_handle,
 45            validation_number,
 46            wake_posted: AtomicBool::new(false),
 47        }
 48    }
 49
 50    fn dispatch_on_threadpool(&self, runnable: Runnable) {
 51        let handler = {
 52            let mut task_wrapper = Some(runnable);
 53            WorkItemHandler::new(move |_| {
 54                task_wrapper.take().unwrap().run();
 55                Ok(())
 56            })
 57        };
 58        ThreadPool::RunWithPriorityAsync(&handler, WorkItemPriority::High).log_err();
 59    }
 60
 61    fn dispatch_on_threadpool_after(&self, runnable: Runnable, duration: Duration) {
 62        let handler = {
 63            let mut task_wrapper = Some(runnable);
 64            TimerElapsedHandler::new(move |_| {
 65                task_wrapper.take().unwrap().run();
 66                Ok(())
 67            })
 68        };
 69        ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
 70    }
 71}
 72
 73impl PlatformDispatcher for WindowsDispatcher {
 74    fn is_main_thread(&self) -> bool {
 75        current().id() == self.main_thread_id
 76    }
 77
 78    fn dispatch(&self, runnable: Runnable, label: Option<TaskLabel>) {
 79        self.dispatch_on_threadpool(runnable);
 80        if let Some(label) = label {
 81            log::debug!("TaskLabel: {label:?}");
 82        }
 83    }
 84
 85    fn dispatch_on_main_thread(&self, runnable: Runnable) {
 86        match self.main_sender.send(runnable) {
 87            Ok(_) => {
 88                if !self.wake_posted.swap(true, Ordering::AcqRel) {
 89                    unsafe {
 90                        PostMessageW(
 91                            Some(self.platform_window_handle.as_raw()),
 92                            WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
 93                            WPARAM(self.validation_number),
 94                            LPARAM(0),
 95                        )
 96                        .log_err();
 97                    }
 98                }
 99            }
100            Err(runnable) => {
101                // NOTE: Runnable may wrap a Future that is !Send.
102                //
103                // This is usually safe because we only poll it on the main thread.
104                // However if the send fails, we know that:
105                // 1. main_receiver has been dropped (which implies the app is shutting down)
106                // 2. we are on a background thread.
107                // It is not safe to drop something !Send on the wrong thread, and
108                // the app will exit soon anyway, so we must forget the runnable.
109                std::mem::forget(runnable);
110            }
111        }
112    }
113
114    fn dispatch_after(&self, duration: Duration, runnable: Runnable) {
115        self.dispatch_on_threadpool_after(runnable, duration);
116    }
117}