notification_store.rs

  1use anyhow::Result;
  2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
  3use client::{Client, UserStore};
  4use collections::HashMap;
  5use db::smol::stream::StreamExt;
  6use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
  7use rpc::{proto, Notification, TypedEnvelope};
  8use std::{ops::Range, sync::Arc};
  9use sum_tree::{Bias, SumTree};
 10use time::OffsetDateTime;
 11use util::ResultExt;
 12
 13pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
 14    let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
 15    cx.set_global(notification_store);
 16}
 17
 18pub struct NotificationStore {
 19    client: Arc<Client>,
 20    user_store: ModelHandle<UserStore>,
 21    channel_messages: HashMap<u64, ChannelMessage>,
 22    channel_store: ModelHandle<ChannelStore>,
 23    notifications: SumTree<NotificationEntry>,
 24    loaded_all_notifications: bool,
 25    _watch_connection_status: Task<Option<()>>,
 26    _subscriptions: Vec<client::Subscription>,
 27}
 28
 29#[derive(Clone, PartialEq, Eq, Debug)]
 30pub enum NotificationEvent {
 31    NotificationsUpdated {
 32        old_range: Range<usize>,
 33        new_count: usize,
 34    },
 35    NewNotification {
 36        entry: NotificationEntry,
 37    },
 38    NotificationRemoved {
 39        entry: NotificationEntry,
 40    },
 41    NotificationRead {
 42        entry: NotificationEntry,
 43    },
 44}
 45
 46#[derive(Debug, PartialEq, Eq, Clone)]
 47pub struct NotificationEntry {
 48    pub id: u64,
 49    pub notification: Notification,
 50    pub timestamp: OffsetDateTime,
 51    pub is_read: bool,
 52    pub response: Option<bool>,
 53}
 54
 55#[derive(Clone, Debug, Default)]
 56pub struct NotificationSummary {
 57    max_id: u64,
 58    count: usize,
 59    unread_count: usize,
 60}
 61
 62#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 63struct Count(usize);
 64
 65#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 66struct UnreadCount(usize);
 67
 68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 69struct NotificationId(u64);
 70
 71impl NotificationStore {
 72    pub fn global(cx: &AppContext) -> ModelHandle<Self> {
 73        cx.global::<ModelHandle<Self>>().clone()
 74    }
 75
 76    pub fn new(
 77        client: Arc<Client>,
 78        user_store: ModelHandle<UserStore>,
 79        cx: &mut ModelContext<Self>,
 80    ) -> Self {
 81        let mut connection_status = client.status();
 82        let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
 83            while let Some(status) = connection_status.next().await {
 84                let this = this.upgrade(&cx)?;
 85                match status {
 86                    client::Status::Connected { .. } => {
 87                        if let Some(task) = this.update(&mut cx, |this, cx| this.handle_connect(cx))
 88                        {
 89                            task.await.log_err()?;
 90                        }
 91                    }
 92                    _ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
 93                }
 94            }
 95            Some(())
 96        });
 97
 98        Self {
 99            channel_store: ChannelStore::global(cx),
100            notifications: Default::default(),
101            loaded_all_notifications: false,
102            channel_messages: Default::default(),
103            _watch_connection_status: watch_connection_status,
104            _subscriptions: vec![
105                client.add_message_handler(cx.handle(), Self::handle_new_notification),
106                client.add_message_handler(cx.handle(), Self::handle_delete_notification),
107            ],
108            user_store,
109            client,
110        }
111    }
112
113    pub fn notification_count(&self) -> usize {
114        self.notifications.summary().count
115    }
116
117    pub fn unread_notification_count(&self) -> usize {
118        self.notifications.summary().unread_count
119    }
120
121    pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
122        self.channel_messages.get(&id)
123    }
124
125    // Get the nth newest notification.
126    pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
127        let count = self.notifications.summary().count;
128        if ix >= count {
129            return None;
130        }
131        let ix = count - 1 - ix;
132        let mut cursor = self.notifications.cursor::<Count>();
133        cursor.seek(&Count(ix), Bias::Right, &());
134        cursor.item()
135    }
136
137    pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
138        let mut cursor = self.notifications.cursor::<NotificationId>();
139        cursor.seek(&NotificationId(id), Bias::Left, &());
140        if let Some(item) = cursor.item() {
141            if item.id == id {
142                return Some(item);
143            }
144        }
145        None
146    }
147
148    pub fn load_more_notifications(
149        &self,
150        clear_old: bool,
151        cx: &mut ModelContext<Self>,
152    ) -> Option<Task<Result<()>>> {
153        if self.loaded_all_notifications && !clear_old {
154            return None;
155        }
156
157        let before_id = if clear_old {
158            None
159        } else {
160            self.notifications.first().map(|entry| entry.id)
161        };
162        let request = self.client.request(proto::GetNotifications { before_id });
163        Some(cx.spawn(|this, mut cx| async move {
164            let response = request.await?;
165            this.update(&mut cx, |this, _| {
166                this.loaded_all_notifications = response.done
167            });
168            Self::add_notifications(
169                this,
170                response.notifications,
171                AddNotificationsOptions {
172                    is_new: false,
173                    clear_old,
174                    includes_first: response.done,
175                },
176                cx,
177            )
178            .await?;
179            Ok(())
180        }))
181    }
182
183    fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
184        self.notifications = Default::default();
185        self.channel_messages = Default::default();
186        cx.notify();
187        self.load_more_notifications(true, cx)
188    }
189
190    fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
191        cx.notify()
192    }
193
194    async fn handle_new_notification(
195        this: ModelHandle<Self>,
196        envelope: TypedEnvelope<proto::AddNotification>,
197        _: Arc<Client>,
198        cx: AsyncAppContext,
199    ) -> Result<()> {
200        Self::add_notifications(
201            this,
202            envelope.payload.notification.into_iter().collect(),
203            AddNotificationsOptions {
204                is_new: true,
205                clear_old: false,
206                includes_first: false,
207            },
208            cx,
209        )
210        .await
211    }
212
213    async fn handle_delete_notification(
214        this: ModelHandle<Self>,
215        envelope: TypedEnvelope<proto::DeleteNotification>,
216        _: Arc<Client>,
217        mut cx: AsyncAppContext,
218    ) -> Result<()> {
219        this.update(&mut cx, |this, cx| {
220            this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
221            Ok(())
222        })
223    }
224
225    async fn add_notifications(
226        this: ModelHandle<Self>,
227        notifications: Vec<proto::Notification>,
228        options: AddNotificationsOptions,
229        mut cx: AsyncAppContext,
230    ) -> Result<()> {
231        let mut user_ids = Vec::new();
232        let mut message_ids = Vec::new();
233
234        let notifications = notifications
235            .into_iter()
236            .filter_map(|message| {
237                Some(NotificationEntry {
238                    id: message.id,
239                    is_read: message.is_read,
240                    timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
241                        .ok()?,
242                    notification: Notification::from_proto(&message)?,
243                    response: message.response,
244                })
245            })
246            .collect::<Vec<_>>();
247        if notifications.is_empty() {
248            return Ok(());
249        }
250
251        for entry in &notifications {
252            match entry.notification {
253                Notification::ChannelInvitation { inviter_id, .. } => {
254                    user_ids.push(inviter_id);
255                }
256                Notification::ContactRequest {
257                    sender_id: requester_id,
258                } => {
259                    user_ids.push(requester_id);
260                }
261                Notification::ContactRequestAccepted {
262                    responder_id: contact_id,
263                } => {
264                    user_ids.push(contact_id);
265                }
266                Notification::ChannelMessageMention {
267                    sender_id,
268                    message_id,
269                    ..
270                } => {
271                    user_ids.push(sender_id);
272                    message_ids.push(message_id);
273                }
274            }
275        }
276
277        let (user_store, channel_store) = this.read_with(&cx, |this, _| {
278            (this.user_store.clone(), this.channel_store.clone())
279        });
280
281        user_store
282            .update(&mut cx, |store, cx| store.get_users(user_ids, cx))
283            .await?;
284        let messages = channel_store
285            .update(&mut cx, |store, cx| {
286                store.fetch_channel_messages(message_ids, cx)
287            })
288            .await?;
289        this.update(&mut cx, |this, cx| {
290            if options.clear_old {
291                cx.emit(NotificationEvent::NotificationsUpdated {
292                    old_range: 0..this.notifications.summary().count,
293                    new_count: 0,
294                });
295                this.notifications = SumTree::default();
296                this.channel_messages.clear();
297                this.loaded_all_notifications = false;
298            }
299
300            if options.includes_first {
301                this.loaded_all_notifications = true;
302            }
303
304            this.channel_messages
305                .extend(messages.into_iter().filter_map(|message| {
306                    if let ChannelMessageId::Saved(id) = message.id {
307                        Some((id, message))
308                    } else {
309                        None
310                    }
311                }));
312
313            this.splice_notifications(
314                notifications
315                    .into_iter()
316                    .map(|notification| (notification.id, Some(notification))),
317                options.is_new,
318                cx,
319            );
320        });
321
322        Ok(())
323    }
324
325    fn splice_notifications(
326        &mut self,
327        notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
328        is_new: bool,
329        cx: &mut ModelContext<'_, NotificationStore>,
330    ) {
331        let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
332        let mut new_notifications = SumTree::new();
333        let mut old_range = 0..0;
334
335        for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
336            new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
337
338            if i == 0 {
339                old_range.start = cursor.start().1 .0;
340            }
341
342            let old_notification = cursor.item();
343            if let Some(old_notification) = old_notification {
344                if old_notification.id == id {
345                    cursor.next(&());
346
347                    if let Some(new_notification) = &new_notification {
348                        if new_notification.is_read {
349                            cx.emit(NotificationEvent::NotificationRead {
350                                entry: new_notification.clone(),
351                            });
352                        }
353                    } else {
354                        cx.emit(NotificationEvent::NotificationRemoved {
355                            entry: old_notification.clone(),
356                        });
357                    }
358                }
359            } else if let Some(new_notification) = &new_notification {
360                if is_new {
361                    cx.emit(NotificationEvent::NewNotification {
362                        entry: new_notification.clone(),
363                    });
364                }
365            }
366
367            if let Some(notification) = new_notification {
368                new_notifications.push(notification, &());
369            }
370        }
371
372        old_range.end = cursor.start().1 .0;
373        let new_count = new_notifications.summary().count - old_range.start;
374        new_notifications.append(cursor.suffix(&()), &());
375        drop(cursor);
376
377        self.notifications = new_notifications;
378        cx.emit(NotificationEvent::NotificationsUpdated {
379            old_range,
380            new_count,
381        });
382    }
383
384    pub fn respond_to_notification(
385        &mut self,
386        notification: Notification,
387        response: bool,
388        cx: &mut ModelContext<Self>,
389    ) {
390        match notification {
391            Notification::ContactRequest { sender_id } => {
392                self.user_store
393                    .update(cx, |store, cx| {
394                        store.respond_to_contact_request(sender_id, response, cx)
395                    })
396                    .detach();
397            }
398            Notification::ChannelInvitation { channel_id, .. } => {
399                self.channel_store
400                    .update(cx, |store, cx| {
401                        store.respond_to_channel_invite(channel_id, response, cx)
402                    })
403                    .detach();
404            }
405            _ => {}
406        }
407    }
408}
409
410impl Entity for NotificationStore {
411    type Event = NotificationEvent;
412}
413
414impl sum_tree::Item for NotificationEntry {
415    type Summary = NotificationSummary;
416
417    fn summary(&self) -> Self::Summary {
418        NotificationSummary {
419            max_id: self.id,
420            count: 1,
421            unread_count: if self.is_read { 0 } else { 1 },
422        }
423    }
424}
425
426impl sum_tree::Summary for NotificationSummary {
427    type Context = ();
428
429    fn add_summary(&mut self, summary: &Self, _: &()) {
430        self.max_id = self.max_id.max(summary.max_id);
431        self.count += summary.count;
432        self.unread_count += summary.unread_count;
433    }
434}
435
436impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
437    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
438        debug_assert!(summary.max_id > self.0);
439        self.0 = summary.max_id;
440    }
441}
442
443impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
444    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
445        self.0 += summary.count;
446    }
447}
448
449impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
450    fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
451        self.0 += summary.unread_count;
452    }
453}
454
455struct AddNotificationsOptions {
456    is_new: bool,
457    clear_old: bool,
458    includes_first: bool,
459}