1use anyhow::{Context, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{ChannelId, Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{
7 AppContext, AsyncAppContext, Context as _, EventEmitter, Global, Model, ModelContext, Task,
8};
9use rpc::{proto, Notification, TypedEnvelope};
10use std::{ops::Range, sync::Arc};
11use sum_tree::{Bias, SumTree};
12use time::OffsetDateTime;
13use util::ResultExt;
14
15pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
16 let notification_store = cx.new_model(|cx| NotificationStore::new(client, user_store, cx));
17 cx.set_global(GlobalNotificationStore(notification_store));
18}
19
20struct GlobalNotificationStore(Model<NotificationStore>);
21
22impl Global for GlobalNotificationStore {}
23
24pub struct NotificationStore {
25 client: Arc<Client>,
26 user_store: Model<UserStore>,
27 channel_messages: HashMap<u64, ChannelMessage>,
28 channel_store: Model<ChannelStore>,
29 notifications: SumTree<NotificationEntry>,
30 loaded_all_notifications: bool,
31 _watch_connection_status: Task<Option<()>>,
32 _subscriptions: Vec<client::Subscription>,
33}
34
35#[derive(Clone, PartialEq, Eq, Debug)]
36pub enum NotificationEvent {
37 NotificationsUpdated {
38 old_range: Range<usize>,
39 new_count: usize,
40 },
41 NewNotification {
42 entry: NotificationEntry,
43 },
44 NotificationRemoved {
45 entry: NotificationEntry,
46 },
47 NotificationRead {
48 entry: NotificationEntry,
49 },
50}
51
52#[derive(Debug, PartialEq, Eq, Clone)]
53pub struct NotificationEntry {
54 pub id: u64,
55 pub notification: Notification,
56 pub timestamp: OffsetDateTime,
57 pub is_read: bool,
58 pub response: Option<bool>,
59}
60
61#[derive(Clone, Debug, Default)]
62pub struct NotificationSummary {
63 max_id: u64,
64 count: usize,
65 unread_count: usize,
66}
67
68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
69struct Count(usize);
70
71#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
72struct UnreadCount(usize);
73
74#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
75struct NotificationId(u64);
76
77impl NotificationStore {
78 pub fn global(cx: &AppContext) -> Model<Self> {
79 cx.global::<GlobalNotificationStore>().0.clone()
80 }
81
82 pub fn new(
83 client: Arc<Client>,
84 user_store: Model<UserStore>,
85 cx: &mut ModelContext<Self>,
86 ) -> Self {
87 let mut connection_status = client.status();
88 let watch_connection_status = cx.spawn(|this, mut cx| async move {
89 while let Some(status) = connection_status.next().await {
90 let this = this.upgrade()?;
91 match status {
92 client::Status::Connected { .. } => {
93 if let Some(task) = this
94 .update(&mut cx, |this, cx| this.handle_connect(cx))
95 .log_err()?
96 {
97 task.await.log_err()?;
98 }
99 }
100 _ => this
101 .update(&mut cx, |this, cx| this.handle_disconnect(cx))
102 .log_err()?,
103 }
104 }
105 Some(())
106 });
107
108 Self {
109 channel_store: ChannelStore::global(cx),
110 notifications: Default::default(),
111 loaded_all_notifications: false,
112 channel_messages: Default::default(),
113 _watch_connection_status: watch_connection_status,
114 _subscriptions: vec![
115 client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
116 client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
117 client.add_message_handler(cx.weak_model(), Self::handle_update_notification),
118 ],
119 user_store,
120 client,
121 }
122 }
123
124 pub fn notification_count(&self) -> usize {
125 self.notifications.summary().count
126 }
127
128 pub fn unread_notification_count(&self) -> usize {
129 self.notifications.summary().unread_count
130 }
131
132 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
133 self.channel_messages.get(&id)
134 }
135
136 // Get the nth newest notification.
137 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
138 let count = self.notifications.summary().count;
139 if ix >= count {
140 return None;
141 }
142 let ix = count - 1 - ix;
143 let mut cursor = self.notifications.cursor::<Count>();
144 cursor.seek(&Count(ix), Bias::Right, &());
145 cursor.item()
146 }
147
148 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
149 let mut cursor = self.notifications.cursor::<NotificationId>();
150 cursor.seek(&NotificationId(id), Bias::Left, &());
151 if let Some(item) = cursor.item() {
152 if item.id == id {
153 return Some(item);
154 }
155 }
156 None
157 }
158
159 pub fn load_more_notifications(
160 &self,
161 clear_old: bool,
162 cx: &mut ModelContext<Self>,
163 ) -> Option<Task<Result<()>>> {
164 if self.loaded_all_notifications && !clear_old {
165 return None;
166 }
167
168 let before_id = if clear_old {
169 None
170 } else {
171 self.notifications.first().map(|entry| entry.id)
172 };
173 let request = self.client.request(proto::GetNotifications { before_id });
174 Some(cx.spawn(|this, mut cx| async move {
175 let this = this
176 .upgrade()
177 .context("Notification store was dropped while loading notifications")?;
178
179 let response = request.await?;
180 this.update(&mut cx, |this, _| {
181 this.loaded_all_notifications = response.done
182 })?;
183 Self::add_notifications(
184 this,
185 response.notifications,
186 AddNotificationsOptions {
187 is_new: false,
188 clear_old,
189 includes_first: response.done,
190 },
191 cx,
192 )
193 .await?;
194 Ok(())
195 }))
196 }
197
198 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
199 self.notifications = Default::default();
200 self.channel_messages = Default::default();
201 cx.notify();
202 self.load_more_notifications(true, cx)
203 }
204
205 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
206 cx.notify()
207 }
208
209 async fn handle_new_notification(
210 this: Model<Self>,
211 envelope: TypedEnvelope<proto::AddNotification>,
212 _: Arc<Client>,
213 cx: AsyncAppContext,
214 ) -> Result<()> {
215 Self::add_notifications(
216 this,
217 envelope.payload.notification.into_iter().collect(),
218 AddNotificationsOptions {
219 is_new: true,
220 clear_old: false,
221 includes_first: false,
222 },
223 cx,
224 )
225 .await
226 }
227
228 async fn handle_delete_notification(
229 this: Model<Self>,
230 envelope: TypedEnvelope<proto::DeleteNotification>,
231 _: Arc<Client>,
232 mut cx: AsyncAppContext,
233 ) -> Result<()> {
234 this.update(&mut cx, |this, cx| {
235 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
236 Ok(())
237 })?
238 }
239
240 async fn handle_update_notification(
241 this: Model<Self>,
242 envelope: TypedEnvelope<proto::UpdateNotification>,
243 _: Arc<Client>,
244 mut cx: AsyncAppContext,
245 ) -> Result<()> {
246 this.update(&mut cx, |this, cx| {
247 if let Some(notification) = envelope.payload.notification {
248 if let Some(rpc::Notification::ChannelMessageMention {
249 message_id,
250 sender_id: _,
251 channel_id: _,
252 }) = Notification::from_proto(¬ification)
253 {
254 let fetch_message_task = this.channel_store.update(cx, |this, cx| {
255 this.fetch_channel_messages(vec![message_id], cx)
256 });
257
258 cx.spawn(|this, mut cx| async move {
259 let messages = fetch_message_task.await?;
260 this.update(&mut cx, move |this, cx| {
261 for message in messages {
262 this.channel_messages.insert(message_id, message);
263 }
264 cx.notify();
265 })
266 })
267 .detach_and_log_err(cx)
268 }
269 }
270 Ok(())
271 })?
272 }
273
274 async fn add_notifications(
275 this: Model<Self>,
276 notifications: Vec<proto::Notification>,
277 options: AddNotificationsOptions,
278 mut cx: AsyncAppContext,
279 ) -> Result<()> {
280 let mut user_ids = Vec::new();
281 let mut message_ids = Vec::new();
282
283 let notifications = notifications
284 .into_iter()
285 .filter_map(|message| {
286 Some(NotificationEntry {
287 id: message.id,
288 is_read: message.is_read,
289 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
290 .ok()?,
291 notification: Notification::from_proto(&message)?,
292 response: message.response,
293 })
294 })
295 .collect::<Vec<_>>();
296 if notifications.is_empty() {
297 return Ok(());
298 }
299
300 for entry in ¬ifications {
301 match entry.notification {
302 Notification::ChannelInvitation { inviter_id, .. } => {
303 user_ids.push(inviter_id);
304 }
305 Notification::ContactRequest {
306 sender_id: requester_id,
307 } => {
308 user_ids.push(requester_id);
309 }
310 Notification::ContactRequestAccepted {
311 responder_id: contact_id,
312 } => {
313 user_ids.push(contact_id);
314 }
315 Notification::ChannelMessageMention {
316 sender_id,
317 message_id,
318 ..
319 } => {
320 user_ids.push(sender_id);
321 message_ids.push(message_id);
322 }
323 }
324 }
325
326 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
327 (this.user_store.clone(), this.channel_store.clone())
328 })?;
329
330 user_store
331 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
332 .await?;
333 let messages = channel_store
334 .update(&mut cx, |store, cx| {
335 store.fetch_channel_messages(message_ids, cx)
336 })?
337 .await?;
338 this.update(&mut cx, |this, cx| {
339 if options.clear_old {
340 cx.emit(NotificationEvent::NotificationsUpdated {
341 old_range: 0..this.notifications.summary().count,
342 new_count: 0,
343 });
344 this.notifications = SumTree::default();
345 this.channel_messages.clear();
346 this.loaded_all_notifications = false;
347 }
348
349 if options.includes_first {
350 this.loaded_all_notifications = true;
351 }
352
353 this.channel_messages
354 .extend(messages.into_iter().filter_map(|message| {
355 if let ChannelMessageId::Saved(id) = message.id {
356 Some((id, message))
357 } else {
358 None
359 }
360 }));
361
362 this.splice_notifications(
363 notifications
364 .into_iter()
365 .map(|notification| (notification.id, Some(notification))),
366 options.is_new,
367 cx,
368 );
369 })
370 .log_err();
371
372 Ok(())
373 }
374
375 fn splice_notifications(
376 &mut self,
377 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
378 is_new: bool,
379 cx: &mut ModelContext<'_, NotificationStore>,
380 ) {
381 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
382 let mut new_notifications = SumTree::new();
383 let mut old_range = 0..0;
384
385 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
386 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
387
388 if i == 0 {
389 old_range.start = cursor.start().1 .0;
390 }
391
392 let old_notification = cursor.item();
393 if let Some(old_notification) = old_notification {
394 if old_notification.id == id {
395 cursor.next(&());
396
397 if let Some(new_notification) = &new_notification {
398 if new_notification.is_read {
399 cx.emit(NotificationEvent::NotificationRead {
400 entry: new_notification.clone(),
401 });
402 }
403 } else {
404 cx.emit(NotificationEvent::NotificationRemoved {
405 entry: old_notification.clone(),
406 });
407 }
408 }
409 } else if let Some(new_notification) = &new_notification {
410 if is_new {
411 cx.emit(NotificationEvent::NewNotification {
412 entry: new_notification.clone(),
413 });
414 }
415 }
416
417 if let Some(notification) = new_notification {
418 new_notifications.push(notification, &());
419 }
420 }
421
422 old_range.end = cursor.start().1 .0;
423 let new_count = new_notifications.summary().count - old_range.start;
424 new_notifications.append(cursor.suffix(&()), &());
425 drop(cursor);
426
427 self.notifications = new_notifications;
428 cx.emit(NotificationEvent::NotificationsUpdated {
429 old_range,
430 new_count,
431 });
432 }
433
434 pub fn respond_to_notification(
435 &mut self,
436 notification: Notification,
437 response: bool,
438 cx: &mut ModelContext<Self>,
439 ) {
440 match notification {
441 Notification::ContactRequest { sender_id } => {
442 self.user_store
443 .update(cx, |store, cx| {
444 store.respond_to_contact_request(sender_id, response, cx)
445 })
446 .detach();
447 }
448 Notification::ChannelInvitation { channel_id, .. } => {
449 self.channel_store
450 .update(cx, |store, cx| {
451 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
452 })
453 .detach();
454 }
455 _ => {}
456 }
457 }
458}
459
460impl EventEmitter<NotificationEvent> for NotificationStore {}
461
462impl sum_tree::Item for NotificationEntry {
463 type Summary = NotificationSummary;
464
465 fn summary(&self) -> Self::Summary {
466 NotificationSummary {
467 max_id: self.id,
468 count: 1,
469 unread_count: if self.is_read { 0 } else { 1 },
470 }
471 }
472}
473
474impl sum_tree::Summary for NotificationSummary {
475 type Context = ();
476
477 fn add_summary(&mut self, summary: &Self, _: &()) {
478 self.max_id = self.max_id.max(summary.max_id);
479 self.count += summary.count;
480 self.unread_count += summary.unread_count;
481 }
482}
483
484impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
485 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
486 debug_assert!(summary.max_id > self.0);
487 self.0 = summary.max_id;
488 }
489}
490
491impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
492 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
493 self.0 += summary.count;
494 }
495}
496
497impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
498 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
499 self.0 += summary.unread_count;
500 }
501}
502
503struct AddNotificationsOptions {
504 is_new: bool,
505 clear_old: bool,
506 includes_first: bool,
507}