1mod error;
  2
  3pub use error::*;
  4use parking_lot::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard};
  5use std::{
  6    collections::BTreeMap,
  7    mem,
  8    pin::Pin,
  9    sync::Arc,
 10    task::{Context, Poll, Waker},
 11};
 12
 13pub fn channel<T>(value: T) -> (Sender<T>, Receiver<T>) {
 14    let state = Arc::new(RwLock::new(State {
 15        value,
 16        wakers: BTreeMap::new(),
 17        next_waker_id: WakerId::default(),
 18        version: 0,
 19        closed: false,
 20    }));
 21
 22    (
 23        Sender {
 24            state: state.clone(),
 25        },
 26        Receiver { state, version: 0 },
 27    )
 28}
 29
 30#[derive(Default, Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
 31struct WakerId(usize);
 32
 33impl WakerId {
 34    const fn post_inc(&mut self) -> Self {
 35        let id = *self;
 36        self.0 = id.0.wrapping_add(1);
 37        *self
 38    }
 39}
 40
 41struct State<T> {
 42    value: T,
 43    wakers: BTreeMap<WakerId, Waker>,
 44    next_waker_id: WakerId,
 45    version: usize,
 46    closed: bool,
 47}
 48
 49pub struct Sender<T> {
 50    state: Arc<RwLock<State<T>>>,
 51}
 52
 53impl<T> Sender<T> {
 54    pub fn receiver(&self) -> Receiver<T> {
 55        let version = self.state.read().version;
 56        Receiver {
 57            state: self.state.clone(),
 58            version,
 59        }
 60    }
 61
 62    pub fn send(&mut self, value: T) -> Result<(), NoReceiverError> {
 63        if let Some(state) = Arc::get_mut(&mut self.state) {
 64            let state = state.get_mut();
 65            state.value = value;
 66            debug_assert_eq!(state.wakers.len(), 0);
 67            Err(NoReceiverError)
 68        } else {
 69            let mut state = self.state.write();
 70            state.value = value;
 71            state.version = state.version.wrapping_add(1);
 72            let wakers = mem::take(&mut state.wakers);
 73            drop(state);
 74
 75            for (_, waker) in wakers {
 76                waker.wake();
 77            }
 78
 79            Ok(())
 80        }
 81    }
 82}
 83
 84impl<T> Drop for Sender<T> {
 85    fn drop(&mut self) {
 86        let mut state = self.state.write();
 87        state.closed = true;
 88        for (_, waker) in mem::take(&mut state.wakers) {
 89            waker.wake();
 90        }
 91    }
 92}
 93
 94#[derive(Clone)]
 95pub struct Receiver<T> {
 96    state: Arc<RwLock<State<T>>>,
 97    version: usize,
 98}
 99
