notification_store2.rs

  1use anyhow::{Context, Result};
  2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
  3use client::{Client, UserStore};
  4use collections::HashMap;
  5use db::smol::stream::StreamExt;
  6use gpui::{AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task};
  7use rpc::{proto, Notification, TypedEnvelope};
  8use std::{ops::Range, sync::Arc};
  9use sum_tree::{Bias, SumTree};
 10use time::OffsetDateTime;
 11use util::ResultExt;
 12
 13pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
 14    let notification_store = cx.build_model(|cx| NotificationStore::new(client, user_store, cx));
 15    cx.set_global(notification_store);
 16}
 17
 18pub struct NotificationStore {
 19    client: Arc<Client>,
 20    user_store: Model<UserStore>,
 21    channel_messages: HashMap<u64, ChannelMessage>,
 22    channel_store: Model<ChannelStore>,
 23    notifications: SumTree<NotificationEntry>,
 24    loaded_all_notifications: bool,
 25    _watch_connection_status: Task<Option<()>>,
 26    _subscriptions: Vec<client::Subscription>,
 27}
 28
 29#[derive(Clone, PartialEq, Eq, Debug)]
 30pub enum NotificationEvent {
 31    NotificationsUpdated {
 32        old_range: Range<usize>,
 33        new_count: usize,
 34    },
 35    NewNotification {
 36        entry: NotificationEntry,
 37    },
 38    NotificationRemoved {
 39        entry: NotificationEntry,
 40    },
 41    NotificationRead {
 42        entry: NotificationEntry,
 43    },
 44}
 45
 46#[derive(Debug, PartialEq, Eq, Clone)]
 47pub struct NotificationEntry {
 48    pub id: u64,
 49    pub notification: Notification,
 50    pub timestamp: OffsetDateTime,
 51    pub is_read: bool,
 52    pub response: Option<bool>,
 53}
 54
 55#[derive(Clone, Debug, Default)]
 56pub struct NotificationSummary {
 57    max_id: u64,
 58    count: usize,
 59    unread_count: usize,
 60}
 61
 62#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 63struct Count(usize);
 64
 65#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 66struct UnreadCount(usize);
 67
 68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 69struct NotificationId(u64);
 70
 71impl NotificationStore {
 72    pub fn global(cx: &AppContext) -> Model<Self> {
 73        cx.global::<Model<Self>>().clone()
 74    }
 75
 76    pub fn new(
 77        client: Arc<Client>,
 78        user_store: Model<UserStore>,
 79        cx: &mut ModelContext<Self>,
 80    ) -> Self {
 81        let mut connection_status = client.status();
 82        let watch_connection_status = cx.spawn(|this, mut cx| async move {
 83            while let Some(status) = connection_status.next().await {
 84                let this = this.upgrade()?;
 85                match status {
 86                    client::Status::Connected { .. } => {
 87                        if let Some(task) = this
 88                            .update(&mut cx, |this, cx| this.handle_connect(cx))
 89                            .log_err()?
 90                        {
 91                            task.await.log_err()?;
 92                        }
 93                    }
 94                    _ => this
 95                        .update(&mut cx, |this, cx| this.handle_disconnect(cx))
 96                        .log_err()?,
 97                }
 98            }
 99            Some(())
100        });
101
102        Self {
103            channel_store: ChannelStore::global(cx),
104            notifications: Default::default(),
105            loaded_all_notifications: false,
106            channel_messages: Default::default(),
107            _watch_connection_status: watch_connection_status,
108            _subscriptions: vec![
109                client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
110                client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
111            ],
112            user_store,
113            client,
114        }
115    }
116
117    pub fn notification_count(&self) -> usize {
118        self.notifications.summary().count
119    }
120
121    pub fn unread_notification_count(&self) -> usize {
122        self.notifications.summary().unread_count
123    }
124
125    pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
126        self.channel_messages.get(&id)
127    }
128
129    // Get the nth newest notification.
130    pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
131        let count = self.notifications.summary().count;
132        if ix >= count {
133            return None;
134        }
135        let ix = count - 1 - ix;
136        let mut cursor = self.notifications.cursor::<Count>();
137        cursor.seek(&Count(ix), Bias::Right, &());
138        cursor.item()
139    }
140
141    pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
142        let mut cursor = self.notifications.cursor::<NotificationId>();
143        cursor.seek(&NotificationId(id), Bias::Left, &());
144        if let Some(item) = cursor.item() {
145            if item.id == id {
146                return Some(item);
147            }
148        }
149        None
150    }
151
152    pub fn load_more_notifications(
153        &self,
154        clear_old: bool,
155        cx: &mut ModelContext<Self>,
156    ) -> Option<Task<Result<()>>> {
157        if self.loaded_all_notifications && !clear_old {
158            return None;
159        }
160
161        let before_id = if clear_old {
162            None
163        } else {
164            self.notifications.first().map(|entry| entry.id)
165        };
166        let request = self.client.request(proto::GetNotifications { before_id });
167        Some(cx.spawn(|this, mut cx| async move {
168            let this = this
169                .upgrade()
170                .context("Notification store was dropped while loading notifications")?;
171
172            let response = request.await?;
173            this.update(&mut cx, |this, _| {
174                this.loaded_all_notifications = response.done
175            })?;
176            Self::add_notifications(
177                this,
178                response.notifications,
179                AddNotificationsOptions {
180                    is_new: false,
181                    clear_old,
182                    includes_first: response.done,
183                },
184                cx,
185            )
186            .await?;
187            Ok(())
188        }))
189    }
190
191    fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
192        self.notifications = Default::default();
193        self.channel_messages = Default::default();
194        cx.notify();
195        self.load_more_notifications(true, cx)
196    }
197
198    fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
199        cx.notify()
200    }
201
202    async fn handle_new_notification(
203        this: Model<Self>,
204        envelope: TypedEnvelope<proto::AddNotification>,
205        _: Arc<Client>,
206        cx: AsyncAppContext,
207    ) -> Result<()> {
208        Self::add_notifications(
209            this,
210            envelope.payload.notification.into_iter().collect(),
211            AddNotificationsOptions {
212                is_new: true,
213                clear_old: false,
214                includes_first: false,
215            },
216            cx,
217        )
218        .await
219    }
220
221    async fn handle_delete_notification(
222        this: Model<Self>,
223        envelope: TypedEnvelope<proto::DeleteNotification>,
224        _: Arc<Client>,
225        mut cx: AsyncAppContext,
226    ) -> Result<()> {
227        this.update(&mut cx, |this, cx| {
228            this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
229            Ok(())
230        })?
231    }
232
233    async fn add_notifications(
234        this: Model<Self>,
235        notifications: Vec<proto::Notification>,
236        options: AddNotificationsOptions,
237        mut cx: AsyncAppContext,
238    ) -> Result<()> {
239        let mut user_ids = Vec::new();
240        let mut message_ids = Vec::new();
241
242        let notifications = notifications
243            .into_iter()
244            .filter_map(|message| {
245                Some(NotificationEntry {
246                    id: message.id,
247                    is_read: message.is_read,
248                    timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
249                        .ok()?,
250                    notification: Notification::from_proto(&message)?,
251                    response: message.response,
252                })
253            })
254            .collect::<Vec<_>>();
255        if notifications.is_empty() {
256            return Ok(());
257        }
258
259        for entry in &notifications {
260            match entry.notification {
261                Notification::ChannelInvitation { inviter_id, .. } => {
262                    user_ids.push(inviter_id);
263                }
264                Notification::ContactRequest {
265                    sender_id: requester_id,
266                } => {
267                    user_ids.push(requester_id);
268                }
269                Notification::ContactRequestAccepted {
270                    responder_id: contact_id,
271                } => {
272                    user_ids.push(contact_id);
273                }
274                Notification::ChannelMessageMention {
275                    sender_id,
276                    message_id,
277                    ..
278                } => {
279                    user_ids.push(sender_id);
280                    message_ids.push(message_id);
281                }
282            }
283        }
284
285        let (user_store, channel_store) = this.read_with(&cx, |this, _| {
286            (this.user_store.clone(), this.channel_store.clone())
287        })?;
288
289        user_store
290            .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
291            .await?;
292        let messages = channel_store
293            .update(&mut cx, |store, cx| {
294                store.fetch_channel_messages(message_ids, cx)
295            })?
296            .await?;
297        this.update(&mut cx, |this, cx| {
298            if options.clear_old {
299                cx.emit(NotificationEvent::NotificationsUpdated {
300                    old_range: 0..this.notifications.summary().count,
301                    new_count: 0,
302                });
303                this.notifications = SumTree::default();
304                this.channel_messages.clear();
305                this.loaded_all_notifications = false;
306            }
307
308            if options.includes_first {
309                this.loaded_all_notifications = true;
310            }
311
312            this.channel_messages
313                .extend(messages.into_iter().filter_map(|message| {
314                    if let ChannelMessageId::Saved(id) = message.id {
315                        Some((id, message))
316                    } else {
317                        None
318                    }
319                }));
320
321            this.splice_notifications(
322                notifications
323                    .into_iter()
324                    .map(|notification| (notification.id, Some(notification))),
325                options.is_new,
326                cx,
327            );
328        })
329        .log_err();
330
331        Ok(())
332    }
333
334    fn splice_notifications(
335        &mut self,
336        notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
337        is_new: bool,
338        cx: &mut ModelContext<'_, NotificationStore>,
339    ) {
340        let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
341        let mut new_notifications = SumTree::new();
342        let mut old_range = 0..0;
343
344        for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
345            new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
346
347            if i == 0 {
348                old_range.start = cursor.start().1 .0;
349            }
350
351            let old_notification = cursor.item();
352            if let Some(old_notification) = old_notification {
353                if old_notification.id == id {
354                    cursor.next(&());
355
356                    if let Some(new_notification) = &new_notification {
357                        if new_notification.is_read {
358                            cx.emit(NotificationEvent::NotificationRead {
359                                entry: new_notification.clone(),
360                            });
361                        }
362                    } else {
363                        cx.emit(NotificationEvent::NotificationRemoved {
364                            entry: old_notification.clone(),
365                        });
366                    }
367                }
368            } else if let Some(new_notification) = &new_notification {
369                if is_new {
370                    cx.emit(NotificationEvent::NewNotification {
371                        entry: new_notification.clone(),
372                    });
373                }
374            }
375
376            if let Some(notification) = new_notification {
377                new_notifications.push(notification, &());
378            }
379        }
380
381        old_range.end = cursor.start().1 .0;
382        let new_count = new_notifications.summary().count - old_range.start;
383        new_notifications.append(cursor.suffix(&()), &());
384        drop(cursor);
385
386        self.notifications = new_notifications;
387        cx.emit(NotificationEvent::NotificationsUpdated {
388            old_range,
389            new_count,
390        });
391    }
392
393    pub fn respond_to_notification(
394        &mut self,
395        notification: Notification,
396        response: bool,
397        cx: &mut ModelContext<Self>,
398    ) {
399        match notification {
400            Notification::ContactRequest { sender_id } => {
401                self.user_store
402                    .update(cx, |store, cx| {
403                        store.respond_to_contact_request(sender_id, response, cx)
404                    })
405                    .detach();
406            }
407            Notification::ChannelInvitation { channel_id, .. } => {
408                self.channel_store
409                    .update(cx, |store, cx| {
410                        store.respond_to_channel_invite(channel_id, response, cx)
411                    })
412                    .detach();
413            }
414            _ => {}
415        }
416    }
417}
418
419impl EventEmitter<NotificationEvent> for NotificationStore {}
420
421impl sum_tree::Item for NotificationEntry {
422    type Summary = NotificationSummary;
423
424    fn summary(&self) -> Self::Summary {
425        NotificationSummary {
426            max_id: self.id,
427            count: 1,
428            unread_count: if self.is_read { 0 } else { 1 },
429        }
430    }
431}
432
433impl sum_tree::Summary for NotificationSummary {
434    type Context = ();
435
436    fn add_summary(&mut self, summary: &Self, _: &()) {
437        self.max_id = self.max_id.max(summary.max_id);
438        self.count += summary.count;
439        self.unread_count += summary.unread_count;
440    }
441}
442
443impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
444    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
445        debug_assert!(summary.max_id > self.0);
446        self.0 = summary.max_id;
447    }
448}
449
450impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
451    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
452        self.0 += summary.count;
453    }
454}
455
456impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
457    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
458        self.0 += summary.unread_count;
459    }
460}
461
462struct AddNotificationsOptions {
463    is_new: bool,
464    clear_old: bool,
465    includes_first: bool,
466}