1use anyhow::{Context as _, Result};
2use channel::ChannelStore;
3use client::{ChannelId, Client, UserStore};
4use db::smol::stream::StreamExt;
5use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task};
6use rpc::{Notification, TypedEnvelope, proto};
7use std::{ops::Range, sync::Arc};
8use sum_tree::{Bias, Dimensions, SumTree};
9use time::OffsetDateTime;
10use util::ResultExt;
11
12pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
13 let notification_store = cx.new(|cx| NotificationStore::new(client, user_store, cx));
14 cx.set_global(GlobalNotificationStore(notification_store));
15}
16
17struct GlobalNotificationStore(Entity<NotificationStore>);
18
19impl Global for GlobalNotificationStore {}
20
21pub struct NotificationStore {
22 client: Arc<Client>,
23 user_store: Entity<UserStore>,
24 channel_store: Entity<ChannelStore>,
25 notifications: SumTree<NotificationEntry>,
26 loaded_all_notifications: bool,
27 _watch_connection_status: Task<Option<()>>,
28 _subscriptions: Vec<client::Subscription>,
29}
30
31#[derive(Clone, PartialEq, Eq, Debug)]
32pub enum NotificationEvent {
33 NotificationsUpdated {
34 old_range: Range<usize>,
35 new_count: usize,
36 },
37 NewNotification {
38 entry: NotificationEntry,
39 },
40 NotificationRemoved {
41 entry: NotificationEntry,
42 },
43 NotificationRead {
44 entry: NotificationEntry,
45 },
46}
47
48#[derive(Debug, PartialEq, Eq, Clone)]
49pub struct NotificationEntry {
50 pub id: u64,
51 pub notification: Notification,
52 pub timestamp: OffsetDateTime,
53 pub is_read: bool,
54 pub response: Option<bool>,
55}
56
57#[derive(Clone, Debug, Default)]
58pub struct NotificationSummary {
59 max_id: u64,
60 count: usize,
61 unread_count: usize,
62}
63
64#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
65struct Count(usize);
66
67#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
68struct NotificationId(u64);
69
70impl NotificationStore {
71 pub fn global(cx: &App) -> Entity<Self> {
72 cx.global::<GlobalNotificationStore>().0.clone()
73 }
74
75 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
76 let mut connection_status = client.status();
77 let watch_connection_status = cx.spawn(async move |this, cx| {
78 while let Some(status) = connection_status.next().await {
79 let this = this.upgrade()?;
80 match status {
81 client::Status::Connected { .. } => {
82 if let Some(task) = this
83 .update(cx, |this, cx| this.handle_connect(cx))
84 .log_err()?
85 {
86 task.await.log_err()?;
87 }
88 }
89 _ => this
90 .update(cx, |this, cx| this.handle_disconnect(cx))
91 .log_err()?,
92 }
93 }
94 Some(())
95 });
96
97 Self {
98 channel_store: ChannelStore::global(cx),
99 notifications: Default::default(),
100 loaded_all_notifications: false,
101 _watch_connection_status: watch_connection_status,
102 _subscriptions: vec![
103 client.add_message_handler(cx.weak_entity(), Self::handle_new_notification),
104 client.add_message_handler(cx.weak_entity(), Self::handle_delete_notification),
105 ],
106 user_store,
107 client,
108 }
109 }
110
111 pub fn notification_count(&self) -> usize {
112 self.notifications.summary().count
113 }
114
115 pub fn unread_notification_count(&self) -> usize {
116 self.notifications.summary().unread_count
117 }
118
119 // Get the nth newest notification.
120 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
121 let count = self.notifications.summary().count;
122 if ix >= count {
123 return None;
124 }
125 let ix = count - 1 - ix;
126 let mut cursor = self.notifications.cursor::<Count>(());
127 cursor.seek(&Count(ix), Bias::Right);
128 cursor.item()
129 }
130 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
131 let mut cursor = self.notifications.cursor::<NotificationId>(());
132 cursor.seek(&NotificationId(id), Bias::Left);
133 if let Some(item) = cursor.item()
134 && item.id == id
135 {
136 return Some(item);
137 }
138 None
139 }
140
141 pub fn load_more_notifications(
142 &self,
143 clear_old: bool,
144 cx: &mut Context<Self>,
145 ) -> Option<Task<Result<()>>> {
146 if self.loaded_all_notifications && !clear_old {
147 return None;
148 }
149
150 let before_id = if clear_old {
151 None
152 } else {
153 self.notifications.first().map(|entry| entry.id)
154 };
155 let request = self.client.request(proto::GetNotifications { before_id });
156 Some(cx.spawn(async move |this, cx| {
157 let this = this
158 .upgrade()
159 .context("Notification store was dropped while loading notifications")?;
160
161 let response = request.await?;
162 this.update(cx, |this, _| this.loaded_all_notifications = response.done)?;
163 Self::add_notifications(
164 this,
165 response.notifications,
166 AddNotificationsOptions {
167 is_new: false,
168 clear_old,
169 includes_first: response.done,
170 },
171 cx,
172 )
173 .await?;
174 Ok(())
175 }))
176 }
177
178 fn handle_connect(&mut self, cx: &mut Context<Self>) -> Option<Task<Result<()>>> {
179 self.notifications = Default::default();
180 cx.notify();
181 self.load_more_notifications(true, cx)
182 }
183
184 fn handle_disconnect(&mut self, cx: &mut Context<Self>) {
185 cx.notify()
186 }
187
188 async fn handle_new_notification(
189 this: Entity<Self>,
190 envelope: TypedEnvelope<proto::AddNotification>,
191 mut cx: AsyncApp,
192 ) -> Result<()> {
193 Self::add_notifications(
194 this,
195 envelope.payload.notification.into_iter().collect(),
196 AddNotificationsOptions {
197 is_new: true,
198 clear_old: false,
199 includes_first: false,
200 },
201 &mut cx,
202 )
203 .await
204 }
205
206 async fn handle_delete_notification(
207 this: Entity<Self>,
208 envelope: TypedEnvelope<proto::DeleteNotification>,
209 mut cx: AsyncApp,
210 ) -> Result<()> {
211 this.update(&mut cx, |this, cx| {
212 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
213 Ok(())
214 })?
215 }
216
217 async fn add_notifications(
218 this: Entity<Self>,
219 notifications: Vec<proto::Notification>,
220 options: AddNotificationsOptions,
221 cx: &mut AsyncApp,
222 ) -> Result<()> {
223 let mut user_ids = Vec::new();
224
225 let notifications = notifications
226 .into_iter()
227 .filter_map(|message| {
228 Some(NotificationEntry {
229 id: message.id,
230 is_read: message.is_read,
231 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
232 .ok()?,
233 notification: Notification::from_proto(&message)?,
234 response: message.response,
235 })
236 })
237 .collect::<Vec<_>>();
238 if notifications.is_empty() {
239 return Ok(());
240 }
241
242 for entry in ¬ifications {
243 match entry.notification {
244 Notification::ChannelInvitation { inviter_id, .. } => {
245 user_ids.push(inviter_id);
246 }
247 Notification::ContactRequest {
248 sender_id: requester_id,
249 } => {
250 user_ids.push(requester_id);
251 }
252 Notification::ContactRequestAccepted {
253 responder_id: contact_id,
254 } => {
255 user_ids.push(contact_id);
256 }
257 }
258 }
259
260 let user_store = this.read_with(cx, |this, _| this.user_store.clone())?;
261
262 user_store
263 .update(cx, |store, cx| store.get_users(user_ids, cx))?
264 .await?;
265 this.update(cx, |this, cx| {
266 if options.clear_old {
267 cx.emit(NotificationEvent::NotificationsUpdated {
268 old_range: 0..this.notifications.summary().count,
269 new_count: 0,
270 });
271 this.notifications = SumTree::default();
272 this.loaded_all_notifications = false;
273 }
274
275 if options.includes_first {
276 this.loaded_all_notifications = true;
277 }
278
279 this.splice_notifications(
280 notifications
281 .into_iter()
282 .map(|notification| (notification.id, Some(notification))),
283 options.is_new,
284 cx,
285 );
286 })
287 .log_err();
288
289 Ok(())
290 }
291
292 fn splice_notifications(
293 &mut self,
294 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
295 is_new: bool,
296 cx: &mut Context<NotificationStore>,
297 ) {
298 let mut cursor = self
299 .notifications
300 .cursor::<Dimensions<NotificationId, Count>>(());
301 let mut new_notifications = SumTree::default();
302 let mut old_range = 0..0;
303
304 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
305 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left), ());
306
307 if i == 0 {
308 old_range.start = cursor.start().1.0;
309 }
310
311 let old_notification = cursor.item();
312 if let Some(old_notification) = old_notification {
313 if old_notification.id == id {
314 cursor.next();
315
316 if let Some(new_notification) = &new_notification {
317 if new_notification.is_read {
318 cx.emit(NotificationEvent::NotificationRead {
319 entry: new_notification.clone(),
320 });
321 }
322 } else {
323 cx.emit(NotificationEvent::NotificationRemoved {
324 entry: old_notification.clone(),
325 });
326 }
327 }
328 } else if let Some(new_notification) = &new_notification
329 && is_new
330 {
331 cx.emit(NotificationEvent::NewNotification {
332 entry: new_notification.clone(),
333 });
334 }
335
336 if let Some(notification) = new_notification {
337 new_notifications.push(notification, ());
338 }
339 }
340
341 old_range.end = cursor.start().1.0;
342 let new_count = new_notifications.summary().count - old_range.start;
343 new_notifications.append(cursor.suffix(), ());
344 drop(cursor);
345
346 self.notifications = new_notifications;
347 cx.emit(NotificationEvent::NotificationsUpdated {
348 old_range,
349 new_count,
350 });
351 }
352
353 pub fn respond_to_notification(
354 &mut self,
355 notification: Notification,
356 response: bool,
357 cx: &mut Context<Self>,
358 ) {
359 match notification {
360 Notification::ContactRequest { sender_id } => {
361 self.user_store
362 .update(cx, |store, cx| {
363 store.respond_to_contact_request(sender_id, response, cx)
364 })
365 .detach();
366 }
367 Notification::ChannelInvitation { channel_id, .. } => {
368 self.channel_store
369 .update(cx, |store, cx| {
370 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
371 })
372 .detach();
373 }
374 _ => {}
375 }
376 }
377}
378
379impl EventEmitter<NotificationEvent> for NotificationStore {}
380
381impl sum_tree::Item for NotificationEntry {
382 type Summary = NotificationSummary;
383
384 fn summary(&self, _cx: ()) -> Self::Summary {
385 NotificationSummary {
386 max_id: self.id,
387 count: 1,
388 unread_count: if self.is_read { 0 } else { 1 },
389 }
390 }
391}
392
393impl sum_tree::ContextLessSummary for NotificationSummary {
394 fn zero() -> Self {
395 Default::default()
396 }
397
398 fn add_summary(&mut self, summary: &Self) {
399 self.max_id = self.max_id.max(summary.max_id);
400 self.count += summary.count;
401 self.unread_count += summary.unread_count;
402 }
403}
404
405impl sum_tree::Dimension<'_, NotificationSummary> for NotificationId {
406 fn zero(_cx: ()) -> Self {
407 Default::default()
408 }
409
410 fn add_summary(&mut self, summary: &NotificationSummary, _: ()) {
411 debug_assert!(summary.max_id > self.0);
412 self.0 = summary.max_id;
413 }
414}
415
416impl sum_tree::Dimension<'_, NotificationSummary> for Count {
417 fn zero(_cx: ()) -> Self {
418 Default::default()
419 }
420
421 fn add_summary(&mut self, summary: &NotificationSummary, _: ()) {
422 self.0 += summary.count;
423 }
424}
425
426struct AddNotificationsOptions {
427 is_new: bool,
428 clear_old: bool,
429 includes_first: bool,
430}