notification_store.rs

  1use anyhow::{Context, Result};
  2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
  3use client::{ChannelId, Client, UserStore};
  4use collections::HashMap;
  5use db::smol::stream::StreamExt;
  6use gpui::{
  7    AppContext, AsyncAppContext, Context as _, EventEmitter, Global, Model, ModelContext, Task,
  8};
  9use rpc::{proto, Notification, TypedEnvelope};
 10use std::{ops::Range, sync::Arc};
 11use sum_tree::{Bias, SumTree};
 12use time::OffsetDateTime;
 13use util::ResultExt;
 14
 15pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
 16    let notification_store = cx.new_model(|cx| NotificationStore::new(client, user_store, cx));
 17    cx.set_global(GlobalNotificationStore(notification_store));
 18}
 19
 20struct GlobalNotificationStore(Model<NotificationStore>);
 21
 22impl Global for GlobalNotificationStore {}
 23
 24pub struct NotificationStore {
 25    client: Arc<Client>,
 26    user_store: Model<UserStore>,
 27    channel_messages: HashMap<u64, ChannelMessage>,
 28    channel_store: Model<ChannelStore>,
 29    notifications: SumTree<NotificationEntry>,
 30    loaded_all_notifications: bool,
 31    _watch_connection_status: Task<Option<()>>,
 32    _subscriptions: Vec<client::Subscription>,
 33}
 34
 35#[derive(Clone, PartialEq, Eq, Debug)]
 36pub enum NotificationEvent {
 37    NotificationsUpdated {
 38        old_range: Range<usize>,
 39        new_count: usize,
 40    },
 41    NewNotification {
 42        entry: NotificationEntry,
 43    },
 44    NotificationRemoved {
 45        entry: NotificationEntry,
 46    },
 47    NotificationRead {
 48        entry: NotificationEntry,
 49    },
 50}
 51
 52#[derive(Debug, PartialEq, Eq, Clone)]
 53pub struct NotificationEntry {
 54    pub id: u64,
 55    pub notification: Notification,
 56    pub timestamp: OffsetDateTime,
 57    pub is_read: bool,
 58    pub response: Option<bool>,
 59}
 60
 61#[derive(Clone, Debug, Default)]
 62pub struct NotificationSummary {
 63    max_id: u64,
 64    count: usize,
 65    unread_count: usize,
 66}
 67
 68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 69struct Count(usize);
 70
 71#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 72struct UnreadCount(usize);
 73
 74#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 75struct NotificationId(u64);
 76
 77impl NotificationStore {
 78    pub fn global(cx: &AppContext) -> Model<Self> {
 79        cx.global::<GlobalNotificationStore>().0.clone()
 80    }
 81
 82    pub fn new(
 83        client: Arc<Client>,
 84        user_store: Model<UserStore>,
 85        cx: &mut ModelContext<Self>,
 86    ) -> Self {
 87        let mut connection_status = client.status();
 88        let watch_connection_status = cx.spawn(|this, mut cx| async move {
 89            while let Some(status) = connection_status.next().await {
 90                let this = this.upgrade()?;
 91                match status {
 92                    client::Status::Connected { .. } => {
 93                        if let Some(task) = this
 94                            .update(&mut cx, |this, cx| this.handle_connect(cx))
 95                            .log_err()?
 96                        {
 97                            task.await.log_err()?;
 98                        }
 99                    }
100                    _ => this
101                        .update(&mut cx, |this, cx| this.handle_disconnect(cx))
102                        .log_err()?,
103                }
104            }
105            Some(())
106        });
107
108        Self {
109            channel_store: ChannelStore::global(cx),
110            notifications: Default::default(),
111            loaded_all_notifications: false,
112            channel_messages: Default::default(),
113            _watch_connection_status: watch_connection_status,
114            _subscriptions: vec![
115                client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
116                client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
117                client.add_message_handler(cx.weak_model(), Self::handle_update_notification),
118            ],
119            user_store,
120            client,
121        }
122    }
123
124    pub fn notification_count(&self) -> usize {
125        self.notifications.summary().count
126    }
127
128    pub fn unread_notification_count(&self) -> usize {
129        self.notifications.summary().unread_count
130    }
131
132    pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
133        self.channel_messages.get(&id)
134    }
135
136    // Get the nth newest notification.
137    pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
138        let count = self.notifications.summary().count;
139        if ix >= count {
140            return None;
141        }
142        let ix = count - 1 - ix;
143        let mut cursor = self.notifications.cursor::<Count>();
144        cursor.seek(&Count(ix), Bias::Right, &());
145        cursor.item()
146    }
147
148    pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
149        let mut cursor = self.notifications.cursor::<NotificationId>();
150        cursor.seek(&NotificationId(id), Bias::Left, &());
151        if let Some(item) = cursor.item() {
152            if item.id == id {
153                return Some(item);
154            }
155        }
156        None
157    }
158
159    pub fn load_more_notifications(
160        &self,
161        clear_old: bool,
162        cx: &mut ModelContext<Self>,
163    ) -> Option<Task<Result<()>>> {
164        if self.loaded_all_notifications && !clear_old {
165            return None;
166        }
167
168        let before_id = if clear_old {
169            None
170        } else {
171            self.notifications.first().map(|entry| entry.id)
172        };
173        let request = self.client.request(proto::GetNotifications { before_id });
174        Some(cx.spawn(|this, mut cx| async move {
175            let this = this
176                .upgrade()
177                .context("Notification store was dropped while loading notifications")?;
178
179            let response = request.await?;
180            this.update(&mut cx, |this, _| {
181                this.loaded_all_notifications = response.done
182            })?;
183            Self::add_notifications(
184                this,
185                response.notifications,
186                AddNotificationsOptions {
187                    is_new: false,
188                    clear_old,
189                    includes_first: response.done,
190                },
191                cx,
192            )
193            .await?;
194            Ok(())
195        }))
196    }
197
198    fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
199        self.notifications = Default::default();
200        self.channel_messages = Default::default();
201        cx.notify();
202        self.load_more_notifications(true, cx)
203    }
204
205    fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
206        cx.notify()
207    }
208
209    async fn handle_new_notification(
210        this: Model<Self>,
211        envelope: TypedEnvelope<proto::AddNotification>,
212        _: Arc<Client>,
213        cx: AsyncAppContext,
214    ) -> Result<()> {
215        Self::add_notifications(
216            this,
217            envelope.payload.notification.into_iter().collect(),
218            AddNotificationsOptions {
219                is_new: true,
220                clear_old: false,
221                includes_first: false,
222            },
223            cx,
224        )
225        .await
226    }
227
228    async fn handle_delete_notification(
229        this: Model<Self>,
230        envelope: TypedEnvelope<proto::DeleteNotification>,
231        _: Arc<Client>,
232        mut cx: AsyncAppContext,
233    ) -> Result<()> {
234        this.update(&mut cx, |this, cx| {
235            this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
236            Ok(())
237        })?
238    }
239
240    async fn handle_update_notification(
241        this: Model<Self>,
242        envelope: TypedEnvelope<proto::UpdateNotification>,
243        _: Arc<Client>,
244        mut cx: AsyncAppContext,
245    ) -> Result<()> {
246        this.update(&mut cx, |this, cx| {
247            if let Some(notification) = envelope.payload.notification {
248                if let Some(rpc::Notification::ChannelMessageMention {
249                    message_id,
250                    sender_id: _,
251                    channel_id: _,
252                }) = Notification::from_proto(&notification)
253                {
254                    let fetch_message_task = this.channel_store.update(cx, |this, cx| {
255                        this.fetch_channel_messages(vec![message_id], cx)
256                    });
257
258                    cx.spawn(|this, mut cx| async move {
259                        let messages = fetch_message_task.await?;
260                        this.update(&mut cx, move |this, cx| {
261                            for message in messages {
262                                this.channel_messages.insert(message_id, message);
263                            }
264                            cx.notify();
265                        })
266                    })
267                    .detach_and_log_err(cx)
268                }
269            }
270            Ok(())
271        })?
272    }
273
274    async fn add_notifications(
275        this: Model<Self>,
276        notifications: Vec<proto::Notification>,
277        options: AddNotificationsOptions,
278        mut cx: AsyncAppContext,
279    ) -> Result<()> {
280        let mut user_ids = Vec::new();
281        let mut message_ids = Vec::new();
282
283        let notifications = notifications
284            .into_iter()
285            .filter_map(|message| {
286                Some(NotificationEntry {
287                    id: message.id,
288                    is_read: message.is_read,
289                    timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
290                        .ok()?,
291                    notification: Notification::from_proto(&message)?,
292                    response: message.response,
293                })
294            })
295            .collect::<Vec<_>>();
296        if notifications.is_empty() {
297            return Ok(());
298        }
299
300        for entry in &notifications {
301            match entry.notification {
302                Notification::ChannelInvitation { inviter_id, .. } => {
303                    user_ids.push(inviter_id);
304                }
305                Notification::ContactRequest {
306                    sender_id: requester_id,
307                } => {
308                    user_ids.push(requester_id);
309                }
310                Notification::ContactRequestAccepted {
311                    responder_id: contact_id,
312                } => {
313                    user_ids.push(contact_id);
314                }
315                Notification::ChannelMessageMention {
316                    sender_id,
317                    message_id,
318                    ..
319                } => {
320                    user_ids.push(sender_id);
321                    message_ids.push(message_id);
322                }
323            }
324        }
325
326        let (user_store, channel_store) = this.read_with(&cx, |this, _| {
327            (this.user_store.clone(), this.channel_store.clone())
328        })?;
329
330        user_store
331            .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
332            .await?;
333        let messages = channel_store
334            .update(&mut cx, |store, cx| {
335                store.fetch_channel_messages(message_ids, cx)
336            })?
337            .await?;
338        this.update(&mut cx, |this, cx| {
339            if options.clear_old {
340                cx.emit(NotificationEvent::NotificationsUpdated {
341                    old_range: 0..this.notifications.summary().count,
342                    new_count: 0,
343                });
344                this.notifications = SumTree::default();
345                this.channel_messages.clear();
346                this.loaded_all_notifications = false;
347            }
348
349            if options.includes_first {
350                this.loaded_all_notifications = true;
351            }
352
353            this.channel_messages
354                .extend(messages.into_iter().filter_map(|message| {
355                    if let ChannelMessageId::Saved(id) = message.id {
356                        Some((id, message))
357                    } else {
358                        None
359                    }
360                }));
361
362            this.splice_notifications(
363                notifications
364                    .into_iter()
365                    .map(|notification| (notification.id, Some(notification))),
366                options.is_new,
367                cx,
368            );
369        })
370        .log_err();
371
372        Ok(())
373    }
374
375    fn splice_notifications(
376        &mut self,
377        notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
378        is_new: bool,
379        cx: &mut ModelContext<'_, NotificationStore>,
380    ) {
381        let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
382        let mut new_notifications = SumTree::new();
383        let mut old_range = 0..0;
384
385        for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
386            new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
387
388            if i == 0 {
389                old_range.start = cursor.start().1 .0;
390            }
391
392            let old_notification = cursor.item();
393            if let Some(old_notification) = old_notification {
394                if old_notification.id == id {
395                    cursor.next(&());
396
397                    if let Some(new_notification) = &new_notification {
398                        if new_notification.is_read {
399                            cx.emit(NotificationEvent::NotificationRead {
400                                entry: new_notification.clone(),
401                            });
402                        }
403                    } else {
404                        cx.emit(NotificationEvent::NotificationRemoved {
405                            entry: old_notification.clone(),
406                        });
407                    }
408                }
409            } else if let Some(new_notification) = &new_notification {
410                if is_new {
411                    cx.emit(NotificationEvent::NewNotification {
412                        entry: new_notification.clone(),
413                    });
414                }
415            }
416
417            if let Some(notification) = new_notification {
418                new_notifications.push(notification, &());
419            }
420        }
421
422        old_range.end = cursor.start().1 .0;
423        let new_count = new_notifications.summary().count - old_range.start;
424        new_notifications.append(cursor.suffix(&()), &());
425        drop(cursor);
426
427        self.notifications = new_notifications;
428        cx.emit(NotificationEvent::NotificationsUpdated {
429            old_range,
430            new_count,
431        });
432    }
433
434    pub fn respond_to_notification(
435        &mut self,
436        notification: Notification,
437        response: bool,
438        cx: &mut ModelContext<Self>,
439    ) {
440        match notification {
441            Notification::ContactRequest { sender_id } => {
442                self.user_store
443                    .update(cx, |store, cx| {
444                        store.respond_to_contact_request(sender_id, response, cx)
445                    })
446                    .detach();
447            }
448            Notification::ChannelInvitation { channel_id, .. } => {
449                self.channel_store
450                    .update(cx, |store, cx| {
451                        store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
452                    })
453                    .detach();
454            }
455            _ => {}
456        }
457    }
458}
459
460impl EventEmitter<NotificationEvent> for NotificationStore {}
461
462impl sum_tree::Item for NotificationEntry {
463    type Summary = NotificationSummary;
464
465    fn summary(&self) -> Self::Summary {
466        NotificationSummary {
467            max_id: self.id,
468            count: 1,
469            unread_count: if self.is_read { 0 } else { 1 },
470        }
471    }
472}
473
474impl sum_tree::Summary for NotificationSummary {
475    type Context = ();
476
477    fn add_summary(&mut self, summary: &Self, _: &()) {
478        self.max_id = self.max_id.max(summary.max_id);
479        self.count += summary.count;
480        self.unread_count += summary.unread_count;
481    }
482}
483
484impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
485    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
486        debug_assert!(summary.max_id > self.0);
487        self.0 = summary.max_id;
488    }
489}
490
491impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
492    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
493        self.0 += summary.count;
494    }
495}
496
497impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
498    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
499        self.0 += summary.unread_count;
500    }
501}
502
503struct AddNotificationsOptions {
504    is_new: bool,
505    clear_old: bool,
506    includes_first: bool,
507}