1use std::{
2 sync::atomic::{AtomicBool, Ordering},
3 thread::{ThreadId, current},
4 time::{Duration, Instant},
5};
6
7use anyhow::Context;
8use util::ResultExt;
9use windows::{
10 System::Threading::{
11 ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
12 },
13 Win32::{
14 Foundation::{LPARAM, WPARAM},
15 Media::{timeBeginPeriod, timeEndPeriod},
16 System::Threading::{
17 GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
18 THREAD_PRIORITY_TIME_CRITICAL,
19 },
20 UI::WindowsAndMessaging::PostMessageW,
21 },
22};
23
24use crate::{HWND, SafeHwnd, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD};
25use gpui::{
26 GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, PriorityQueueSender, RunnableVariant,
27 THREAD_TIMINGS, TaskTiming, ThreadTaskTimings, TimerResolutionGuard,
28};
29
30pub(crate) struct WindowsDispatcher {
31 pub(crate) wake_posted: AtomicBool,
32 main_sender: PriorityQueueSender<RunnableVariant>,
33 main_thread_id: ThreadId,
34 pub(crate) platform_window_handle: SafeHwnd,
35 validation_number: usize,
36}
37
38impl WindowsDispatcher {
39 pub(crate) fn new(
40 main_sender: PriorityQueueSender<RunnableVariant>,
41 platform_window_handle: HWND,
42 validation_number: usize,
43 ) -> Self {
44 let main_thread_id = current().id();
45 let platform_window_handle = platform_window_handle.into();
46
47 WindowsDispatcher {
48 main_sender,
49 main_thread_id,
50 platform_window_handle,
51 validation_number,
52 wake_posted: AtomicBool::new(false),
53 }
54 }
55
56 fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
57 let handler = {
58 let mut task_wrapper = Some(runnable);
59 WorkItemHandler::new(move |_| {
60 let runnable = task_wrapper.take().unwrap();
61 // Check if the executor that spawned this task was closed
62 if runnable.metadata().is_closed() {
63 return Ok(());
64 }
65 Self::execute_runnable(runnable);
66 Ok(())
67 })
68 };
69
70 ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
71 }
72
73 fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
74 let handler = {
75 let mut task_wrapper = Some(runnable);
76 TimerElapsedHandler::new(move |_| {
77 let runnable = task_wrapper.take().unwrap();
78 // Check if the executor that spawned this task was closed
79 if runnable.metadata().is_closed() {
80 return Ok(());
81 }
82 Self::execute_runnable(runnable);
83 Ok(())
84 })
85 };
86 ThreadPoolTimer::CreateTimer(&handler, duration.into()).log_err();
87 }
88
89 #[inline(always)]
90 pub(crate) fn execute_runnable(runnable: RunnableVariant) {
91 let start = Instant::now();
92
93 let location = runnable.metadata().location;
94 let mut timing = TaskTiming {
95 location,
96 start,
97 end: None,
98 };
99 gpui::profiler::add_task_timing(timing);
100
101 runnable.run();
102
103 let end = Instant::now();
104 timing.end = Some(end);
105
106 gpui::profiler::add_task_timing(timing);
107 }
108}
109
110impl PlatformDispatcher for WindowsDispatcher {
111 fn get_all_timings(&self) -> Vec<ThreadTaskTimings> {
112 let global_thread_timings = GLOBAL_THREAD_TIMINGS.lock();
113 ThreadTaskTimings::convert(&global_thread_timings)
114 }
115
116 fn get_current_thread_timings(&self) -> gpui::ThreadTaskTimings {
117 THREAD_TIMINGS.with(|timings| {
118 let timings = timings.lock();
119 let thread_name = timings.thread_name.clone();
120 let total_pushed = timings.total_pushed;
121 let timings = &timings.timings;
122
123 let mut vec = Vec::with_capacity(timings.len());
124
125 let (s1, s2) = timings.as_slices();
126 vec.extend_from_slice(s1);
127 vec.extend_from_slice(s2);
128
129 gpui::ThreadTaskTimings {
130 thread_name,
131 thread_id: std::thread::current().id(),
132 timings: vec,
133 total_pushed,
134 }
135 })
136 }
137
138 fn is_main_thread(&self) -> bool {
139 current().id() == self.main_thread_id
140 }
141
142 fn dispatch(&self, runnable: RunnableVariant, priority: Priority) {
143 let priority = match priority {
144 Priority::RealtimeAudio => {
145 panic!("RealtimeAudio priority should use spawn_realtime, not dispatch")
146 }
147 Priority::High => WorkItemPriority::High,
148 Priority::Medium => WorkItemPriority::Normal,
149 Priority::Low => WorkItemPriority::Low,
150 };
151 self.dispatch_on_threadpool(priority, runnable);
152 }
153
154 fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
155 match self.main_sender.send(priority, runnable) {
156 Ok(_) => {
157 if !self.wake_posted.swap(true, Ordering::AcqRel) {
158 unsafe {
159 PostMessageW(
160 Some(self.platform_window_handle.as_raw()),
161 WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
162 WPARAM(self.validation_number),
163 LPARAM(0),
164 )
165 .log_err();
166 }
167 }
168 }
169 Err(runnable) => {
170 // NOTE: Runnable may wrap a Future that is !Send.
171 //
172 // This is usually safe because we only poll it on the main thread.
173 // However if the send fails, we know that:
174 // 1. main_receiver has been dropped (which implies the app is shutting down)
175 // 2. we are on a background thread.
176 // It is not safe to drop something !Send on the wrong thread, and
177 // the app will exit soon anyway, so we must forget the runnable.
178 std::mem::forget(runnable);
179 }
180 }
181 }
182
183 fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
184 self.dispatch_on_threadpool_after(runnable, duration);
185 }
186
187 fn spawn_realtime(&self, f: Box<dyn FnOnce() + Send>) {
188 std::thread::spawn(move || {
189 // SAFETY: always safe to call
190 let thread_handle = unsafe { GetCurrentThread() };
191
192 // SAFETY: thread_handle is a valid handle to a thread
193 unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
194 .context("thread priority class")
195 .log_err();
196
197 // SAFETY: thread_handle is a valid handle to a thread
198 unsafe { SetThreadPriority(thread_handle, THREAD_PRIORITY_TIME_CRITICAL) }
199 .context("thread priority")
200 .log_err();
201
202 f();
203 });
204 }
205
206 fn increase_timer_resolution(&self) -> TimerResolutionGuard {
207 unsafe {
208 timeBeginPeriod(1);
209 }
210 util::defer(Box::new(|| unsafe {
211 timeEndPeriod(1);
212 }))
213 }
214}