1use anyhow::{Context as _, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{ChannelId, Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task};
7use rpc::{Notification, TypedEnvelope, proto};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, Dimensions, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
14 let notification_store = cx.new(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(GlobalNotificationStore(notification_store));
16}
17
18struct GlobalNotificationStore(Entity<NotificationStore>);
19
20impl Global for GlobalNotificationStore {}
21
22pub struct NotificationStore {
23 client: Arc<Client>,
24 user_store: Entity<UserStore>,
25 channel_messages: HashMap<u64, ChannelMessage>,
26 channel_store: Entity<ChannelStore>,
27 notifications: SumTree<NotificationEntry>,
28 loaded_all_notifications: bool,
29 _watch_connection_status: Task<Option<()>>,
30 _subscriptions: Vec<client::Subscription>,
31}
32
33#[derive(Clone, PartialEq, Eq, Debug)]
34pub enum NotificationEvent {
35 NotificationsUpdated {
36 old_range: Range<usize>,
37 new_count: usize,
38 },
39 NewNotification {
40 entry: NotificationEntry,
41 },
42 NotificationRemoved {
43 entry: NotificationEntry,
44 },
45 NotificationRead {
46 entry: NotificationEntry,
47 },
48}
49
50#[derive(Debug, PartialEq, Eq, Clone)]
51pub struct NotificationEntry {
52 pub id: u64,
53 pub notification: Notification,
54 pub timestamp: OffsetDateTime,
55 pub is_read: bool,
56 pub response: Option<bool>,
57}
58
59#[derive(Clone, Debug, Default)]
60pub struct NotificationSummary {
61 max_id: u64,
62 count: usize,
63 unread_count: usize,
64}
65
66#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
67struct Count(usize);
68
69#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
70struct NotificationId(u64);
71
72impl NotificationStore {
73 pub fn global(cx: &App) -> Entity<Self> {
74 cx.global::<GlobalNotificationStore>().0.clone()
75 }
76
77 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
78 let mut connection_status = client.status();
79 let watch_connection_status = cx.spawn(async move |this, cx| {
80 while let Some(status) = connection_status.next().await {
81 let this = this.upgrade()?;
82 match status {
83 client::Status::Connected { .. } => {
84 if let Some(task) = this
85 .update(cx, |this, cx| this.handle_connect(cx))
86 .log_err()?
87 {
88 task.await.log_err()?;
89 }
90 }
91 _ => this
92 .update(cx, |this, cx| this.handle_disconnect(cx))
93 .log_err()?,
94 }
95 }
96 Some(())
97 });
98
99 Self {
100 channel_store: ChannelStore::global(cx),
101 notifications: Default::default(),
102 loaded_all_notifications: false,
103 channel_messages: Default::default(),
104 _watch_connection_status: watch_connection_status,
105 _subscriptions: vec![
106 client.add_message_handler(cx.weak_entity(), Self::handle_new_notification),
107 client.add_message_handler(cx.weak_entity(), Self::handle_delete_notification),
108 client.add_message_handler(cx.weak_entity(), Self::handle_update_notification),
109 ],
110 user_store,
111 client,
112 }
113 }
114
115 pub fn notification_count(&self) -> usize {
116 self.notifications.summary().count
117 }
118
119 pub fn unread_notification_count(&self) -> usize {
120 self.notifications.summary().unread_count
121 }
122
123 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
124 self.channel_messages.get(&id)
125 }
126
127 // Get the nth newest notification.
128 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
129 let count = self.notifications.summary().count;
130 if ix >= count {
131 return None;
132 }
133 let ix = count - 1 - ix;
134 let mut cursor = self.notifications.cursor::<Count>(&());
135 cursor.seek(&Count(ix), Bias::Right);
136 cursor.item()
137 }
138 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
139 let mut cursor = self.notifications.cursor::<NotificationId>(&());
140 cursor.seek(&NotificationId(id), Bias::Left);
141 if let Some(item) = cursor.item() {
142 if item.id == id {
143 return Some(item);
144 }
145 }
146 None
147 }
148
149 pub fn load_more_notifications(
150 &self,
151 clear_old: bool,
152 cx: &mut Context<Self>,
153 ) -> Option<Task<Result<()>>> {
154 if self.loaded_all_notifications && !clear_old {
155 return None;
156 }
157
158 let before_id = if clear_old {
159 None
160 } else {
161 self.notifications.first().map(|entry| entry.id)
162 };
163 let request = self.client.request(proto::GetNotifications { before_id });
164 Some(cx.spawn(async move |this, cx| {
165 let this = this
166 .upgrade()
167 .context("Notification store was dropped while loading notifications")?;
168
169 let response = request.await?;
170 this.update(cx, |this, _| this.loaded_all_notifications = response.done)?;
171 Self::add_notifications(
172 this,
173 response.notifications,
174 AddNotificationsOptions {
175 is_new: false,
176 clear_old,
177 includes_first: response.done,
178 },
179 cx,
180 )
181 .await?;
182 Ok(())
183 }))
184 }
185
186 fn handle_connect(&mut self, cx: &mut Context<Self>) -> Option<Task<Result<()>>> {
187 self.notifications = Default::default();
188 self.channel_messages = Default::default();
189 cx.notify();
190 self.load_more_notifications(true, cx)
191 }
192
193 fn handle_disconnect(&mut self, cx: &mut Context<Self>) {
194 cx.notify()
195 }
196
197 async fn handle_new_notification(
198 this: Entity<Self>,
199 envelope: TypedEnvelope<proto::AddNotification>,
200 mut cx: AsyncApp,
201 ) -> Result<()> {
202 Self::add_notifications(
203 this,
204 envelope.payload.notification.into_iter().collect(),
205 AddNotificationsOptions {
206 is_new: true,
207 clear_old: false,
208 includes_first: false,
209 },
210 &mut cx,
211 )
212 .await
213 }
214
215 async fn handle_delete_notification(
216 this: Entity<Self>,
217 envelope: TypedEnvelope<proto::DeleteNotification>,
218 mut cx: AsyncApp,
219 ) -> Result<()> {
220 this.update(&mut cx, |this, cx| {
221 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
222 Ok(())
223 })?
224 }
225
226 async fn handle_update_notification(
227 this: Entity<Self>,
228 envelope: TypedEnvelope<proto::UpdateNotification>,
229 mut cx: AsyncApp,
230 ) -> Result<()> {
231 this.update(&mut cx, |this, cx| {
232 if let Some(notification) = envelope.payload.notification {
233 if let Some(rpc::Notification::ChannelMessageMention { message_id, .. }) =
234 Notification::from_proto(¬ification)
235 {
236 let fetch_message_task = this.channel_store.update(cx, |this, cx| {
237 this.fetch_channel_messages(vec![message_id], cx)
238 });
239
240 cx.spawn(async move |this, cx| {
241 let messages = fetch_message_task.await?;
242 this.update(cx, move |this, cx| {
243 for message in messages {
244 this.channel_messages.insert(message_id, message);
245 }
246 cx.notify();
247 })
248 })
249 .detach_and_log_err(cx)
250 }
251 }
252 Ok(())
253 })?
254 }
255
256 async fn add_notifications(
257 this: Entity<Self>,
258 notifications: Vec<proto::Notification>,
259 options: AddNotificationsOptions,
260 cx: &mut AsyncApp,
261 ) -> Result<()> {
262 let mut user_ids = Vec::new();
263 let mut message_ids = Vec::new();
264
265 let notifications = notifications
266 .into_iter()
267 .filter_map(|message| {
268 Some(NotificationEntry {
269 id: message.id,
270 is_read: message.is_read,
271 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
272 .ok()?,
273 notification: Notification::from_proto(&message)?,
274 response: message.response,
275 })
276 })
277 .collect::<Vec<_>>();
278 if notifications.is_empty() {
279 return Ok(());
280 }
281
282 for entry in ¬ifications {
283 match entry.notification {
284 Notification::ChannelInvitation { inviter_id, .. } => {
285 user_ids.push(inviter_id);
286 }
287 Notification::ContactRequest {
288 sender_id: requester_id,
289 } => {
290 user_ids.push(requester_id);
291 }
292 Notification::ContactRequestAccepted {
293 responder_id: contact_id,
294 } => {
295 user_ids.push(contact_id);
296 }
297 Notification::ChannelMessageMention {
298 sender_id,
299 message_id,
300 ..
301 } => {
302 user_ids.push(sender_id);
303 message_ids.push(message_id);
304 }
305 }
306 }
307
308 let (user_store, channel_store) = this.read_with(cx, |this, _| {
309 (this.user_store.clone(), this.channel_store.clone())
310 })?;
311
312 user_store
313 .update(cx, |store, cx| store.get_users(user_ids, cx))?
314 .await?;
315 let messages = channel_store
316 .update(cx, |store, cx| {
317 store.fetch_channel_messages(message_ids, cx)
318 })?
319 .await?;
320 this.update(cx, |this, cx| {
321 if options.clear_old {
322 cx.emit(NotificationEvent::NotificationsUpdated {
323 old_range: 0..this.notifications.summary().count,
324 new_count: 0,
325 });
326 this.notifications = SumTree::default();
327 this.channel_messages.clear();
328 this.loaded_all_notifications = false;
329 }
330
331 if options.includes_first {
332 this.loaded_all_notifications = true;
333 }
334
335 this.channel_messages
336 .extend(messages.into_iter().filter_map(|message| {
337 if let ChannelMessageId::Saved(id) = message.id {
338 Some((id, message))
339 } else {
340 None
341 }
342 }));
343
344 this.splice_notifications(
345 notifications
346 .into_iter()
347 .map(|notification| (notification.id, Some(notification))),
348 options.is_new,
349 cx,
350 );
351 })
352 .log_err();
353
354 Ok(())
355 }
356
357 fn splice_notifications(
358 &mut self,
359 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
360 is_new: bool,
361 cx: &mut Context<NotificationStore>,
362 ) {
363 let mut cursor = self
364 .notifications
365 .cursor::<Dimensions<NotificationId, Count>>(&());
366 let mut new_notifications = SumTree::default();
367 let mut old_range = 0..0;
368
369 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
370 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left), &());
371
372 if i == 0 {
373 old_range.start = cursor.start().1.0;
374 }
375
376 let old_notification = cursor.item();
377 if let Some(old_notification) = old_notification {
378 if old_notification.id == id {
379 cursor.next();
380
381 if let Some(new_notification) = &new_notification {
382 if new_notification.is_read {
383 cx.emit(NotificationEvent::NotificationRead {
384 entry: new_notification.clone(),
385 });
386 }
387 } else {
388 cx.emit(NotificationEvent::NotificationRemoved {
389 entry: old_notification.clone(),
390 });
391 }
392 }
393 } else if let Some(new_notification) = &new_notification {
394 if is_new {
395 cx.emit(NotificationEvent::NewNotification {
396 entry: new_notification.clone(),
397 });
398 }
399 }
400
401 if let Some(notification) = new_notification {
402 new_notifications.push(notification, &());
403 }
404 }
405
406 old_range.end = cursor.start().1.0;
407 let new_count = new_notifications.summary().count - old_range.start;
408 new_notifications.append(cursor.suffix(), &());
409 drop(cursor);
410
411 self.notifications = new_notifications;
412 cx.emit(NotificationEvent::NotificationsUpdated {
413 old_range,
414 new_count,
415 });
416 }
417
418 pub fn respond_to_notification(
419 &mut self,
420 notification: Notification,
421 response: bool,
422 cx: &mut Context<Self>,
423 ) {
424 match notification {
425 Notification::ContactRequest { sender_id } => {
426 self.user_store
427 .update(cx, |store, cx| {
428 store.respond_to_contact_request(sender_id, response, cx)
429 })
430 .detach();
431 }
432 Notification::ChannelInvitation { channel_id, .. } => {
433 self.channel_store
434 .update(cx, |store, cx| {
435 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
436 })
437 .detach();
438 }
439 _ => {}
440 }
441 }
442}
443
444impl EventEmitter<NotificationEvent> for NotificationStore {}
445
446impl sum_tree::Item for NotificationEntry {
447 type Summary = NotificationSummary;
448
449 fn summary(&self, _cx: &()) -> Self::Summary {
450 NotificationSummary {
451 max_id: self.id,
452 count: 1,
453 unread_count: if self.is_read { 0 } else { 1 },
454 }
455 }
456}
457
458impl sum_tree::Summary for NotificationSummary {
459 type Context = ();
460
461 fn zero(_cx: &()) -> Self {
462 Default::default()
463 }
464
465 fn add_summary(&mut self, summary: &Self, _: &()) {
466 self.max_id = self.max_id.max(summary.max_id);
467 self.count += summary.count;
468 self.unread_count += summary.unread_count;
469 }
470}
471
472impl sum_tree::Dimension<'_, NotificationSummary> for NotificationId {
473 fn zero(_cx: &()) -> Self {
474 Default::default()
475 }
476
477 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
478 debug_assert!(summary.max_id > self.0);
479 self.0 = summary.max_id;
480 }
481}
482
483impl sum_tree::Dimension<'_, NotificationSummary> for Count {
484 fn zero(_cx: &()) -> Self {
485 Default::default()
486 }
487
488 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
489 self.0 += summary.count;
490 }
491}
492
493struct AddNotificationsOptions {
494 is_new: bool,
495 clear_old: bool,
496 includes_first: bool,
497}