1use anyhow::{Context as _, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{ChannelId, Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task};
7use rpc::{Notification, TypedEnvelope, proto};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, Dimensions, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
14 let notification_store = cx.new(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(GlobalNotificationStore(notification_store));
16}
17
18struct GlobalNotificationStore(Entity<NotificationStore>);
19
20impl Global for GlobalNotificationStore {}
21
22pub struct NotificationStore {
23 client: Arc<Client>,
24 user_store: Entity<UserStore>,
25 channel_messages: HashMap<u64, ChannelMessage>,
26 channel_store: Entity<ChannelStore>,
27 notifications: SumTree<NotificationEntry>,
28 loaded_all_notifications: bool,
29 _watch_connection_status: Task<Option<()>>,
30 _subscriptions: Vec<client::Subscription>,
31}
32
33#[derive(Clone, PartialEq, Eq, Debug)]
34pub enum NotificationEvent {
35 NotificationsUpdated {
36 old_range: Range<usize>,
37 new_count: usize,
38 },
39 NewNotification {
40 entry: NotificationEntry,
41 },
42 NotificationRemoved {
43 entry: NotificationEntry,
44 },
45 NotificationRead {
46 entry: NotificationEntry,
47 },
48}
49
50#[derive(Debug, PartialEq, Eq, Clone)]
51pub struct NotificationEntry {
52 pub id: u64,
53 pub notification: Notification,
54 pub timestamp: OffsetDateTime,
55 pub is_read: bool,
56 pub response: Option<bool>,
57}
58
59#[derive(Clone, Debug, Default)]
60pub struct NotificationSummary {
61 max_id: u64,
62 count: usize,
63 unread_count: usize,
64}
65
66#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
67struct Count(usize);
68
69#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
70struct NotificationId(u64);
71
72impl NotificationStore {
73 pub fn global(cx: &App) -> Entity<Self> {
74 cx.global::<GlobalNotificationStore>().0.clone()
75 }
76
77 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
78 let mut connection_status = client.status();
79 let watch_connection_status = cx.spawn(async move |this, cx| {
80 while let Some(status) = connection_status.next().await {
81 let this = this.upgrade()?;
82 match status {
83 client::Status::Connected { .. } => {
84 if let Some(task) = this
85 .update(cx, |this, cx| this.handle_connect(cx))
86 .log_err()?
87 {
88 task.await.log_err()?;
89 }
90 }
91 _ => this
92 .update(cx, |this, cx| this.handle_disconnect(cx))
93 .log_err()?,
94 }
95 }
96 Some(())
97 });
98
99 Self {
100 channel_store: ChannelStore::global(cx),
101 notifications: Default::default(),
102 loaded_all_notifications: false,
103 channel_messages: Default::default(),
104 _watch_connection_status: watch_connection_status,
105 _subscriptions: vec![
106 client.add_message_handler(cx.weak_entity(), Self::handle_new_notification),
107 client.add_message_handler(cx.weak_entity(), Self::handle_delete_notification),
108 client.add_message_handler(cx.weak_entity(), Self::handle_update_notification),
109 ],
110 user_store,
111 client,
112 }
113 }
114
115 pub fn notification_count(&self) -> usize {
116 self.notifications.summary().count
117 }
118
119 pub fn unread_notification_count(&self) -> usize {
120 self.notifications.summary().unread_count
121 }
122
123 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
124 self.channel_messages.get(&id)
125 }
126
127 // Get the nth newest notification.
128 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
129 let count = self.notifications.summary().count;
130 if ix >= count {
131 return None;
132 }
133 let ix = count - 1 - ix;
134 let mut cursor = self.notifications.cursor::<Count>(&());
135 cursor.seek(&Count(ix), Bias::Right);
136 cursor.item()
137 }
138 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
139 let mut cursor = self.notifications.cursor::<NotificationId>(&());
140 cursor.seek(&NotificationId(id), Bias::Left);
141 if let Some(item) = cursor.item()
142 && item.id == id {
143 return Some(item);
144 }
145 None
146 }
147
148 pub fn load_more_notifications(
149 &self,
150 clear_old: bool,
151 cx: &mut Context<Self>,
152 ) -> Option<Task<Result<()>>> {
153 if self.loaded_all_notifications && !clear_old {
154 return None;
155 }
156
157 let before_id = if clear_old {
158 None
159 } else {
160 self.notifications.first().map(|entry| entry.id)
161 };
162 let request = self.client.request(proto::GetNotifications { before_id });
163 Some(cx.spawn(async move |this, cx| {
164 let this = this
165 .upgrade()
166 .context("Notification store was dropped while loading notifications")?;
167
168 let response = request.await?;
169 this.update(cx, |this, _| this.loaded_all_notifications = response.done)?;
170 Self::add_notifications(
171 this,
172 response.notifications,
173 AddNotificationsOptions {
174 is_new: false,
175 clear_old,
176 includes_first: response.done,
177 },
178 cx,
179 )
180 .await?;
181 Ok(())
182 }))
183 }
184
185 fn handle_connect(&mut self, cx: &mut Context<Self>) -> Option<Task<Result<()>>> {
186 self.notifications = Default::default();
187 self.channel_messages = Default::default();
188 cx.notify();
189 self.load_more_notifications(true, cx)
190 }
191
192 fn handle_disconnect(&mut self, cx: &mut Context<Self>) {
193 cx.notify()
194 }
195
196 async fn handle_new_notification(
197 this: Entity<Self>,
198 envelope: TypedEnvelope<proto::AddNotification>,
199 mut cx: AsyncApp,
200 ) -> Result<()> {
201 Self::add_notifications(
202 this,
203 envelope.payload.notification.into_iter().collect(),
204 AddNotificationsOptions {
205 is_new: true,
206 clear_old: false,
207 includes_first: false,
208 },
209 &mut cx,
210 )
211 .await
212 }
213
214 async fn handle_delete_notification(
215 this: Entity<Self>,
216 envelope: TypedEnvelope<proto::DeleteNotification>,
217 mut cx: AsyncApp,
218 ) -> Result<()> {
219 this.update(&mut cx, |this, cx| {
220 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
221 Ok(())
222 })?
223 }
224
225 async fn handle_update_notification(
226 this: Entity<Self>,
227 envelope: TypedEnvelope<proto::UpdateNotification>,
228 mut cx: AsyncApp,
229 ) -> Result<()> {
230 this.update(&mut cx, |this, cx| {
231 if let Some(notification) = envelope.payload.notification
232 && let Some(rpc::Notification::ChannelMessageMention { message_id, .. }) =
233 Notification::from_proto(¬ification)
234 {
235 let fetch_message_task = this.channel_store.update(cx, |this, cx| {
236 this.fetch_channel_messages(vec![message_id], cx)
237 });
238
239 cx.spawn(async move |this, cx| {
240 let messages = fetch_message_task.await?;
241 this.update(cx, move |this, cx| {
242 for message in messages {
243 this.channel_messages.insert(message_id, message);
244 }
245 cx.notify();
246 })
247 })
248 .detach_and_log_err(cx)
249 }
250 Ok(())
251 })?
252 }
253
254 async fn add_notifications(
255 this: Entity<Self>,
256 notifications: Vec<proto::Notification>,
257 options: AddNotificationsOptions,
258 cx: &mut AsyncApp,
259 ) -> Result<()> {
260 let mut user_ids = Vec::new();
261 let mut message_ids = Vec::new();
262
263 let notifications = notifications
264 .into_iter()
265 .filter_map(|message| {
266 Some(NotificationEntry {
267 id: message.id,
268 is_read: message.is_read,
269 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
270 .ok()?,
271 notification: Notification::from_proto(&message)?,
272 response: message.response,
273 })
274 })
275 .collect::<Vec<_>>();
276 if notifications.is_empty() {
277 return Ok(());
278 }
279
280 for entry in ¬ifications {
281 match entry.notification {
282 Notification::ChannelInvitation { inviter_id, .. } => {
283 user_ids.push(inviter_id);
284 }
285 Notification::ContactRequest {
286 sender_id: requester_id,
287 } => {
288 user_ids.push(requester_id);
289 }
290 Notification::ContactRequestAccepted {
291 responder_id: contact_id,
292 } => {
293 user_ids.push(contact_id);
294 }
295 Notification::ChannelMessageMention {
296 sender_id,
297 message_id,
298 ..
299 } => {
300 user_ids.push(sender_id);
301 message_ids.push(message_id);
302 }
303 }
304 }
305
306 let (user_store, channel_store) = this.read_with(cx, |this, _| {
307 (this.user_store.clone(), this.channel_store.clone())
308 })?;
309
310 user_store
311 .update(cx, |store, cx| store.get_users(user_ids, cx))?
312 .await?;
313 let messages = channel_store
314 .update(cx, |store, cx| {
315 store.fetch_channel_messages(message_ids, cx)
316 })?
317 .await?;
318 this.update(cx, |this, cx| {
319 if options.clear_old {
320 cx.emit(NotificationEvent::NotificationsUpdated {
321 old_range: 0..this.notifications.summary().count,
322 new_count: 0,
323 });
324 this.notifications = SumTree::default();
325 this.channel_messages.clear();
326 this.loaded_all_notifications = false;
327 }
328
329 if options.includes_first {
330 this.loaded_all_notifications = true;
331 }
332
333 this.channel_messages
334 .extend(messages.into_iter().filter_map(|message| {
335 if let ChannelMessageId::Saved(id) = message.id {
336 Some((id, message))
337 } else {
338 None
339 }
340 }));
341
342 this.splice_notifications(
343 notifications
344 .into_iter()
345 .map(|notification| (notification.id, Some(notification))),
346 options.is_new,
347 cx,
348 );
349 })
350 .log_err();
351
352 Ok(())
353 }
354
355 fn splice_notifications(
356 &mut self,
357 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
358 is_new: bool,
359 cx: &mut Context<NotificationStore>,
360 ) {
361 let mut cursor = self
362 .notifications
363 .cursor::<Dimensions<NotificationId, Count>>(&());
364 let mut new_notifications = SumTree::default();
365 let mut old_range = 0..0;
366
367 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
368 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left), &());
369
370 if i == 0 {
371 old_range.start = cursor.start().1.0;
372 }
373
374 let old_notification = cursor.item();
375 if let Some(old_notification) = old_notification {
376 if old_notification.id == id {
377 cursor.next();
378
379 if let Some(new_notification) = &new_notification {
380 if new_notification.is_read {
381 cx.emit(NotificationEvent::NotificationRead {
382 entry: new_notification.clone(),
383 });
384 }
385 } else {
386 cx.emit(NotificationEvent::NotificationRemoved {
387 entry: old_notification.clone(),
388 });
389 }
390 }
391 } else if let Some(new_notification) = &new_notification
392 && is_new {
393 cx.emit(NotificationEvent::NewNotification {
394 entry: new_notification.clone(),
395 });
396 }
397
398 if let Some(notification) = new_notification {
399 new_notifications.push(notification, &());
400 }
401 }
402
403 old_range.end = cursor.start().1.0;
404 let new_count = new_notifications.summary().count - old_range.start;
405 new_notifications.append(cursor.suffix(), &());
406 drop(cursor);
407
408 self.notifications = new_notifications;
409 cx.emit(NotificationEvent::NotificationsUpdated {
410 old_range,
411 new_count,
412 });
413 }
414
415 pub fn respond_to_notification(
416 &mut self,
417 notification: Notification,
418 response: bool,
419 cx: &mut Context<Self>,
420 ) {
421 match notification {
422 Notification::ContactRequest { sender_id } => {
423 self.user_store
424 .update(cx, |store, cx| {
425 store.respond_to_contact_request(sender_id, response, cx)
426 })
427 .detach();
428 }
429 Notification::ChannelInvitation { channel_id, .. } => {
430 self.channel_store
431 .update(cx, |store, cx| {
432 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
433 })
434 .detach();
435 }
436 _ => {}
437 }
438 }
439}
440
441impl EventEmitter<NotificationEvent> for NotificationStore {}
442
443impl sum_tree::Item for NotificationEntry {
444 type Summary = NotificationSummary;
445
446 fn summary(&self, _cx: &()) -> Self::Summary {
447 NotificationSummary {
448 max_id: self.id,
449 count: 1,
450 unread_count: if self.is_read { 0 } else { 1 },
451 }
452 }
453}
454
455impl sum_tree::Summary for NotificationSummary {
456 type Context = ();
457
458 fn zero(_cx: &()) -> Self {
459 Default::default()
460 }
461
462 fn add_summary(&mut self, summary: &Self, _: &()) {
463 self.max_id = self.max_id.max(summary.max_id);
464 self.count += summary.count;
465 self.unread_count += summary.unread_count;
466 }
467}
468
469impl sum_tree::Dimension<'_, NotificationSummary> for NotificationId {
470 fn zero(_cx: &()) -> Self {
471 Default::default()
472 }
473
474 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
475 debug_assert!(summary.max_id > self.0);
476 self.0 = summary.max_id;
477 }
478}
479
480impl sum_tree::Dimension<'_, NotificationSummary> for Count {
481 fn zero(_cx: &()) -> Self {
482 Default::default()
483 }
484
485 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
486 self.0 += summary.count;
487 }
488}
489
490struct AddNotificationsOptions {
491 is_new: bool,
492 clear_old: bool,
493 includes_first: bool,
494}