dispatcher.rs

  1use std::{
  2    sync::atomic::{AtomicBool, Ordering},
  3    thread::{ThreadId, current},
  4    time::{Duration, Instant},
  5};
  6
  7use anyhow::Context;
  8use util::ResultExt;
  9use windows::{
 10    System::Threading::{
 11        ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
 12    },
 13    Win32::{
 14        Foundation::{LPARAM, WPARAM},
 15        Media::{timeBeginPeriod, timeEndPeriod},
 16        System::Threading::{
 17            GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
 18            THREAD_PRIORITY_TIME_CRITICAL,
 19        },
 20        UI::WindowsAndMessaging::PostMessageW,
 21    },
 22};
 23
 24use crate::{HWND, SafeHwnd, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD};
 25use gpui::{
 26    GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, PriorityQueueSender, RunnableVariant,
 27    THREAD_TIMINGS, TaskTiming, ThreadTaskTimings, TimerResolutionGuard,
 28};
 29
 30pub(crate) struct WindowsDispatcher {
 31    pub(crate) wake_posted: AtomicBool,
 32    main_sender: PriorityQueueSender<RunnableVariant>,
 33    main_thread_id: ThreadId,
 34    pub(crate) platform_window_handle: SafeHwnd,
 35    validation_number: usize,
 36}
 37
 38impl WindowsDispatcher {
 39    pub(crate) fn new(
 40        main_sender: PriorityQueueSender<RunnableVariant>,
 41        platform_window_handle: HWND,
 42        validation_number: usize,
 43    ) -> Self {
 44        let main_thread_id = current().id();
 45        let platform_window_handle = platform_window_handle.into();
 46
 47        WindowsDispatcher {
 48            main_sender,
 49            main_thread_id,
 50            platform_window_handle,
 51            validation_number,
 52            wake_posted: AtomicBool::new(false),
 53        }
 54    }
 55
 56    fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
 57        let handler = {
 58            let mut task_wrapper = Some(runnable);
 59            WorkItemHandler::new(move |_| {
 60                let runnable = task_wrapper.take().unwrap();
 61                // Check if the executor that spawned this task was closed
 62                if runnable.metadata().is_closed() {
 63                    return Ok(());
 64                }
 65                Self::execute_runnable(runnable);
 66                Ok(())
 67            })
 68        };
 69
 70        ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
 71    }
 72
 73    fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
 74        let handler = {
 75            let mut task_wrapper = Some(runnable);
 76            TimerElapsedHandler::new(move |_| {
 77                let runnable = task_wrapper.take().unwrap();
 78                // Check if the executor that spawned this task was closed
 79                if runnable.metadata().is_closed() {
 80                    return Ok(());
 81                }
 82                Self::execute_runnable(runnable);
 83                Ok(())
 84            })
 85        };
 86        ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
 87    }
 88
 89    #[inline(always)]
 90    pub(crate) fn execute_runnable(runnable: RunnableVariant) {
 91        let start = Instant::now();
 92
 93        let location = runnable.metadata().location;
 94        let mut timing = TaskTiming {
 95            location,
 96            start,
 97            end: None,
 98        };
 99        gpui::profiler::add_task_timing(timing);
100
101        runnable.run();
102
103        let end = Instant::now();
104        timing.end = Some(end);
105
106        gpui::profiler::add_task_timing(timing);
107    }
108}
109
110impl PlatformDispatcher for WindowsDispatcher {
111    fn get_all_timings(&self) -> Vec<ThreadTaskTimings> {
112        let global_thread_timings = GLOBAL_THREAD_TIMINGS.lock();
113        ThreadTaskTimings::convert(&global_thread_timings)
114    }
115
116    fn get_current_thread_timings(&self) -> gpui::ThreadTaskTimings {
117        THREAD_TIMINGS.with(|timings| {
118            let timings = timings.lock();
119            let thread_name = timings.thread_name.clone();
120            let total_pushed = timings.total_pushed;
121            let timings = &timings.timings;
122
123            let mut vec = Vec::with_capacity(timings.len());
124
125            let (s1, s2) = timings.as_slices();
126            vec.extend_from_slice(s1);
127            vec.extend_from_slice(s2);
128
129            gpui::ThreadTaskTimings {
130                thread_name,
131                thread_id: std::thread::current().id(),
132                timings: vec,
133                total_pushed,
134            }
135        })
136    }
137
138    fn is_main_thread(&self) -> bool {
139        current().id() == self.main_thread_id
140    }
141
142    fn dispatch(&self, runnable: RunnableVariant, priority: Priority) {
143        let priority = match priority {
144            Priority::RealtimeAudio => {
145                panic!("RealtimeAudio priority should use spawn_realtime, not dispatch")
146            }
147            Priority::High => WorkItemPriority::High,
148            Priority::Medium => WorkItemPriority::Normal,
149            Priority::Low => WorkItemPriority::Low,
150        };
151        self.dispatch_on_threadpool(priority, runnable);
152    }
153
154    fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
155        match self.main_sender.send(priority, runnable) {
156            Ok(_) => {
157                if !self.wake_posted.swap(true, Ordering::AcqRel) {
158                    unsafe {
159                        PostMessageW(
160                            Some(self.platform_window_handle.as_raw()),
161                            WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
162                            WPARAM(self.validation_number),
163                            LPARAM(0),
164                        )
165                        .log_err();
166                    }
167                }
168            }
169            Err(runnable) => {
170                // NOTE: Runnable may wrap a Future that is !Send.
171                //
172                // This is usually safe because we only poll it on the main thread.
173                // However if the send fails, we know that:
174                // 1. main_receiver has been dropped (which implies the app is shutting down)
175                // 2. we are on a background thread.
176                // It is not safe to drop something !Send on the wrong thread, and
177                // the app will exit soon anyway, so we must forget the runnable.
178                std::mem::forget(runnable);
179            }
180        }
181    }
182
183    fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
184        self.dispatch_on_threadpool_after(runnable, duration);
185    }
186
187    fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
188        std::thread::spawn(move || {
189            // SAFETY: always safe to call
190            let thread_handle = unsafe { GetCurrentThread() };
191
192            // SAFETY: thread_handle is a valid handle to a thread
193            unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
194                .context("thread priority class")
195                .log_err();
196
197            // SAFETY: thread_handle is a valid handle to a thread
198            unsafe { SetThreadPriority(thread_handle, THREAD_PRIORITY_TIME_CRITICAL) }
199                .context("thread priority")
200                .log_err();
201
202            f();
203        });
204    }
205
206    fn increase_timer_resolution(&self) -> TimerResolutionGuard {
207        unsafe {
208            timeBeginPeriod(1);
209        }
210        util::defer(Box::new(|| unsafe {
211            timeEndPeriod(1);
212        }))
213    }
214}