1use anyhow::{Context as _, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{ChannelId, Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task};
7use rpc::{Notification, TypedEnvelope, proto};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, Dimensions, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
14 let notification_store = cx.new(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(GlobalNotificationStore(notification_store));
16}
17
18struct GlobalNotificationStore(Entity<NotificationStore>);
19
20impl Global for GlobalNotificationStore {}
21
22pub struct NotificationStore {
23 client: Arc<Client>,
24 user_store: Entity<UserStore>,
25 channel_messages: HashMap<u64, ChannelMessage>,
26 channel_store: Entity<ChannelStore>,
27 notifications: SumTree<NotificationEntry>,
28 loaded_all_notifications: bool,
29 _watch_connection_status: Task<Option<()>>,
30 _subscriptions: Vec<client::Subscription>,
31}
32
33#[derive(Clone, PartialEq, Eq, Debug)]
34pub enum NotificationEvent {
35 NotificationsUpdated {
36 old_range: Range<usize>,
37 new_count: usize,
38 },
39 NewNotification {
40 entry: NotificationEntry,
41 },
42 NotificationRemoved {
43 entry: NotificationEntry,
44 },
45 NotificationRead {
46 entry: NotificationEntry,
47 },
48}
49
50#[derive(Debug, PartialEq, Eq, Clone)]
51pub struct NotificationEntry {
52 pub id: u64,
53 pub notification: Notification,
54 pub timestamp: OffsetDateTime,
55 pub is_read: bool,
56 pub response: Option<bool>,
57}
58
59#[derive(Clone, Debug, Default)]
60pub struct NotificationSummary {
61 max_id: u64,
62 count: usize,
63 unread_count: usize,
64}
65
66#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
67struct Count(usize);
68
69#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
70struct NotificationId(u64);
71
72impl NotificationStore {
73 pub fn global(cx: &App) -> Entity<Self> {
74 cx.global::<GlobalNotificationStore>().0.clone()
75 }
76
77 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
78 let mut connection_status = client.status();
79 let watch_connection_status = cx.spawn(async move |this, cx| {
80 while let Some(status) = connection_status.next().await {
81 let this = this.upgrade()?;
82 match status {
83 client::Status::Connected { .. } => {
84 if let Some(task) = this
85 .update(cx, |this, cx| this.handle_connect(cx))
86 .log_err()?
87 {
88 task.await.log_err()?;
89 }
90 }
91 _ => this
92 .update(cx, |this, cx| this.handle_disconnect(cx))
93 .log_err()?,
94 }
95 }
96 Some(())
97 });
98
99 Self {
100 channel_store: ChannelStore::global(cx),
101 notifications: Default::default(),
102 loaded_all_notifications: false,
103 channel_messages: Default::default(),
104 _watch_connection_status: watch_connection_status,
105 _subscriptions: vec![
106 client.add_message_handler(cx.weak_entity(), Self::handle_new_notification),
107 client.add_message_handler(cx.weak_entity(), Self::handle_delete_notification),
108 client.add_message_handler(cx.weak_entity(), Self::handle_update_notification),
109 ],
110 user_store,
111 client,
112 }
113 }
114
115 pub fn notification_count(&self) -> usize {
116 self.notifications.summary().count
117 }
118
119 pub fn unread_notification_count(&self) -> usize {
120 self.notifications.summary().unread_count
121 }
122
123 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
124 self.channel_messages.get(&id)
125 }
126
127 // Get the nth newest notification.
128 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
129 let count = self.notifications.summary().count;
130 if ix >= count {
131 return None;
132 }
133 let ix = count - 1 - ix;
134 let mut cursor = self.notifications.cursor::<Count>(&());
135 cursor.seek(&Count(ix), Bias::Right);
136 cursor.item()
137 }
138 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
139 let mut cursor = self.notifications.cursor::<NotificationId>(&());
140 cursor.seek(&NotificationId(id), Bias::Left);
141 if let Some(item) = cursor.item()
142 && item.id == id
143 {
144 return Some(item);
145 }
146 None
147 }
148
149 pub fn load_more_notifications(
150 &self,
151 clear_old: bool,
152 cx: &mut Context<Self>,
153 ) -> Option<Task<Result<()>>> {
154 if self.loaded_all_notifications && !clear_old {
155 return None;
156 }
157
158 let before_id = if clear_old {
159 None
160 } else {
161 self.notifications.first().map(|entry| entry.id)
162 };
163 let request = self.client.request(proto::GetNotifications { before_id });
164 Some(cx.spawn(async move |this, cx| {
165 let this = this
166 .upgrade()
167 .context("Notification store was dropped while loading notifications")?;
168
169 let response = request.await?;
170 this.update(cx, |this, _| this.loaded_all_notifications = response.done)?;
171 Self::add_notifications(
172 this,
173 response.notifications,
174 AddNotificationsOptions {
175 is_new: false,
176 clear_old,
177 includes_first: response.done,
178 },
179 cx,
180 )
181 .await?;
182 Ok(())
183 }))
184 }
185
186 fn handle_connect(&mut self, cx: &mut Context<Self>) -> Option<Task<Result<()>>> {
187 self.notifications = Default::default();
188 self.channel_messages = Default::default();
189 cx.notify();
190 self.load_more_notifications(true, cx)
191 }
192
193 fn handle_disconnect(&mut self, cx: &mut Context<Self>) {
194 cx.notify()
195 }
196
197 async fn handle_new_notification(
198 this: Entity<Self>,
199 envelope: TypedEnvelope<proto::AddNotification>,
200 mut cx: AsyncApp,
201 ) -> Result<()> {
202 Self::add_notifications(
203 this,
204 envelope.payload.notification.into_iter().collect(),
205 AddNotificationsOptions {
206 is_new: true,
207 clear_old: false,
208 includes_first: false,
209 },
210 &mut cx,
211 )
212 .await
213 }
214
215 async fn handle_delete_notification(
216 this: Entity<Self>,
217 envelope: TypedEnvelope<proto::DeleteNotification>,
218 mut cx: AsyncApp,
219 ) -> Result<()> {
220 this.update(&mut cx, |this, cx| {
221 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
222 Ok(())
223 })?
224 }
225
226 async fn handle_update_notification(
227 this: Entity<Self>,
228 envelope: TypedEnvelope<proto::UpdateNotification>,
229 mut cx: AsyncApp,
230 ) -> Result<()> {
231 this.update(&mut cx, |this, cx| {
232 if let Some(notification) = envelope.payload.notification
233 && let Some(rpc::Notification::ChannelMessageMention { message_id, .. }) =
234 Notification::from_proto(¬ification)
235 {
236 let fetch_message_task = this.channel_store.update(cx, |this, cx| {
237 this.fetch_channel_messages(vec![message_id], cx)
238 });
239
240 cx.spawn(async move |this, cx| {
241 let messages = fetch_message_task.await?;
242 this.update(cx, move |this, cx| {
243 for message in messages {
244 this.channel_messages.insert(message_id, message);
245 }
246 cx.notify();
247 })
248 })
249 .detach_and_log_err(cx)
250 }
251 Ok(())
252 })?
253 }
254
255 async fn add_notifications(
256 this: Entity<Self>,
257 notifications: Vec<proto::Notification>,
258 options: AddNotificationsOptions,
259 cx: &mut AsyncApp,
260 ) -> Result<()> {
261 let mut user_ids = Vec::new();
262 let mut message_ids = Vec::new();
263
264 let notifications = notifications
265 .into_iter()
266 .filter_map(|message| {
267 Some(NotificationEntry {
268 id: message.id,
269 is_read: message.is_read,
270 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
271 .ok()?,
272 notification: Notification::from_proto(&message)?,
273 response: message.response,
274 })
275 })
276 .collect::<Vec<_>>();
277 if notifications.is_empty() {
278 return Ok(());
279 }
280
281 for entry in ¬ifications {
282 match entry.notification {
283 Notification::ChannelInvitation { inviter_id, .. } => {
284 user_ids.push(inviter_id);
285 }
286 Notification::ContactRequest {
287 sender_id: requester_id,
288 } => {
289 user_ids.push(requester_id);
290 }
291 Notification::ContactRequestAccepted {
292 responder_id: contact_id,
293 } => {
294 user_ids.push(contact_id);
295 }
296 Notification::ChannelMessageMention {
297 sender_id,
298 message_id,
299 ..
300 } => {
301 user_ids.push(sender_id);
302 message_ids.push(message_id);
303 }
304 }
305 }
306
307 let (user_store, channel_store) = this.read_with(cx, |this, _| {
308 (this.user_store.clone(), this.channel_store.clone())
309 })?;
310
311 user_store
312 .update(cx, |store, cx| store.get_users(user_ids, cx))?
313 .await?;
314 let messages = channel_store
315 .update(cx, |store, cx| {
316 store.fetch_channel_messages(message_ids, cx)
317 })?
318 .await?;
319 this.update(cx, |this, cx| {
320 if options.clear_old {
321 cx.emit(NotificationEvent::NotificationsUpdated {
322 old_range: 0..this.notifications.summary().count,
323 new_count: 0,
324 });
325 this.notifications = SumTree::default();
326 this.channel_messages.clear();
327 this.loaded_all_notifications = false;
328 }
329
330 if options.includes_first {
331 this.loaded_all_notifications = true;
332 }
333
334 this.channel_messages
335 .extend(messages.into_iter().filter_map(|message| {
336 if let ChannelMessageId::Saved(id) = message.id {
337 Some((id, message))
338 } else {
339 None
340 }
341 }));
342
343 this.splice_notifications(
344 notifications
345 .into_iter()
346 .map(|notification| (notification.id, Some(notification))),
347 options.is_new,
348 cx,
349 );
350 })
351 .log_err();
352
353 Ok(())
354 }
355
356 fn splice_notifications(
357 &mut self,
358 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
359 is_new: bool,
360 cx: &mut Context<NotificationStore>,
361 ) {
362 let mut cursor = self
363 .notifications
364 .cursor::<Dimensions<NotificationId, Count>>(&());
365 let mut new_notifications = SumTree::default();
366 let mut old_range = 0..0;
367
368 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
369 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left), &());
370
371 if i == 0 {
372 old_range.start = cursor.start().1.0;
373 }
374
375 let old_notification = cursor.item();
376 if let Some(old_notification) = old_notification {
377 if old_notification.id == id {
378 cursor.next();
379
380 if let Some(new_notification) = &new_notification {
381 if new_notification.is_read {
382 cx.emit(NotificationEvent::NotificationRead {
383 entry: new_notification.clone(),
384 });
385 }
386 } else {
387 cx.emit(NotificationEvent::NotificationRemoved {
388 entry: old_notification.clone(),
389 });
390 }
391 }
392 } else if let Some(new_notification) = &new_notification
393 && is_new
394 {
395 cx.emit(NotificationEvent::NewNotification {
396 entry: new_notification.clone(),
397 });
398 }
399
400 if let Some(notification) = new_notification {
401 new_notifications.push(notification, &());
402 }
403 }
404
405 old_range.end = cursor.start().1.0;
406 let new_count = new_notifications.summary().count - old_range.start;
407 new_notifications.append(cursor.suffix(), &());
408 drop(cursor);
409
410 self.notifications = new_notifications;
411 cx.emit(NotificationEvent::NotificationsUpdated {
412 old_range,
413 new_count,
414 });
415 }
416
417 pub fn respond_to_notification(
418 &mut self,
419 notification: Notification,
420 response: bool,
421 cx: &mut Context<Self>,
422 ) {
423 match notification {
424 Notification::ContactRequest { sender_id } => {
425 self.user_store
426 .update(cx, |store, cx| {
427 store.respond_to_contact_request(sender_id, response, cx)
428 })
429 .detach();
430 }
431 Notification::ChannelInvitation { channel_id, .. } => {
432 self.channel_store
433 .update(cx, |store, cx| {
434 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
435 })
436 .detach();
437 }
438 _ => {}
439 }
440 }
441}
442
443impl EventEmitter<NotificationEvent> for NotificationStore {}
444
445impl sum_tree::Item for NotificationEntry {
446 type Summary = NotificationSummary;
447
448 fn summary(&self, _cx: &()) -> Self::Summary {
449 NotificationSummary {
450 max_id: self.id,
451 count: 1,
452 unread_count: if self.is_read { 0 } else { 1 },
453 }
454 }
455}
456
457impl sum_tree::Summary for NotificationSummary {
458 type Context = ();
459
460 fn zero(_cx: &()) -> Self {
461 Default::default()
462 }
463
464 fn add_summary(&mut self, summary: &Self, _: &()) {
465 self.max_id = self.max_id.max(summary.max_id);
466 self.count += summary.count;
467 self.unread_count += summary.unread_count;
468 }
469}
470
471impl sum_tree::Dimension<'_, NotificationSummary> for NotificationId {
472 fn zero(_cx: &()) -> Self {
473 Default::default()
474 }
475
476 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
477 debug_assert!(summary.max_id > self.0);
478 self.0 = summary.max_id;
479 }
480}
481
482impl sum_tree::Dimension<'_, NotificationSummary> for Count {
483 fn zero(_cx: &()) -> Self {
484 Default::default()
485 }
486
487 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
488 self.0 += summary.count;
489 }
490}
491
492struct AddNotificationsOptions {
493 is_new: bool,
494 clear_old: bool,
495 includes_first: bool,
496}