1use anyhow::{Context, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{ChannelId, Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{
7 AppContext, AsyncAppContext, Context as _, EventEmitter, Global, Model, ModelContext, Task,
8};
9use rpc::{proto, Notification, TypedEnvelope};
10use std::{ops::Range, sync::Arc};
11use sum_tree::{Bias, SumTree};
12use time::OffsetDateTime;
13use util::ResultExt;
14
15pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
16 let notification_store = cx.new_model(|cx| NotificationStore::new(client, user_store, cx));
17 cx.set_global(GlobalNotificationStore(notification_store));
18}
19
20struct GlobalNotificationStore(Model<NotificationStore>);
21
22impl Global for GlobalNotificationStore {}
23
24pub struct NotificationStore {
25 client: Arc<Client>,
26 user_store: Model<UserStore>,
27 channel_messages: HashMap<u64, ChannelMessage>,
28 channel_store: Model<ChannelStore>,
29 notifications: SumTree<NotificationEntry>,
30 loaded_all_notifications: bool,
31 _watch_connection_status: Task<Option<()>>,
32 _subscriptions: Vec<client::Subscription>,
33}
34
35#[derive(Clone, PartialEq, Eq, Debug)]
36pub enum NotificationEvent {
37 NotificationsUpdated {
38 old_range: Range<usize>,
39 new_count: usize,
40 },
41 NewNotification {
42 entry: NotificationEntry,
43 },
44 NotificationRemoved {
45 entry: NotificationEntry,
46 },
47 NotificationRead {
48 entry: NotificationEntry,
49 },
50}
51
52#[derive(Debug, PartialEq, Eq, Clone)]
53pub struct NotificationEntry {
54 pub id: u64,
55 pub notification: Notification,
56 pub timestamp: OffsetDateTime,
57 pub is_read: bool,
58 pub response: Option<bool>,
59}
60
61#[derive(Clone, Debug, Default)]
62pub struct NotificationSummary {
63 max_id: u64,
64 count: usize,
65 unread_count: usize,
66}
67
68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
69struct Count(usize);
70
71#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
72struct NotificationId(u64);
73
74impl NotificationStore {
75 pub fn global(cx: &AppContext) -> Model<Self> {
76 cx.global::<GlobalNotificationStore>().0.clone()
77 }
78
79 pub fn new(
80 client: Arc<Client>,
81 user_store: Model<UserStore>,
82 cx: &mut ModelContext<Self>,
83 ) -> Self {
84 let mut connection_status = client.status();
85 let watch_connection_status = cx.spawn(|this, mut cx| async move {
86 while let Some(status) = connection_status.next().await {
87 let this = this.upgrade()?;
88 match status {
89 client::Status::Connected { .. } => {
90 if let Some(task) = this
91 .update(&mut cx, |this, cx| this.handle_connect(cx))
92 .log_err()?
93 {
94 task.await.log_err()?;
95 }
96 }
97 _ => this
98 .update(&mut cx, |this, cx| this.handle_disconnect(cx))
99 .log_err()?,
100 }
101 }
102 Some(())
103 });
104
105 Self {
106 channel_store: ChannelStore::global(cx),
107 notifications: Default::default(),
108 loaded_all_notifications: false,
109 channel_messages: Default::default(),
110 _watch_connection_status: watch_connection_status,
111 _subscriptions: vec![
112 client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
113 client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
114 client.add_message_handler(cx.weak_model(), Self::handle_update_notification),
115 ],
116 user_store,
117 client,
118 }
119 }
120
121 pub fn notification_count(&self) -> usize {
122 self.notifications.summary().count
123 }
124
125 pub fn unread_notification_count(&self) -> usize {
126 self.notifications.summary().unread_count
127 }
128
129 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
130 self.channel_messages.get(&id)
131 }
132
133 // Get the nth newest notification.
134 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
135 let count = self.notifications.summary().count;
136 if ix >= count {
137 return None;
138 }
139 let ix = count - 1 - ix;
140 let mut cursor = self.notifications.cursor::<Count>(&());
141 cursor.seek(&Count(ix), Bias::Right, &());
142 cursor.item()
143 }
144 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
145 let mut cursor = self.notifications.cursor::<NotificationId>(&());
146 cursor.seek(&NotificationId(id), Bias::Left, &());
147 if let Some(item) = cursor.item() {
148 if item.id == id {
149 return Some(item);
150 }
151 }
152 None
153 }
154
155 pub fn load_more_notifications(
156 &self,
157 clear_old: bool,
158 cx: &mut ModelContext<Self>,
159 ) -> Option<Task<Result<()>>> {
160 if self.loaded_all_notifications && !clear_old {
161 return None;
162 }
163
164 let before_id = if clear_old {
165 None
166 } else {
167 self.notifications.first().map(|entry| entry.id)
168 };
169 let request = self.client.request(proto::GetNotifications { before_id });
170 Some(cx.spawn(|this, mut cx| async move {
171 let this = this
172 .upgrade()
173 .context("Notification store was dropped while loading notifications")?;
174
175 let response = request.await?;
176 this.update(&mut cx, |this, _| {
177 this.loaded_all_notifications = response.done
178 })?;
179 Self::add_notifications(
180 this,
181 response.notifications,
182 AddNotificationsOptions {
183 is_new: false,
184 clear_old,
185 includes_first: response.done,
186 },
187 cx,
188 )
189 .await?;
190 Ok(())
191 }))
192 }
193
194 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
195 self.notifications = Default::default();
196 self.channel_messages = Default::default();
197 cx.notify();
198 self.load_more_notifications(true, cx)
199 }
200
201 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
202 cx.notify()
203 }
204
205 async fn handle_new_notification(
206 this: Model<Self>,
207 envelope: TypedEnvelope<proto::AddNotification>,
208 cx: AsyncAppContext,
209 ) -> Result<()> {
210 Self::add_notifications(
211 this,
212 envelope.payload.notification.into_iter().collect(),
213 AddNotificationsOptions {
214 is_new: true,
215 clear_old: false,
216 includes_first: false,
217 },
218 cx,
219 )
220 .await
221 }
222
223 async fn handle_delete_notification(
224 this: Model<Self>,
225 envelope: TypedEnvelope<proto::DeleteNotification>,
226 mut cx: AsyncAppContext,
227 ) -> Result<()> {
228 this.update(&mut cx, |this, cx| {
229 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
230 Ok(())
231 })?
232 }
233
234 async fn handle_update_notification(
235 this: Model<Self>,
236 envelope: TypedEnvelope<proto::UpdateNotification>,
237 mut cx: AsyncAppContext,
238 ) -> Result<()> {
239 this.update(&mut cx, |this, cx| {
240 if let Some(notification) = envelope.payload.notification {
241 if let Some(rpc::Notification::ChannelMessageMention { message_id, .. }) =
242 Notification::from_proto(¬ification)
243 {
244 let fetch_message_task = this.channel_store.update(cx, |this, cx| {
245 this.fetch_channel_messages(vec![message_id], cx)
246 });
247
248 cx.spawn(|this, mut cx| async move {
249 let messages = fetch_message_task.await?;
250 this.update(&mut cx, move |this, cx| {
251 for message in messages {
252 this.channel_messages.insert(message_id, message);
253 }
254 cx.notify();
255 })
256 })
257 .detach_and_log_err(cx)
258 }
259 }
260 Ok(())
261 })?
262 }
263
264 async fn add_notifications(
265 this: Model<Self>,
266 notifications: Vec<proto::Notification>,
267 options: AddNotificationsOptions,
268 mut cx: AsyncAppContext,
269 ) -> Result<()> {
270 let mut user_ids = Vec::new();
271 let mut message_ids = Vec::new();
272
273 let notifications = notifications
274 .into_iter()
275 .filter_map(|message| {
276 Some(NotificationEntry {
277 id: message.id,
278 is_read: message.is_read,
279 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
280 .ok()?,
281 notification: Notification::from_proto(&message)?,
282 response: message.response,
283 })
284 })
285 .collect::<Vec<_>>();
286 if notifications.is_empty() {
287 return Ok(());
288 }
289
290 for entry in ¬ifications {
291 match entry.notification {
292 Notification::ChannelInvitation { inviter_id, .. } => {
293 user_ids.push(inviter_id);
294 }
295 Notification::ContactRequest {
296 sender_id: requester_id,
297 } => {
298 user_ids.push(requester_id);
299 }
300 Notification::ContactRequestAccepted {
301 responder_id: contact_id,
302 } => {
303 user_ids.push(contact_id);
304 }
305 Notification::ChannelMessageMention {
306 sender_id,
307 message_id,
308 ..
309 } => {
310 user_ids.push(sender_id);
311 message_ids.push(message_id);
312 }
313 }
314 }
315
316 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
317 (this.user_store.clone(), this.channel_store.clone())
318 })?;
319
320 user_store
321 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
322 .await?;
323 let messages = channel_store
324 .update(&mut cx, |store, cx| {
325 store.fetch_channel_messages(message_ids, cx)
326 })?
327 .await?;
328 this.update(&mut cx, |this, cx| {
329 if options.clear_old {
330 cx.emit(NotificationEvent::NotificationsUpdated {
331 old_range: 0..this.notifications.summary().count,
332 new_count: 0,
333 });
334 this.notifications = SumTree::default();
335 this.channel_messages.clear();
336 this.loaded_all_notifications = false;
337 }
338
339 if options.includes_first {
340 this.loaded_all_notifications = true;
341 }
342
343 this.channel_messages
344 .extend(messages.into_iter().filter_map(|message| {
345 if let ChannelMessageId::Saved(id) = message.id {
346 Some((id, message))
347 } else {
348 None
349 }
350 }));
351
352 this.splice_notifications(
353 notifications
354 .into_iter()
355 .map(|notification| (notification.id, Some(notification))),
356 options.is_new,
357 cx,
358 );
359 })
360 .log_err();
361
362 Ok(())
363 }
364
365 fn splice_notifications(
366 &mut self,
367 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
368 is_new: bool,
369 cx: &mut ModelContext<'_, NotificationStore>,
370 ) {
371 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>(&());
372 let mut new_notifications = SumTree::default();
373 let mut old_range = 0..0;
374
375 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
376 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
377
378 if i == 0 {
379 old_range.start = cursor.start().1 .0;
380 }
381
382 let old_notification = cursor.item();
383 if let Some(old_notification) = old_notification {
384 if old_notification.id == id {
385 cursor.next(&());
386
387 if let Some(new_notification) = &new_notification {
388 if new_notification.is_read {
389 cx.emit(NotificationEvent::NotificationRead {
390 entry: new_notification.clone(),
391 });
392 }
393 } else {
394 cx.emit(NotificationEvent::NotificationRemoved {
395 entry: old_notification.clone(),
396 });
397 }
398 }
399 } else if let Some(new_notification) = &new_notification {
400 if is_new {
401 cx.emit(NotificationEvent::NewNotification {
402 entry: new_notification.clone(),
403 });
404 }
405 }
406
407 if let Some(notification) = new_notification {
408 new_notifications.push(notification, &());
409 }
410 }
411
412 old_range.end = cursor.start().1 .0;
413 let new_count = new_notifications.summary().count - old_range.start;
414 new_notifications.append(cursor.suffix(&()), &());
415 drop(cursor);
416
417 self.notifications = new_notifications;
418 cx.emit(NotificationEvent::NotificationsUpdated {
419 old_range,
420 new_count,
421 });
422 }
423
424 pub fn respond_to_notification(
425 &mut self,
426 notification: Notification,
427 response: bool,
428 cx: &mut ModelContext<Self>,
429 ) {
430 match notification {
431 Notification::ContactRequest { sender_id } => {
432 self.user_store
433 .update(cx, |store, cx| {
434 store.respond_to_contact_request(sender_id, response, cx)
435 })
436 .detach();
437 }
438 Notification::ChannelInvitation { channel_id, .. } => {
439 self.channel_store
440 .update(cx, |store, cx| {
441 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
442 })
443 .detach();
444 }
445 _ => {}
446 }
447 }
448}
449
450impl EventEmitter<NotificationEvent> for NotificationStore {}
451
452impl sum_tree::Item for NotificationEntry {
453 type Summary = NotificationSummary;
454
455 fn summary(&self, _cx: &()) -> Self::Summary {
456 NotificationSummary {
457 max_id: self.id,
458 count: 1,
459 unread_count: if self.is_read { 0 } else { 1 },
460 }
461 }
462}
463
464impl sum_tree::Summary for NotificationSummary {
465 type Context = ();
466
467 fn zero(_cx: &()) -> Self {
468 Default::default()
469 }
470
471 fn add_summary(&mut self, summary: &Self, _: &()) {
472 self.max_id = self.max_id.max(summary.max_id);
473 self.count += summary.count;
474 self.unread_count += summary.unread_count;
475 }
476}
477
478impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
479 fn zero(_cx: &()) -> Self {
480 Default::default()
481 }
482
483 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
484 debug_assert!(summary.max_id > self.0);
485 self.0 = summary.max_id;
486 }
487}
488
489impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
490 fn zero(_cx: &()) -> Self {
491 Default::default()
492 }
493
494 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
495 self.0 += summary.count;
496 }
497}
498
499struct AddNotificationsOptions {
500 is_new: bool,
501 clear_old: bool,
502 includes_first: bool,
503}