100struct Changed<'a, T> {
101    receiver: &'a mut Receiver<T>,
102    pending_waker_id: Option<WakerId>,
103}
104
105impl<T> Future for Changed<'_, T> {
106    type Output = Result<(), NoSenderError>;
107
108    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
109        let this = &mut *self;
110
111        let state = this.receiver.state.upgradable_read();
112        if state.version != this.receiver.version {
113            // The sender produced a new value. Avoid unregistering the pending
114            // waker, because the sender has already done so.
115            this.pending_waker_id = None;
116            this.receiver.version = state.version;
117            Poll::Ready(Ok(()))
118        } else if state.closed {
119            Poll::Ready(Err(NoSenderError))
120        } else {
121            let mut state = RwLockUpgradableReadGuard::upgrade(state);
122
123            // Unregister the pending waker. This should happen automatically
124            // when the waker gets awoken by the sender, but if this future was
125            // polled again without an explicit call to `wake` (e.g., a spurious
126            // wake by the executor), we need to remove it manually.
127            if let Some(pending_waker_id) = this.pending_waker_id.take() {
128                state.wakers.remove(&pending_waker_id);
129            }
130
131            // Register the waker for this future.
132            let waker_id = state.next_waker_id.post_inc();
133            state.wakers.insert(waker_id, cx.waker().clone());
134            this.pending_waker_id = Some(waker_id);
135
136            Poll::Pending
137        }
138    }
139}
140
141impl<T> Drop for Changed<'_, T> {
142    fn drop(&mut self) {
143        // If this future gets dropped before the waker has a chance of being
144        // awoken, we need to clear it to avoid a memory leak.
145        if let Some(waker_id) = self.pending_waker_id {
146            let mut state = self.receiver.state.write();
147            state.wakers.remove(&waker_id);
148        }
149    }
150}
151
152impl<T> Receiver<T> {
153    pub fn borrow(&mut self) -> parking_lot::MappedRwLockReadGuard<'_, T> {
154        let state = self.state.read();
155        self.version = state.version;
156        RwLockReadGuard::map(state, |state| &state.value)
157    }
158
159    pub fn changed(&mut self) -> impl Future<Output = Result<(), NoSenderError>> {
160        Changed {
161            receiver: self,
162            pending_waker_id: None,
163        }
164    }
165
166    /// Creates a new [`Receiver`] holding an initial value that will never change.
167    pub fn constant(value: T) -> Self {
168        let state = Arc::new(RwLock::new(State {
169            value,
170            wakers: BTreeMap::new(),
171            next_waker_id: WakerId::default(),
172            version: 0,
173            closed: false,
174        }));
175
176        Self { state, version: 0 }
177    }
178}
179
180impl<T: Clone> Receiver<T> {
181    pub async fn recv(&mut self) -> Result<T, NoSenderError> {
182        self.changed().await?;
183        Ok(self.borrow().clone())
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use futures::{FutureExt, select_biased};
191    use gpui::{AppContext, TestAppContext};
192    use std::{
193        pin::pin,
194        sync::atomic::{AtomicBool, AtomicUsize, Ordering::SeqCst},
195    };
196
197    #[gpui::test]
198    async fn test_basic_watch() {
199        let (mut sender, mut receiver) = channel(0);
200        assert_eq!(sender.send(1), Ok(()));
201        assert_eq!(receiver.recv().await, Ok(1));
202
203        assert_eq!(sender.send(2), Ok(()));
204        assert_eq!(sender.send(3), Ok(()));
205        assert_eq!(receiver.recv().await, Ok(3));
206
207        drop(receiver);
208        assert_eq!(sender.send(4), Err(NoReceiverError));
209
210        let mut receiver = sender.receiver();
211        assert_eq!(sender.send(5), Ok(()));
212        assert_eq!(receiver.recv().await, Ok(5));
213
214        // Ensure `changed` doesn't resolve if we just read the latest value
215        // using `borrow`.
216        assert_eq!(sender.send(6), Ok(()));
217        assert_eq!(*receiver.borrow(), 6);
218        assert_eq!(receiver.changed().now_or_never(), None);
219
220        assert_eq!(sender.send(7), Ok(()));
221        drop(sender);
222        assert_eq!(receiver.recv().await, Ok(7));
223        assert_eq!(receiver.recv().await, Err(NoSenderError));
224    }
225
226    #[gpui::test(iterations = 1000)]
227    async fn test_watch_random(cx: &mut TestAppContext) {
228        let next_id = Arc::new(AtomicUsize::new(1));
229        let closed = Arc::new(AtomicBool::new(false));
230        let (mut tx, rx) = channel(0);
231        let mut tasks = Vec::new();
232
233        tasks.push(cx.background_spawn({
234            let executor = cx.executor();
235            let next_id = next_id.clone();
236            let closed = closed.clone();
237            async move {
238                for _ in 0..16 {
239                    executor.simulate_random_delay().await;
240                    let id = next_id.fetch_add(1, SeqCst);
241                    zlog::info!("sending {}", id);
242                    tx.send(id).ok();
243                }
244                closed.store(true, SeqCst);
245            }
246        }));
247
248        for receiver_id in 0..16 {
249            let executor = cx.executor().clone();
250            let next_id = next_id.clone();
251            let closed = closed.clone();
252            let mut rx = rx.clone();
253            let mut prev_observed_value = *rx.borrow();
254            tasks.push(cx.background_spawn(async move {
255                for _ in 0..16 {
256                    executor.simulate_random_delay().await;
257
258                    zlog::info!("{}: receiving", receiver_id);
259                    let mut timeout = executor.simulate_random_delay().fuse();
260                    let mut recv = pin!(rx.recv().fuse());
261                    select_biased! {
262                        _ = timeout => {
263                            zlog::info!("{}: dropping recv future", receiver_id);
264                        }
265                        result = recv => {
266                            match result {
267                                Ok(value) => {
268                                    zlog::info!("{}: received {}", receiver_id, value);
269                                    assert_eq!(value, next_id.load(SeqCst) - 1);
270                                    assert_ne!(value, prev_observed_value);
271                                    prev_observed_value = value;
272                                }
273                                Err(NoSenderError) => {
274                                    zlog::info!("{}: closed", receiver_id);
275                                    assert!(closed.load(SeqCst));
276                                    break;
277                                }
278                            }
279                        }
280                    }
281                }
282            }));
283        }
284
285        futures::future::join_all(tasks).await;
286    }
287
288    #[ctor::ctor]
289    fn init_logger() {
290        zlog::init_test();
291    }
292}