notification_store.rs

  1use anyhow::{Context as _, Result};
  2use channel::ChannelStore;
  3use client::{ChannelId, Client, UserStore};
  4use db::smol::stream::StreamExt;
  5use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task};
  6use rpc::{Notification, TypedEnvelope, proto};
  7use std::{ops::Range, sync::Arc};
  8use sum_tree::{Bias, Dimensions, SumTree};
  9use time::OffsetDateTime;
 10use util::ResultExt;
 11
 12pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
 13    let notification_store = cx.new(|cx| NotificationStore::new(client, user_store, cx));
 14    cx.set_global(GlobalNotificationStore(notification_store));
 15}
 16
 17struct GlobalNotificationStore(Entity<NotificationStore>);
 18
 19impl Global for GlobalNotificationStore {}
 20
 21pub struct NotificationStore {
 22    client: Arc<Client>,
 23    user_store: Entity<UserStore>,
 24    channel_store: Entity<ChannelStore>,
 25    notifications: SumTree<NotificationEntry>,
 26    loaded_all_notifications: bool,
 27    _watch_connection_status: Task<Option<()>>,
 28    _subscriptions: Vec<client::Subscription>,
 29}
 30
 31#[derive(Clone, PartialEq, Eq, Debug)]
 32pub enum NotificationEvent {
 33    NotificationsUpdated {
 34        old_range: Range<usize>,
 35        new_count: usize,
 36    },
 37    NewNotification {
 38        entry: NotificationEntry,
 39    },
 40    NotificationRemoved {
 41        entry: NotificationEntry,
 42    },
 43    NotificationRead {
 44        entry: NotificationEntry,
 45    },
 46}
 47
 48#[derive(Debug, PartialEq, Eq, Clone)]
 49pub struct NotificationEntry {
 50    pub id: u64,
 51    pub notification: Notification,
 52    pub timestamp: OffsetDateTime,
 53    pub is_read: bool,
 54    pub response: Option<bool>,
 55}
 56
 57#[derive(Clone, Debug, Default)]
 58pub struct NotificationSummary {
 59    max_id: u64,
 60    count: usize,
 61    unread_count: usize,
 62}
 63
 64#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 65struct Count(usize);
 66
 67#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
 68struct NotificationId(u64);
 69
 70impl NotificationStore {
 71    pub fn global(cx: &App) -> Entity<Self> {
 72        cx.global::<GlobalNotificationStore>().0.clone()
 73    }
 74
 75    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 76        let mut connection_status = client.status();
 77        let watch_connection_status = cx.spawn(async move |this, cx| {
 78            while let Some(status) = connection_status.next().await {
 79                let this = this.upgrade()?;
 80                match status {
 81                    client::Status::Connected { .. } => {
 82                        if let Some(task) = this.update(cx, |this, cx| this.handle_connect(cx)) {
 83                            task.await.log_err()?;
 84                        }
 85                    }
 86                    _ => {
 87                        this.update(cx, |this, cx| this.handle_disconnect(cx));
 88                    }
 89                }
 90            }
 91            Some(())
 92        });
 93
 94        Self {
 95            channel_store: ChannelStore::global(cx),
 96            notifications: Default::default(),
 97            loaded_all_notifications: false,
 98            _watch_connection_status: watch_connection_status,
 99            _subscriptions: vec![
100                client.add_message_handler(cx.weak_entity(), Self::handle_new_notification),
101                client.add_message_handler(cx.weak_entity(), Self::handle_delete_notification),
102            ],
103            user_store,
104            client,
105        }
106    }
107
108    pub fn notification_count(&self) -> usize {
109        self.notifications.summary().count
110    }
111
112    pub fn unread_notification_count(&self) -> usize {
113        self.notifications.summary().unread_count
114    }
115
116    // Get the nth newest notification.
117    pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
118        let count = self.notifications.summary().count;
119        if ix >= count {
120            return None;
121        }
122        let ix = count - 1 - ix;
123        let (.., item) = self
124            .notifications
125            .find::<Count, _>((), &Count(ix), Bias::Right);
126        item
127    }
128    pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
129        let (.., item) =
130            self.notifications
131                .find::<NotificationId, _>((), &NotificationId(id), Bias::Left);
132        if let Some(item) = item
133            && item.id == id
134        {
135            return Some(item);
136        }
137        None
138    }
139
140    pub fn load_more_notifications(
141        &self,
142        clear_old: bool,
143        cx: &mut Context<Self>,
144    ) -> Option<Task<Result<()>>> {
145        if self.loaded_all_notifications && !clear_old {
146            return None;
147        }
148
149        let before_id = if clear_old {
150            None
151        } else {
152            self.notifications.first().map(|entry| entry.id)
153        };
154        let request = self.client.request(proto::GetNotifications { before_id });
155        Some(cx.spawn(async move |this, cx| {
156            let this = this
157                .upgrade()
158                .context("Notification store was dropped while loading notifications")?;
159
160            let response = request.await?;
161            this.update(cx, |this, _| this.loaded_all_notifications = response.done);
162            Self::add_notifications(
163                this,
164                response.notifications,
165                AddNotificationsOptions {
166                    is_new: false,
167                    clear_old,
168                    includes_first: response.done,
169                },
170                cx,
171            )
172            .await?;
173            Ok(())
174        }))
175    }
176
177    fn handle_connect(&mut self, cx: &mut Context<Self>) -> Option<Task<Result<()>>> {
178        self.notifications = Default::default();
179        cx.notify();
180        self.load_more_notifications(true, cx)
181    }
182
183    fn handle_disconnect(&mut self, cx: &mut Context<Self>) {
184        cx.notify()
185    }
186
187    async fn handle_new_notification(
188        this: Entity<Self>,
189        envelope: TypedEnvelope<proto::AddNotification>,
190        mut cx: AsyncApp,
191    ) -> Result<()> {
192        Self::add_notifications(
193            this,
194            envelope.payload.notification.into_iter().collect(),
195            AddNotificationsOptions {
196                is_new: true,
197                clear_old: false,
198                includes_first: false,
199            },
200            &mut cx,
201        )
202        .await
203    }
204
205    async fn handle_delete_notification(
206        this: Entity<Self>,
207        envelope: TypedEnvelope<proto::DeleteNotification>,
208        mut cx: AsyncApp,
209    ) -> Result<()> {
210        this.update(&mut cx, |this, cx| {
211            this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
212        });
213        Ok(())
214    }
215
216    async fn add_notifications(
217        this: Entity<Self>,
218        notifications: Vec<proto::Notification>,
219        options: AddNotificationsOptions,
220        cx: &mut AsyncApp,
221    ) -> Result<()> {
222        let mut user_ids = Vec::new();
223
224        let notifications = notifications
225            .into_iter()
226            .filter_map(|message| {
227                Some(NotificationEntry {
228                    id: message.id,
229                    is_read: message.is_read,
230                    timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
231                        .ok()?,
232                    notification: Notification::from_proto(&message)?,
233                    response: message.response,
234                })
235            })
236            .collect::<Vec<_>>();
237        if notifications.is_empty() {
238            return Ok(());
239        }
240
241        for entry in &notifications {
242            match entry.notification {
243                Notification::ChannelInvitation { inviter_id, .. } => {
244                    user_ids.push(inviter_id);
245                }
246                Notification::ContactRequest {
247                    sender_id: requester_id,
248                } => {
249                    user_ids.push(requester_id);
250                }
251                Notification::ContactRequestAccepted {
252                    responder_id: contact_id,
253                } => {
254                    user_ids.push(contact_id);
255                }
256            }
257        }
258
259        let user_store = this.read_with(cx, |this, _| this.user_store.clone());
260
261        user_store
262            .update(cx, |store, cx| store.get_users(user_ids, cx))
263            .await?;
264        this.update(cx, |this, cx| {
265            if options.clear_old {
266                cx.emit(NotificationEvent::NotificationsUpdated {
267                    old_range: 0..this.notifications.summary().count,
268                    new_count: 0,
269                });
270                this.notifications = SumTree::default();
271                this.loaded_all_notifications = false;
272            }
273
274            if options.includes_first {
275                this.loaded_all_notifications = true;
276            }
277
278            this.splice_notifications(
279                notifications
280                    .into_iter()
281                    .map(|notification| (notification.id, Some(notification))),
282                options.is_new,
283                cx,
284            );
285        });
286
287        Ok(())
288    }
289
290    fn splice_notifications(
291        &mut self,
292        notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
293        is_new: bool,
294        cx: &mut Context<NotificationStore>,
295    ) {
296        let mut cursor = self
297            .notifications
298            .cursor::<Dimensions<NotificationId, Count>>(());
299        let mut new_notifications = SumTree::default();
300        let mut old_range = 0..0;
301
302        for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
303            new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left), ());
304
305            if i == 0 {
306                old_range.start = cursor.start().1.0;
307            }
308
309            let old_notification = cursor.item();
310            if let Some(old_notification) = old_notification {
311                if old_notification.id == id {
312                    cursor.next();
313
314                    if let Some(new_notification) = &new_notification {
315                        if new_notification.is_read {
316                            cx.emit(NotificationEvent::NotificationRead {
317                                entry: new_notification.clone(),
318                            });
319                        }
320                    } else {
321                        cx.emit(NotificationEvent::NotificationRemoved {
322                            entry: old_notification.clone(),
323                        });
324                    }
325                }
326            } else if let Some(new_notification) = &new_notification
327                && is_new
328            {
329                cx.emit(NotificationEvent::NewNotification {
330                    entry: new_notification.clone(),
331                });
332            }
333
334            if let Some(notification) = new_notification {
335                new_notifications.push(notification, ());
336            }
337        }
338
339        old_range.end = cursor.start().1.0;
340        let new_count = new_notifications.summary().count - old_range.start;
341        new_notifications.append(cursor.suffix(), ());
342        drop(cursor);
343
344        self.notifications = new_notifications;
345        cx.emit(NotificationEvent::NotificationsUpdated {
346            old_range,
347            new_count,
348        });
349    }
350
351    pub fn respond_to_notification(
352        &mut self,
353        notification: Notification,
354        response: bool,
355        cx: &mut Context<Self>,
356    ) {
357        match notification {
358            Notification::ContactRequest { sender_id } => {
359                self.user_store
360                    .update(cx, |store, cx| {
361                        store.respond_to_contact_request(sender_id, response, cx)
362                    })
363                    .detach();
364            }
365            Notification::ChannelInvitation { channel_id, .. } => {
366                self.channel_store
367                    .update(cx, |store, cx| {
368                        store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
369                    })
370                    .detach();
371            }
372            _ => {}
373        }
374    }
375}
376
377impl EventEmitter<NotificationEvent> for NotificationStore {}
378
379impl sum_tree::Item for NotificationEntry {
380    type Summary = NotificationSummary;
381
382    fn summary(&self, _cx: ()) -> Self::Summary {
383        NotificationSummary {
384            max_id: self.id,
385            count: 1,
386            unread_count: if self.is_read { 0 } else { 1 },
387        }
388    }
389}
390
391impl sum_tree::ContextLessSummary for NotificationSummary {
392    fn zero() -> Self {
393        Default::default()
394    }
395
396    fn add_summary(&mut self, summary: &Self) {
397        self.max_id = self.max_id.max(summary.max_id);
398        self.count += summary.count;
399        self.unread_count += summary.unread_count;
400    }
401}
402
403impl sum_tree::Dimension<'_, NotificationSummary> for NotificationId {
404    fn zero(_cx: ()) -> Self {
405        Default::default()
406    }
407
408    fn add_summary(&mut self, summary: &NotificationSummary, _: ()) {
409        debug_assert!(summary.max_id > self.0);
410        self.0 = summary.max_id;
411    }
412}
413
414impl sum_tree::Dimension<'_, NotificationSummary> for Count {
415    fn zero(_cx: ()) -> Self {
416        Default::default()
417    }
418
419    fn add_summary(&mut self, summary: &NotificationSummary, _: ()) {
420        self.0 += summary.count;
421    }
422}
423
424struct AddNotificationsOptions {
425    is_new: bool,
426    clear_old: bool,
427    includes_first: bool,
428}