1mod error;
2
3pub use error::*;
4use parking_lot::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard};
5use std::{
6 collections::BTreeMap,
7 mem,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll, Waker},
11};
12
13pub fn channel<T>(value: T) -> (Sender<T>, Receiver<T>) {
14 let state = Arc::new(RwLock::new(State {
15 value,
16 wakers: BTreeMap::new(),
17 next_waker_id: WakerId::default(),
18 version: 0,
19 closed: false,
20 }));
21
22 (
23 Sender {
24 state: state.clone(),
25 },
26 Receiver { state, version: 0 },
27 )
28}
29
30#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
31struct WakerId(usize);
32
33impl WakerId {
34 fn post_inc(&mut self) -> Self {
35 let id = *self;
36 self.0 = id.0.wrapping_add(1);
37 *self
38 }
39}
40
41struct State<T> {
42 value: T,
43 wakers: BTreeMap<WakerId, Waker>,
44 next_waker_id: WakerId,
45 version: usize,
46 closed: bool,
47}
48
49pub struct Sender<T> {
50 state: Arc<RwLock<State<T>>>,
51}
52
53impl<T> Sender<T> {
54 pub fn receiver(&self) -> Receiver<T> {
55 let version = self.state.read().version;
56 Receiver {
57 state: self.state.clone(),
58 version,
59 }
60 }
61
62 pub fn send(&mut self, value: T) -> Result<(), NoReceiverError> {
63 if let Some(state) = Arc::get_mut(&mut self.state) {
64 let state = state.get_mut();
65 state.value = value;
66 debug_assert_eq!(state.wakers.len(), 0);
67 Err(NoReceiverError)
68 } else {
69 let mut state = self.state.write();
70 state.value = value;
71 state.version = state.version.wrapping_add(1);
72 let wakers = mem::take(&mut state.wakers);
73 drop(state);
74
75 for (_, waker) in wakers {
76 waker.wake();
77 }
78
79 Ok(())
80 }
81 }
82}
83
84impl<T> Drop for Sender<T> {
85 fn drop(&mut self) {
86 let mut state = self.state.write();
87 state.closed = true;
88 for (_, waker) in mem::take(&mut state.wakers) {
89 waker.wake();
90 }
91 }
92}
93
94#[derive(Clone)]
95pub struct Receiver<T> {
96 state: Arc<RwLock<State<T>>>,
97 version: usize,
98}
99
100struct Changed<'a, T> {
101 receiver: &'a mut Receiver<T>,
102 pending_waker_id: Option<WakerId>,
103}
104
105impl<T> Future for Changed<'_, T> {
106 type Output = Result<(), NoSenderError>;
107
108 fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
109 let this = &mut *self;
110
111 let state = this.receiver.state.upgradable_read();
112 if state.version != this.receiver.version {
113 // The sender produced a new value. Avoid unregistering the pending
114 // waker, because the sender has already done so.
115 this.pending_waker_id = None;
116 this.receiver.version = state.version;
117 Poll::Ready(Ok(()))
118 } else if state.closed {
119 Poll::Ready(Err(NoSenderError))
120 } else {
121 let mut state = RwLockUpgradableReadGuard::upgrade(state);
122
123 // Unregister the pending waker. This should happen automatically
124 // when the waker gets awoken by the sender, but if this future was
125 // polled again without an explicit call to `wake` (e.g., a spurious
126 // wake by the executor), we need to remove it manually.
127 if let Some(pending_waker_id) = this.pending_waker_id.take() {
128 state.wakers.remove(&pending_waker_id);
129 }
130
131 // Register the waker for this future.
132 let waker_id = state.next_waker_id.post_inc();
133 state.wakers.insert(waker_id, cx.waker().clone());
134 this.pending_waker_id = Some(waker_id);
135
136 Poll::Pending
137 }
138 }
139}
140
141impl<T> Drop for Changed<'_, T> {
142 fn drop(&mut self) {
143 // If this future gets dropped before the waker has a chance of being
144 // awoken, we need to clear it to avoid a memory leak.
145 if let Some(waker_id) = self.pending_waker_id {
146 let mut state = self.receiver.state.write();
147 state.wakers.remove(&waker_id);
148 }
149 }
150}
151
152impl<T> Receiver<T> {
153 pub fn borrow(&mut self) -> parking_lot::MappedRwLockReadGuard<'_, T> {
154 let state = self.state.read();
155 self.version = state.version;
156 RwLockReadGuard::map(state, |state| &state.value)
157 }
158
159 pub fn changed(&mut self) -> impl Future<Output = Result<(), NoSenderError>> {
160 Changed {
161 receiver: self,
162 pending_waker_id: None,
163 }
164 }
165
166 /// Creates a new [`Receiver`] holding an initial value that will never change.
167 pub fn constant(value: T) -> Self {
168 let state = Arc::new(RwLock::new(State {
169 value,
170 wakers: BTreeMap::new(),
171 next_waker_id: WakerId::default(),
172 version: 0,
173 closed: false,
174 }));
175
176 Self { state, version: 0 }
177 }
178}
179
180impl<T: Clone> Receiver<T> {
181 pub async fn recv(&mut self) -> Result<T, NoSenderError> {
182 self.changed().await?;
183 Ok(self.borrow().clone())
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use futures::{FutureExt, select_biased};
191 use gpui::{AppContext, TestAppContext};
192 use std::{
193 pin::pin,
194 sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
195 };
196
197 #[gpui::test]
198 async fn test_basic_watch() {
199 let (mut sender, mut receiver) = channel(0);
200 assert_eq!(sender.send(1), Ok(()));
201 assert_eq!(receiver.recv().await, Ok(1));
202
203 assert_eq!(sender.send(2), Ok(()));
204 assert_eq!(sender.send(3), Ok(()));
205 assert_eq!(receiver.recv().await, Ok(3));
206
207 drop(receiver);
208 assert_eq!(sender.send(4), Err(NoReceiverError));
209
210 let mut receiver = sender.receiver();
211 assert_eq!(sender.send(5), Ok(()));
212 assert_eq!(receiver.recv().await, Ok(5));
213
214 // Ensure `changed` doesn't resolve if we just read the latest value
215 // using `borrow`.
216 assert_eq!(sender.send(6), Ok(()));
217 assert_eq!(*receiver.borrow(), 6);
218 assert_eq!(receiver.changed().now_or_never(), None);
219
220 assert_eq!(sender.send(7), Ok(()));
221 drop(sender);
222 assert_eq!(receiver.recv().await, Ok(7));
223 assert_eq!(receiver.recv().await, Err(NoSenderError));
224 }
225
226 #[gpui::test(iterations = 1000)]
227 async fn test_watch_random(cx: &mut TestAppContext) {
228 let next_id = Arc::new(AtomicUsize::new(1));
229 let closed = Arc::new(AtomicBool::new(false));
230 let (mut tx, rx) = channel(0);
231 let mut tasks = Vec::new();
232
233 tasks.push(cx.background_spawn({
234 let executor = cx.executor();
235 let next_id = next_id.clone();
236 let closed = closed.clone();
237 async move {
238 for _ in 0..16 {
239 executor.simulate_random_delay().await;
240 let id = next_id.fetch_add(1, SeqCst);
241 zlog::info!("sending {}", id);
242 tx.send(id).ok();
243 }
244 closed.store(true, SeqCst);
245 }
246 }));
247
248 for receiver_id in 0..16 {
249 let executor = cx.executor().clone();
250 let next_id = next_id.clone();
251 let closed = closed.clone();
252 let mut rx = rx.clone();
253 let mut prev_observed_value = *rx.borrow();
254 tasks.push(cx.background_spawn(async move {
255 for _ in 0..16 {
256 executor.simulate_random_delay().await;
257
258 zlog::info!("{}: receiving", receiver_id);
259 let mut timeout = executor.simulate_random_delay().fuse();
260 let mut recv = pin!(rx.recv().fuse());
261 select_biased! {
262 _ = timeout => {
263 zlog::info!("{}: dropping recv future", receiver_id);
264 }
265 result = recv => {
266 match result {
267 Ok(value) => {
268 zlog::info!("{}: received {}", receiver_id, value);
269 assert_eq!(value, next_id.load(SeqCst) - 1);
270 assert_ne!(value, prev_observed_value);
271 prev_observed_value = value;
272 }
273 Err(NoSenderError) => {
274 zlog::info!("{}: closed", receiver_id);
275 assert!(closed.load(SeqCst));
276 break;
277 }
278 }
279 }
280 }
281 }
282 }));
283 }
284
285 futures::future::join_all(tasks).await;
286 }
287
288 #[ctor::ctor]
289 fn init_logger() {
290 zlog::init_test();
291 }
292}