notification_store.rs

  1use anyhow::{Context, Result};
  2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
  3use client::{ChannelId, Client, UserStore};
  4use collections::HashMap;
  5use db::smol::stream::StreamExt;
  6use gpui::{
  7    AppContext, AsyncAppContext, Context as _, EventEmitter, Global, Model, ModelContext, Task,
  8};
  9use rpc::{proto, Notification, TypedEnvelope};
 10use std::{ops::Range, sync::Arc};
 11use sum_tree::{Bias, SumTree};
 12use time::OffsetDateTime;
 13use util::ResultExt;
 14
 15pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
 16    let notification_store = cx.new_model(|cx| NotificationStore::new(client, user_store, cx));
 17    cx.set_global(GlobalNotificationStore(notification_store));
 18}
 19
 20struct GlobalNotificationStore(Model<NotificationStore>);
 21
 22impl Global for GlobalNotificationStore {}
 23
 24pub struct NotificationStore {
 25    client: Arc<Client>,
 26    user_store: Model<UserStore>,
 27    channel_messages: HashMap<u64, ChannelMessage>,
 28    channel_store: Model<ChannelStore>,
 29    notifications: SumTree<NotificationEntry>,
 30    loaded_all_notifications: bool,
 31    _watch_connection_status: Task<Option<()>>,
 32    _subscriptions: Vec<client::Subscription>,
 33}
 34
 35#[derive(Clone, PartialEq, Eq, Debug)]
 36pub enum NotificationEvent {
 37    NotificationsUpdated {
 38        old_range: Range<usize>,
 39        new_count: usize,
 40    },
 41    NewNotification {
 42        entry: NotificationEntry,
 43    },
 44    NotificationRemoved {
 45        entry: NotificationEntry,
 46    },
 47    NotificationRead {
 48        entry: NotificationEntry,
 49    },
 50}
 51
 52#[derive(Debug, PartialEq, Eq, Clone)]
 53pub struct NotificationEntry {
 54    pub id: u64,
 55    pub notification: Notification,
 56    pub timestamp: OffsetDateTime,
 57    pub is_read: bool,
 58    pub response: Option<bool>,
 59}
 60
 61#[derive(Clone, Debug, Default)]
 62pub struct NotificationSummary {
 63    max_id: u64,
 64    count: usize,
 65    unread_count: usize,
 66}
 67
 68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 69struct Count(usize);
 70
 71#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 72struct NotificationId(u64);
 73
 74impl NotificationStore {
 75    pub fn global(cx: &AppContext) -> Model<Self> {
 76        cx.global::<GlobalNotificationStore>().0.clone()
 77    }
 78
 79    pub fn new(
 80        client: Arc<Client>,
 81        user_store: Model<UserStore>,
 82        cx: &mut ModelContext<Self>,
 83    ) -> Self {
 84        let mut connection_status = client.status();
 85        let watch_connection_status = cx.spawn(|this, mut cx| async move {
 86            while let Some(status) = connection_status.next().await {
 87                let this = this.upgrade()?;
 88                match status {
 89                    client::Status::Connected { .. } => {
 90                        if let Some(task) = this
 91                            .update(&mut cx, |this, cx| this.handle_connect(cx))
 92                            .log_err()?
 93                        {
 94                            task.await.log_err()?;
 95                        }
 96                    }
 97                    _ => this
 98                        .update(&mut cx, |this, cx| this.handle_disconnect(cx))
 99                        .log_err()?,
100                }
101            }
102            Some(())
103        });
104
105        Self {
106            channel_store: ChannelStore::global(cx),
107            notifications: Default::default(),
108            loaded_all_notifications: false,
109            channel_messages: Default::default(),
110            _watch_connection_status: watch_connection_status,
111            _subscriptions: vec![
112                client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
113                client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
114                client.add_message_handler(cx.weak_model(), Self::handle_update_notification),
115            ],
116            user_store,
117            client,
118        }
119    }
120
121    pub fn notification_count(&self) -> usize {
122        self.notifications.summary().count
123    }
124
125    pub fn unread_notification_count(&self) -> usize {
126        self.notifications.summary().unread_count
127    }
128
129    pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
130        self.channel_messages.get(&id)
131    }
132
133    // Get the nth newest notification.
134    pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
135        let count = self.notifications.summary().count;
136        if ix >= count {
137            return None;
138        }
139        let ix = count - 1 - ix;
140        let mut cursor = self.notifications.cursor::<Count>(&());
141        cursor.seek(&Count(ix), Bias::Right, &());
142        cursor.item()
143    }
144    pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
145        let mut cursor = self.notifications.cursor::<NotificationId>(&());
146        cursor.seek(&NotificationId(id), Bias::Left, &());
147        if let Some(item) = cursor.item() {
148            if item.id == id {
149                return Some(item);
150            }
151        }
152        None
153    }
154
155    pub fn load_more_notifications(
156        &self,
157        clear_old: bool,
158        cx: &mut ModelContext<Self>,
159    ) -> Option<Task<Result<()>>> {
160        if self.loaded_all_notifications && !clear_old {
161            return None;
162        }
163
164        let before_id = if clear_old {
165            None
166        } else {
167            self.notifications.first().map(|entry| entry.id)
168        };
169        let request = self.client.request(proto::GetNotifications { before_id });
170        Some(cx.spawn(|this, mut cx| async move {
171            let this = this
172                .upgrade()
173                .context("Notification store was dropped while loading notifications")?;
174
175            let response = request.await?;
176            this.update(&mut cx, |this, _| {
177                this.loaded_all_notifications = response.done
178            })?;
179            Self::add_notifications(
180                this,
181                response.notifications,
182                AddNotificationsOptions {
183                    is_new: false,
184                    clear_old,
185                    includes_first: response.done,
186                },
187                cx,
188            )
189            .await?;
190            Ok(())
191        }))
192    }
193
194    fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
195        self.notifications = Default::default();
196        self.channel_messages = Default::default();
197        cx.notify();
198        self.load_more_notifications(true, cx)
199    }
200
201    fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
202        cx.notify()
203    }
204
205    async fn handle_new_notification(
206        this: Model<Self>,
207        envelope: TypedEnvelope<proto::AddNotification>,
208        cx: AsyncAppContext,
209    ) -> Result<()> {
210        Self::add_notifications(
211            this,
212            envelope.payload.notification.into_iter().collect(),
213            AddNotificationsOptions {
214                is_new: true,
215                clear_old: false,
216                includes_first: false,
217            },
218            cx,
219        )
220        .await
221    }
222
223    async fn handle_delete_notification(
224        this: Model<Self>,
225        envelope: TypedEnvelope<proto::DeleteNotification>,
226        mut cx: AsyncAppContext,
227    ) -> Result<()> {
228        this.update(&mut cx, |this, cx| {
229            this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
230            Ok(())
231        })?
232    }
233
234    async fn handle_update_notification(
235        this: Model<Self>,
236        envelope: TypedEnvelope<proto::UpdateNotification>,
237        mut cx: AsyncAppContext,
238    ) -> Result<()> {
239        this.update(&mut cx, |this, cx| {
240            if let Some(notification) = envelope.payload.notification {
241                if let Some(rpc::Notification::ChannelMessageMention {
242                    message_id,
243                    sender_id: _,
244                    channel_id: _,
245                }) = Notification::from_proto(&notification)
246                {
247                    let fetch_message_task = this.channel_store.update(cx, |this, cx| {
248                        this.fetch_channel_messages(vec![message_id], cx)
249                    });
250
251                    cx.spawn(|this, mut cx| async move {
252                        let messages = fetch_message_task.await?;
253                        this.update(&mut cx, move |this, cx| {
254                            for message in messages {
255                                this.channel_messages.insert(message_id, message);
256                            }
257                            cx.notify();
258                        })
259                    })
260                    .detach_and_log_err(cx)
261                }
262            }
263            Ok(())
264        })?
265    }
266
267    async fn add_notifications(
268        this: Model<Self>,
269        notifications: Vec<proto::Notification>,
270        options: AddNotificationsOptions,
271        mut cx: AsyncAppContext,
272    ) -> Result<()> {
273        let mut user_ids = Vec::new();
274        let mut message_ids = Vec::new();
275
276        let notifications = notifications
277            .into_iter()
278            .filter_map(|message| {
279                Some(NotificationEntry {
280                    id: message.id,
281                    is_read: message.is_read,
282                    timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
283                        .ok()?,
284                    notification: Notification::from_proto(&message)?,
285                    response: message.response,
286                })
287            })
288            .collect::<Vec<_>>();
289        if notifications.is_empty() {
290            return Ok(());
291        }
292
293        for entry in &notifications {
294            match entry.notification {
295                Notification::ChannelInvitation { inviter_id, .. } => {
296                    user_ids.push(inviter_id);
297                }
298                Notification::ContactRequest {
299                    sender_id: requester_id,
300                } => {
301                    user_ids.push(requester_id);
302                }
303                Notification::ContactRequestAccepted {
304                    responder_id: contact_id,
305                } => {
306                    user_ids.push(contact_id);
307                }
308                Notification::ChannelMessageMention {
309                    sender_id,
310                    message_id,
311                    ..
312                } => {
313                    user_ids.push(sender_id);
314                    message_ids.push(message_id);
315                }
316            }
317        }
318
319        let (user_store, channel_store) = this.read_with(&cx, |this, _| {
320            (this.user_store.clone(), this.channel_store.clone())
321        })?;
322
323        user_store
324            .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
325            .await?;
326        let messages = channel_store
327            .update(&mut cx, |store, cx| {
328                store.fetch_channel_messages(message_ids, cx)
329            })?
330            .await?;
331        this.update(&mut cx, |this, cx| {
332            if options.clear_old {
333                cx.emit(NotificationEvent::NotificationsUpdated {
334                    old_range: 0..this.notifications.summary().count,
335                    new_count: 0,
336                });
337                this.notifications = SumTree::default();
338                this.channel_messages.clear();
339                this.loaded_all_notifications = false;
340            }
341
342            if options.includes_first {
343                this.loaded_all_notifications = true;
344            }
345
346            this.channel_messages
347                .extend(messages.into_iter().filter_map(|message| {
348                    if let ChannelMessageId::Saved(id) = message.id {
349                        Some((id, message))
350                    } else {
351                        None
352                    }
353                }));
354
355            this.splice_notifications(
356                notifications
357                    .into_iter()
358                    .map(|notification| (notification.id, Some(notification))),
359                options.is_new,
360                cx,
361            );
362        })
363        .log_err();
364
365        Ok(())
366    }
367
368    fn splice_notifications(
369        &mut self,
370        notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
371        is_new: bool,
372        cx: &mut ModelContext<'_, NotificationStore>,
373    ) {
374        let mut cursor = self.notifications.cursor::<(NotificationId, Count)>(&());
375        let mut new_notifications = SumTree::default();
376        let mut old_range = 0..0;
377
378        for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
379            new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
380
381            if i == 0 {
382                old_range.start = cursor.start().1 .0;
383            }
384
385            let old_notification = cursor.item();
386            if let Some(old_notification) = old_notification {
387                if old_notification.id == id {
388                    cursor.next(&());
389
390                    if let Some(new_notification) = &new_notification {
391                        if new_notification.is_read {
392                            cx.emit(NotificationEvent::NotificationRead {
393                                entry: new_notification.clone(),
394                            });
395                        }
396                    } else {
397                        cx.emit(NotificationEvent::NotificationRemoved {
398                            entry: old_notification.clone(),
399                        });
400                    }
401                }
402            } else if let Some(new_notification) = &new_notification {
403                if is_new {
404                    cx.emit(NotificationEvent::NewNotification {
405                        entry: new_notification.clone(),
406                    });
407                }
408            }
409
410            if let Some(notification) = new_notification {
411                new_notifications.push(notification, &());
412            }
413        }
414
415        old_range.end = cursor.start().1 .0;
416        let new_count = new_notifications.summary().count - old_range.start;
417        new_notifications.append(cursor.suffix(&()), &());
418        drop(cursor);
419
420        self.notifications = new_notifications;
421        cx.emit(NotificationEvent::NotificationsUpdated {
422            old_range,
423            new_count,
424        });
425    }
426
427    pub fn respond_to_notification(
428        &mut self,
429        notification: Notification,
430        response: bool,
431        cx: &mut ModelContext<Self>,
432    ) {
433        match notification {
434            Notification::ContactRequest { sender_id } => {
435                self.user_store
436                    .update(cx, |store, cx| {
437                        store.respond_to_contact_request(sender_id, response, cx)
438                    })
439                    .detach();
440            }
441            Notification::ChannelInvitation { channel_id, .. } => {
442                self.channel_store
443                    .update(cx, |store, cx| {
444                        store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
445                    })
446                    .detach();
447            }
448            _ => {}
449        }
450    }
451}
452
453impl EventEmitter<NotificationEvent> for NotificationStore {}
454
455impl sum_tree::Item for NotificationEntry {
456    type Summary = NotificationSummary;
457
458    fn summary(&self, _cx: &()) -> Self::Summary {
459        NotificationSummary {
460            max_id: self.id,
461            count: 1,
462            unread_count: if self.is_read { 0 } else { 1 },
463        }
464    }
465}
466
467impl sum_tree::Summary for NotificationSummary {
468    type Context = ();
469
470    fn zero(_cx: &()) -> Self {
471        Default::default()
472    }
473
474    fn add_summary(&mut self, summary: &Self, _: &()) {
475        self.max_id = self.max_id.max(summary.max_id);
476        self.count += summary.count;
477        self.unread_count += summary.unread_count;
478    }
479}
480
481impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
482    fn zero(_cx: &()) -> Self {
483        Default::default()
484    }
485
486    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
487        debug_assert!(summary.max_id > self.0);
488        self.0 = summary.max_id;
489    }
490}
491
492impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
493    fn zero(_cx: &()) -> Self {
494        Default::default()
495    }
496
497    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
498        self.0 += summary.count;
499    }
500}
501
502struct AddNotificationsOptions {
503    is_new: bool,
504    clear_old: bool,
505    includes_first: bool,
506}