notification_store.rs

  1use anyhow::Result;
  2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
  3use client::{Client, UserStore};
  4use collections::HashMap;
  5use db::smol::stream::StreamExt;
  6use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
  7use rpc::{proto, Notification, TypedEnvelope};
  8use std::{ops::Range, sync::Arc};
  9use sum_tree::{Bias, SumTree};
 10use time::OffsetDateTime;
 11use util::ResultExt;
 12
 13pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
 14    let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
 15    cx.set_global(notification_store);
 16}
 17
 18pub struct NotificationStore {
 19    client: Arc<Client>,
 20    user_store: ModelHandle<UserStore>,
 21    channel_messages: HashMap<u64, ChannelMessage>,
 22    channel_store: ModelHandle<ChannelStore>,
 23    notifications: SumTree<NotificationEntry>,
 24    _watch_connection_status: Task<Option<()>>,
 25    _subscriptions: Vec<client::Subscription>,
 26}
 27
 28#[derive(Clone, PartialEq, Eq, Debug)]
 29pub enum NotificationEvent {
 30    NotificationsUpdated {
 31        old_range: Range<usize>,
 32        new_count: usize,
 33    },
 34    NewNotification {
 35        entry: NotificationEntry,
 36    },
 37    NotificationRemoved {
 38        entry: NotificationEntry,
 39    },
 40    NotificationRead {
 41        entry: NotificationEntry,
 42    },
 43}
 44
 45#[derive(Debug, PartialEq, Eq, Clone)]
 46pub struct NotificationEntry {
 47    pub id: u64,
 48    pub notification: Notification,
 49    pub timestamp: OffsetDateTime,
 50    pub is_read: bool,
 51    pub response: Option<bool>,
 52}
 53
 54#[derive(Clone, Debug, Default)]
 55pub struct NotificationSummary {
 56    max_id: u64,
 57    count: usize,
 58    unread_count: usize,
 59}
 60
 61#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 62struct Count(usize);
 63
 64#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 65struct UnreadCount(usize);
 66
 67#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 68struct NotificationId(u64);
 69
 70impl NotificationStore {
 71    pub fn global(cx: &AppContext) -> ModelHandle<Self> {
 72        cx.global::<ModelHandle<Self>>().clone()
 73    }
 74
 75    pub fn new(
 76        client: Arc<Client>,
 77        user_store: ModelHandle<UserStore>,
 78        cx: &mut ModelContext<Self>,
 79    ) -> Self {
 80        let mut connection_status = client.status();
 81        let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
 82            while let Some(status) = connection_status.next().await {
 83                let this = this.upgrade(&cx)?;
 84                match status {
 85                    client::Status::Connected { .. } => {
 86                        this.update(&mut cx, |this, cx| this.handle_connect(cx))
 87                            .await
 88                            .log_err()?;
 89                    }
 90                    _ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
 91                }
 92            }
 93            Some(())
 94        });
 95
 96        Self {
 97            channel_store: ChannelStore::global(cx),
 98            notifications: Default::default(),
 99            channel_messages: Default::default(),
100            _watch_connection_status: watch_connection_status,
101            _subscriptions: vec![
102                client.add_message_handler(cx.handle(), Self::handle_new_notification),
103                client.add_message_handler(cx.handle(), Self::handle_delete_notification),
104            ],
105            user_store,
106            client,
107        }
108    }
109
110    pub fn notification_count(&self) -> usize {
111        self.notifications.summary().count
112    }
113
114    pub fn unread_notification_count(&self) -> usize {
115        self.notifications.summary().unread_count
116    }
117
118    pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
119        self.channel_messages.get(&id)
120    }
121
122    // Get the nth newest notification.
123    pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
124        let count = self.notifications.summary().count;
125        if ix >= count {
126            return None;
127        }
128        let ix = count - 1 - ix;
129        let mut cursor = self.notifications.cursor::<Count>();
130        cursor.seek(&Count(ix), Bias::Right, &());
131        cursor.item()
132    }
133
134    pub fn load_more_notifications(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
135        let request = self
136            .client
137            .request(proto::GetNotifications { before_id: None });
138        cx.spawn(|this, cx| async move {
139            let response = request.await?;
140            Self::add_notifications(this, false, response.notifications, cx).await?;
141            Ok(())
142        })
143    }
144
145    fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
146        self.notifications = Default::default();
147        self.channel_messages = Default::default();
148        self.load_more_notifications(cx)
149    }
150
151    fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
152        cx.notify()
153    }
154
155    async fn handle_new_notification(
156        this: ModelHandle<Self>,
157        envelope: TypedEnvelope<proto::NewNotification>,
158        _: Arc<Client>,
159        cx: AsyncAppContext,
160    ) -> Result<()> {
161        Self::add_notifications(
162            this,
163            true,
164            envelope.payload.notification.into_iter().collect(),
165            cx,
166        )
167        .await
168    }
169
170    async fn handle_delete_notification(
171        this: ModelHandle<Self>,
172        envelope: TypedEnvelope<proto::DeleteNotification>,
173        _: Arc<Client>,
174        mut cx: AsyncAppContext,
175    ) -> Result<()> {
176        this.update(&mut cx, |this, cx| {
177            this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
178            Ok(())
179        })
180    }
181
182    async fn add_notifications(
183        this: ModelHandle<Self>,
184        is_new: bool,
185        notifications: Vec<proto::Notification>,
186        mut cx: AsyncAppContext,
187    ) -> Result<()> {
188        let mut user_ids = Vec::new();
189        let mut message_ids = Vec::new();
190
191        let notifications = notifications
192            .into_iter()
193            .filter_map(|message| {
194                Some(NotificationEntry {
195                    id: message.id,
196                    is_read: message.is_read,
197                    timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
198                        .ok()?,
199                    notification: Notification::from_proto(&message)?,
200                    response: message.response,
201                })
202            })
203            .collect::<Vec<_>>();
204        if notifications.is_empty() {
205            return Ok(());
206        }
207
208        for entry in &notifications {
209            match entry.notification {
210                Notification::ChannelInvitation { inviter_id, .. } => {
211                    user_ids.push(inviter_id);
212                }
213                Notification::ContactRequest {
214                    sender_id: requester_id,
215                } => {
216                    user_ids.push(requester_id);
217                }
218                Notification::ContactRequestAccepted {
219                    responder_id: contact_id,
220                } => {
221                    user_ids.push(contact_id);
222                }
223                Notification::ChannelMessageMention {
224                    sender_id,
225                    message_id,
226                    ..
227                } => {
228                    user_ids.push(sender_id);
229                    message_ids.push(message_id);
230                }
231            }
232        }
233
234        let (user_store, channel_store) = this.read_with(&cx, |this, _| {
235            (this.user_store.clone(), this.channel_store.clone())
236        });
237
238        user_store
239            .update(&mut cx, |store, cx| store.get_users(user_ids, cx))
240            .await?;
241        let messages = channel_store
242            .update(&mut cx, |store, cx| {
243                store.fetch_channel_messages(message_ids, cx)
244            })
245            .await?;
246        this.update(&mut cx, |this, cx| {
247            this.channel_messages
248                .extend(messages.into_iter().filter_map(|message| {
249                    if let ChannelMessageId::Saved(id) = message.id {
250                        Some((id, message))
251                    } else {
252                        None
253                    }
254                }));
255
256            this.splice_notifications(
257                notifications
258                    .into_iter()
259                    .map(|notification| (notification.id, Some(notification))),
260                is_new,
261                cx,
262            );
263        });
264
265        Ok(())
266    }
267
268    fn splice_notifications(
269        &mut self,
270        notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
271        is_new: bool,
272        cx: &mut ModelContext<'_, NotificationStore>,
273    ) {
274        let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
275        let mut new_notifications = SumTree::new();
276        let mut old_range = 0..0;
277
278        for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
279            new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
280
281            if i == 0 {
282                old_range.start = cursor.start().1 .0;
283            }
284
285            let old_notification = cursor.item();
286            if let Some(old_notification) = old_notification {
287                if old_notification.id == id {
288                    cursor.next(&());
289
290                    if let Some(new_notification) = &new_notification {
291                        if new_notification.is_read {
292                            cx.emit(NotificationEvent::NotificationRead {
293                                entry: new_notification.clone(),
294                            });
295                        }
296                    } else {
297                        cx.emit(NotificationEvent::NotificationRemoved {
298                            entry: old_notification.clone(),
299                        });
300                    }
301                }
302            } else if let Some(new_notification) = &new_notification {
303                if is_new {
304                    cx.emit(NotificationEvent::NewNotification {
305                        entry: new_notification.clone(),
306                    });
307                }
308            }
309
310            if let Some(notification) = new_notification {
311                new_notifications.push(notification, &());
312            }
313        }
314
315        old_range.end = cursor.start().1 .0;
316        let new_count = new_notifications.summary().count - old_range.start;
317        new_notifications.append(cursor.suffix(&()), &());
318        drop(cursor);
319
320        self.notifications = new_notifications;
321        cx.emit(NotificationEvent::NotificationsUpdated {
322            old_range,
323            new_count,
324        });
325    }
326
327    pub fn respond_to_notification(
328        &mut self,
329        notification: Notification,
330        response: bool,
331        cx: &mut ModelContext<Self>,
332    ) {
333        match notification {
334            Notification::ContactRequest { sender_id } => {
335                self.user_store
336                    .update(cx, |store, cx| {
337                        store.respond_to_contact_request(sender_id, response, cx)
338                    })
339                    .detach();
340            }
341            Notification::ChannelInvitation { channel_id, .. } => {
342                self.channel_store
343                    .update(cx, |store, cx| {
344                        store.respond_to_channel_invite(channel_id, response, cx)
345                    })
346                    .detach();
347            }
348            _ => {}
349        }
350    }
351}
352
353impl Entity for NotificationStore {
354    type Event = NotificationEvent;
355}
356
357impl sum_tree::Item for NotificationEntry {
358    type Summary = NotificationSummary;
359
360    fn summary(&self) -> Self::Summary {
361        NotificationSummary {
362            max_id: self.id,
363            count: 1,
364            unread_count: if self.is_read { 0 } else { 1 },
365        }
366    }
367}
368
369impl sum_tree::Summary for NotificationSummary {
370    type Context = ();
371
372    fn add_summary(&mut self, summary: &Self, _: &()) {
373        self.max_id = self.max_id.max(summary.max_id);
374        self.count += summary.count;
375        self.unread_count += summary.unread_count;
376    }
377}
378
379impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
380    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
381        debug_assert!(summary.max_id > self.0);
382        self.0 = summary.max_id;
383    }
384}
385
386impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
387    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
388        self.0 += summary.count;
389    }
390}
391
392impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
393    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
394        self.0 += summary.unread_count;
395    }
396}