1use anyhow::{Context as _, Result};
2use channel::ChannelStore;
3use client::{ChannelId, Client, UserStore};
4use db::smol::stream::StreamExt;
5use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task};
6use rpc::{Notification, TypedEnvelope, proto};
7use std::{ops::Range, sync::Arc};
8use sum_tree::{Bias, Dimensions, SumTree};
9use time::OffsetDateTime;
10use util::ResultExt;
11
12pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
13 let notification_store = cx.new(|cx| NotificationStore::new(client, user_store, cx));
14 cx.set_global(GlobalNotificationStore(notification_store));
15}
16
17struct GlobalNotificationStore(Entity<NotificationStore>);
18
19impl Global for GlobalNotificationStore {}
20
21pub struct NotificationStore {
22 client: Arc<Client>,
23 user_store: Entity<UserStore>,
24 channel_store: Entity<ChannelStore>,
25 notifications: SumTree<NotificationEntry>,
26 loaded_all_notifications: bool,
27 _watch_connection_status: Task<Option<()>>,
28 _subscriptions: Vec<client::Subscription>,
29}
30
31#[derive(Clone, PartialEq, Eq, Debug)]
32pub enum NotificationEvent {
33 NotificationsUpdated {
34 old_range: Range<usize>,
35 new_count: usize,
36 },
37 NewNotification {
38 entry: NotificationEntry,
39 },
40 NotificationRemoved {
41 entry: NotificationEntry,
42 },
43 NotificationRead {
44 entry: NotificationEntry,
45 },
46}
47
48#[derive(Debug, PartialEq, Eq, Clone)]
49pub struct NotificationEntry {
50 pub id: u64,
51 pub notification: Notification,
52 pub timestamp: OffsetDateTime,
53 pub is_read: bool,
54 pub response: Option<bool>,
55}
56
57#[derive(Clone, Debug, Default)]
58pub struct NotificationSummary {
59 max_id: u64,
60 count: usize,
61 unread_count: usize,
62}
63
64#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
65struct Count(usize);
66
67#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
68struct NotificationId(u64);
69
70impl NotificationStore {
71 pub fn global(cx: &App) -> Entity<Self> {
72 cx.global::<GlobalNotificationStore>().0.clone()
73 }
74
75 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
76 let mut connection_status = client.status();
77 let watch_connection_status = cx.spawn(async move |this, cx| {
78 while let Some(status) = connection_status.next().await {
79 let this = this.upgrade()?;
80 match status {
81 client::Status::Connected { .. } => {
82 if let Some(task) = this
83 .update(cx, |this, cx| this.handle_connect(cx))
84 .log_err()?
85 {
86 task.await.log_err()?;
87 }
88 }
89 _ => this
90 .update(cx, |this, cx| this.handle_disconnect(cx))
91 .log_err()?,
92 }
93 }
94 Some(())
95 });
96
97 Self {
98 channel_store: ChannelStore::global(cx),
99 notifications: Default::default(),
100 loaded_all_notifications: false,
101 _watch_connection_status: watch_connection_status,
102 _subscriptions: vec![
103 client.add_message_handler(cx.weak_entity(), Self::handle_new_notification),
104 client.add_message_handler(cx.weak_entity(), Self::handle_delete_notification),
105 ],
106 user_store,
107 client,
108 }
109 }
110
111 pub fn notification_count(&self) -> usize {
112 self.notifications.summary().count
113 }
114
115 pub fn unread_notification_count(&self) -> usize {
116 self.notifications.summary().unread_count
117 }
118
119 // Get the nth newest notification.
120 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
121 let count = self.notifications.summary().count;
122 if ix >= count {
123 return None;
124 }
125 let ix = count - 1 - ix;
126 let (.., item) = self
127 .notifications
128 .find::<Count, _>((), &Count(ix), Bias::Right);
129 item
130 }
131 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
132 let (.., item) =
133 self.notifications
134 .find::<NotificationId, _>((), &NotificationId(id), Bias::Left);
135 if let Some(item) = item
136 && item.id == id
137 {
138 return Some(item);
139 }
140 None
141 }
142
143 pub fn load_more_notifications(
144 &self,
145 clear_old: bool,
146 cx: &mut Context<Self>,
147 ) -> Option<Task<Result<()>>> {
148 if self.loaded_all_notifications && !clear_old {
149 return None;
150 }
151
152 let before_id = if clear_old {
153 None
154 } else {
155 self.notifications.first().map(|entry| entry.id)
156 };
157 let request = self.client.request(proto::GetNotifications { before_id });
158 Some(cx.spawn(async move |this, cx| {
159 let this = this
160 .upgrade()
161 .context("Notification store was dropped while loading notifications")?;
162
163 let response = request.await?;
164 this.update(cx, |this, _| this.loaded_all_notifications = response.done)?;
165 Self::add_notifications(
166 this,
167 response.notifications,
168 AddNotificationsOptions {
169 is_new: false,
170 clear_old,
171 includes_first: response.done,
172 },
173 cx,
174 )
175 .await?;
176 Ok(())
177 }))
178 }
179
180 fn handle_connect(&mut self, cx: &mut Context<Self>) -> Option<Task<Result<()>>> {
181 self.notifications = Default::default();
182 cx.notify();
183 self.load_more_notifications(true, cx)
184 }
185
186 fn handle_disconnect(&mut self, cx: &mut Context<Self>) {
187 cx.notify()
188 }
189
190 async fn handle_new_notification(
191 this: Entity<Self>,
192 envelope: TypedEnvelope<proto::AddNotification>,
193 mut cx: AsyncApp,
194 ) -> Result<()> {
195 Self::add_notifications(
196 this,
197 envelope.payload.notification.into_iter().collect(),
198 AddNotificationsOptions {
199 is_new: true,
200 clear_old: false,
201 includes_first: false,
202 },
203 &mut cx,
204 )
205 .await
206 }
207
208 async fn handle_delete_notification(
209 this: Entity<Self>,
210 envelope: TypedEnvelope<proto::DeleteNotification>,
211 mut cx: AsyncApp,
212 ) -> Result<()> {
213 this.update(&mut cx, |this, cx| {
214 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
215 Ok(())
216 })?
217 }
218
219 async fn add_notifications(
220 this: Entity<Self>,
221 notifications: Vec<proto::Notification>,
222 options: AddNotificationsOptions,
223 cx: &mut AsyncApp,
224 ) -> Result<()> {
225 let mut user_ids = Vec::new();
226
227 let notifications = notifications
228 .into_iter()
229 .filter_map(|message| {
230 Some(NotificationEntry {
231 id: message.id,
232 is_read: message.is_read,
233 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
234 .ok()?,
235 notification: Notification::from_proto(&message)?,
236 response: message.response,
237 })
238 })
239 .collect::<Vec<_>>();
240 if notifications.is_empty() {
241 return Ok(());
242 }
243
244 for entry in ¬ifications {
245 match entry.notification {
246 Notification::ChannelInvitation { inviter_id, .. } => {
247 user_ids.push(inviter_id);
248 }
249 Notification::ContactRequest {
250 sender_id: requester_id,
251 } => {
252 user_ids.push(requester_id);
253 }
254 Notification::ContactRequestAccepted {
255 responder_id: contact_id,
256 } => {
257 user_ids.push(contact_id);
258 }
259 }
260 }
261
262 let user_store = this.read_with(cx, |this, _| this.user_store.clone())?;
263
264 user_store
265 .update(cx, |store, cx| store.get_users(user_ids, cx))?
266 .await?;
267 this.update(cx, |this, cx| {
268 if options.clear_old {
269 cx.emit(NotificationEvent::NotificationsUpdated {
270 old_range: 0..this.notifications.summary().count,
271 new_count: 0,
272 });
273 this.notifications = SumTree::default();
274 this.loaded_all_notifications = false;
275 }
276
277 if options.includes_first {
278 this.loaded_all_notifications = true;
279 }
280
281 this.splice_notifications(
282 notifications
283 .into_iter()
284 .map(|notification| (notification.id, Some(notification))),
285 options.is_new,
286 cx,
287 );
288 })
289 .log_err();
290
291 Ok(())
292 }
293
294 fn splice_notifications(
295 &mut self,
296 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
297 is_new: bool,
298 cx: &mut Context<NotificationStore>,
299 ) {
300 let mut cursor = self
301 .notifications
302 .cursor::<Dimensions<NotificationId, Count>>(());
303 let mut new_notifications = SumTree::default();
304 let mut old_range = 0..0;
305
306 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
307 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left), ());
308
309 if i == 0 {
310 old_range.start = cursor.start().1.0;
311 }
312
313 let old_notification = cursor.item();
314 if let Some(old_notification) = old_notification {
315 if old_notification.id == id {
316 cursor.next();
317
318 if let Some(new_notification) = &new_notification {
319 if new_notification.is_read {
320 cx.emit(NotificationEvent::NotificationRead {
321 entry: new_notification.clone(),
322 });
323 }
324 } else {
325 cx.emit(NotificationEvent::NotificationRemoved {
326 entry: old_notification.clone(),
327 });
328 }
329 }
330 } else if let Some(new_notification) = &new_notification
331 && is_new
332 {
333 cx.emit(NotificationEvent::NewNotification {
334 entry: new_notification.clone(),
335 });
336 }
337
338 if let Some(notification) = new_notification {
339 new_notifications.push(notification, ());
340 }
341 }
342
343 old_range.end = cursor.start().1.0;
344 let new_count = new_notifications.summary().count - old_range.start;
345 new_notifications.append(cursor.suffix(), ());
346 drop(cursor);
347
348 self.notifications = new_notifications;
349 cx.emit(NotificationEvent::NotificationsUpdated {
350 old_range,
351 new_count,
352 });
353 }
354
355 pub fn respond_to_notification(
356 &mut self,
357 notification: Notification,
358 response: bool,
359 cx: &mut Context<Self>,
360 ) {
361 match notification {
362 Notification::ContactRequest { sender_id } => {
363 self.user_store
364 .update(cx, |store, cx| {
365 store.respond_to_contact_request(sender_id, response, cx)
366 })
367 .detach();
368 }
369 Notification::ChannelInvitation { channel_id, .. } => {
370 self.channel_store
371 .update(cx, |store, cx| {
372 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
373 })
374 .detach();
375 }
376 _ => {}
377 }
378 }
379}
380
381impl EventEmitter<NotificationEvent> for NotificationStore {}
382
383impl sum_tree::Item for NotificationEntry {
384 type Summary = NotificationSummary;
385
386 fn summary(&self, _cx: ()) -> Self::Summary {
387 NotificationSummary {
388 max_id: self.id,
389 count: 1,
390 unread_count: if self.is_read { 0 } else { 1 },
391 }
392 }
393}
394
395impl sum_tree::ContextLessSummary for NotificationSummary {
396 fn zero() -> Self {
397 Default::default()
398 }
399
400 fn add_summary(&mut self, summary: &Self) {
401 self.max_id = self.max_id.max(summary.max_id);
402 self.count += summary.count;
403 self.unread_count += summary.unread_count;
404 }
405}
406
407impl sum_tree::Dimension<'_, NotificationSummary> for NotificationId {
408 fn zero(_cx: ()) -> Self {
409 Default::default()
410 }
411
412 fn add_summary(&mut self, summary: &NotificationSummary, _: ()) {
413 debug_assert!(summary.max_id > self.0);
414 self.0 = summary.max_id;
415 }
416}
417
418impl sum_tree::Dimension<'_, NotificationSummary> for Count {
419 fn zero(_cx: ()) -> Self {
420 Default::default()
421 }
422
423 fn add_summary(&mut self, summary: &NotificationSummary, _: ()) {
424 self.0 += summary.count;
425 }
426}
427
428struct AddNotificationsOptions {
429 is_new: bool,
430 clear_old: bool,
431 includes_first: bool,
432}