1use anyhow::{Context, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{ChannelId, Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{
7 AppContext, AsyncAppContext, Context as _, EventEmitter, Global, Model, ModelContext, Task,
8};
9use rpc::{proto, Notification, TypedEnvelope};
10use std::{ops::Range, sync::Arc};
11use sum_tree::{Bias, SumTree};
12use time::OffsetDateTime;
13use util::ResultExt;
14
15pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
16 let notification_store = cx.new_model(|cx| NotificationStore::new(client, user_store, cx));
17 cx.set_global(GlobalNotificationStore(notification_store));
18}
19
20struct GlobalNotificationStore(Model<NotificationStore>);
21
22impl Global for GlobalNotificationStore {}
23
24pub struct NotificationStore {
25 client: Arc<Client>,
26 user_store: Model<UserStore>,
27 channel_messages: HashMap<u64, ChannelMessage>,
28 channel_store: Model<ChannelStore>,
29 notifications: SumTree<NotificationEntry>,
30 loaded_all_notifications: bool,
31 _watch_connection_status: Task<Option<()>>,
32 _subscriptions: Vec<client::Subscription>,
33}
34
35#[derive(Clone, PartialEq, Eq, Debug)]
36pub enum NotificationEvent {
37 NotificationsUpdated {
38 old_range: Range<usize>,
39 new_count: usize,
40 },
41 NewNotification {
42 entry: NotificationEntry,
43 },
44 NotificationRemoved {
45 entry: NotificationEntry,
46 },
47 NotificationRead {
48 entry: NotificationEntry,
49 },
50}
51
52#[derive(Debug, PartialEq, Eq, Clone)]
53pub struct NotificationEntry {
54 pub id: u64,
55 pub notification: Notification,
56 pub timestamp: OffsetDateTime,
57 pub is_read: bool,
58 pub response: Option<bool>,
59}
60
61#[derive(Clone, Debug, Default)]
62pub struct NotificationSummary {
63 max_id: u64,
64 count: usize,
65 unread_count: usize,
66}
67
68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
69struct Count(usize);
70
71#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
72struct NotificationId(u64);
73
74impl NotificationStore {
75 pub fn global(cx: &AppContext) -> Model<Self> {
76 cx.global::<GlobalNotificationStore>().0.clone()
77 }
78
79 pub fn new(
80 client: Arc<Client>,
81 user_store: Model<UserStore>,
82 cx: &mut ModelContext<Self>,
83 ) -> Self {
84 let mut connection_status = client.status();
85 let watch_connection_status = cx.spawn(|this, mut cx| async move {
86 while let Some(status) = connection_status.next().await {
87 let this = this.upgrade()?;
88 match status {
89 client::Status::Connected { .. } => {
90 if let Some(task) = this
91 .update(&mut cx, |this, cx| this.handle_connect(cx))
92 .log_err()?
93 {
94 task.await.log_err()?;
95 }
96 }
97 _ => this
98 .update(&mut cx, |this, cx| this.handle_disconnect(cx))
99 .log_err()?,
100 }
101 }
102 Some(())
103 });
104
105 Self {
106 channel_store: ChannelStore::global(cx),
107 notifications: Default::default(),
108 loaded_all_notifications: false,
109 channel_messages: Default::default(),
110 _watch_connection_status: watch_connection_status,
111 _subscriptions: vec![
112 client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
113 client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
114 client.add_message_handler(cx.weak_model(), Self::handle_update_notification),
115 ],
116 user_store,
117 client,
118 }
119 }
120
121 pub fn notification_count(&self) -> usize {
122 self.notifications.summary().count
123 }
124
125 pub fn unread_notification_count(&self) -> usize {
126 self.notifications.summary().unread_count
127 }
128
129 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
130 self.channel_messages.get(&id)
131 }
132
133 // Get the nth newest notification.
134 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
135 let count = self.notifications.summary().count;
136 if ix >= count {
137 return None;
138 }
139 let ix = count - 1 - ix;
140 let mut cursor = self.notifications.cursor::<Count>(&());
141 cursor.seek(&Count(ix), Bias::Right, &());
142 cursor.item()
143 }
144 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
145 let mut cursor = self.notifications.cursor::<NotificationId>(&());
146 cursor.seek(&NotificationId(id), Bias::Left, &());
147 if let Some(item) = cursor.item() {
148 if item.id == id {
149 return Some(item);
150 }
151 }
152 None
153 }
154
155 pub fn load_more_notifications(
156 &self,
157 clear_old: bool,
158 cx: &mut ModelContext<Self>,
159 ) -> Option<Task<Result<()>>> {
160 if self.loaded_all_notifications && !clear_old {
161 return None;
162 }
163
164 let before_id = if clear_old {
165 None
166 } else {
167 self.notifications.first().map(|entry| entry.id)
168 };
169 let request = self.client.request(proto::GetNotifications { before_id });
170 Some(cx.spawn(|this, mut cx| async move {
171 let this = this
172 .upgrade()
173 .context("Notification store was dropped while loading notifications")?;
174
175 let response = request.await?;
176 this.update(&mut cx, |this, _| {
177 this.loaded_all_notifications = response.done
178 })?;
179 Self::add_notifications(
180 this,
181 response.notifications,
182 AddNotificationsOptions {
183 is_new: false,
184 clear_old,
185 includes_first: response.done,
186 },
187 cx,
188 )
189 .await?;
190 Ok(())
191 }))
192 }
193
194 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
195 self.notifications = Default::default();
196 self.channel_messages = Default::default();
197 cx.notify();
198 self.load_more_notifications(true, cx)
199 }
200
201 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
202 cx.notify()
203 }
204
205 async fn handle_new_notification(
206 this: Model<Self>,
207 envelope: TypedEnvelope<proto::AddNotification>,
208 cx: AsyncAppContext,
209 ) -> Result<()> {
210 Self::add_notifications(
211 this,
212 envelope.payload.notification.into_iter().collect(),
213 AddNotificationsOptions {
214 is_new: true,
215 clear_old: false,
216 includes_first: false,
217 },
218 cx,
219 )
220 .await
221 }
222
223 async fn handle_delete_notification(
224 this: Model<Self>,
225 envelope: TypedEnvelope<proto::DeleteNotification>,
226 mut cx: AsyncAppContext,
227 ) -> Result<()> {
228 this.update(&mut cx, |this, cx| {
229 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
230 Ok(())
231 })?
232 }
233
234 async fn handle_update_notification(
235 this: Model<Self>,
236 envelope: TypedEnvelope<proto::UpdateNotification>,
237 mut cx: AsyncAppContext,
238 ) -> Result<()> {
239 this.update(&mut cx, |this, cx| {
240 if let Some(notification) = envelope.payload.notification {
241 if let Some(rpc::Notification::ChannelMessageMention {
242 message_id,
243 sender_id: _,
244 channel_id: _,
245 }) = Notification::from_proto(¬ification)
246 {
247 let fetch_message_task = this.channel_store.update(cx, |this, cx| {
248 this.fetch_channel_messages(vec![message_id], cx)
249 });
250
251 cx.spawn(|this, mut cx| async move {
252 let messages = fetch_message_task.await?;
253 this.update(&mut cx, move |this, cx| {
254 for message in messages {
255 this.channel_messages.insert(message_id, message);
256 }
257 cx.notify();
258 })
259 })
260 .detach_and_log_err(cx)
261 }
262 }
263 Ok(())
264 })?
265 }
266
267 async fn add_notifications(
268 this: Model<Self>,
269 notifications: Vec<proto::Notification>,
270 options: AddNotificationsOptions,
271 mut cx: AsyncAppContext,
272 ) -> Result<()> {
273 let mut user_ids = Vec::new();
274 let mut message_ids = Vec::new();
275
276 let notifications = notifications
277 .into_iter()
278 .filter_map(|message| {
279 Some(NotificationEntry {
280 id: message.id,
281 is_read: message.is_read,
282 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
283 .ok()?,
284 notification: Notification::from_proto(&message)?,
285 response: message.response,
286 })
287 })
288 .collect::<Vec<_>>();
289 if notifications.is_empty() {
290 return Ok(());
291 }
292
293 for entry in ¬ifications {
294 match entry.notification {
295 Notification::ChannelInvitation { inviter_id, .. } => {
296 user_ids.push(inviter_id);
297 }
298 Notification::ContactRequest {
299 sender_id: requester_id,
300 } => {
301 user_ids.push(requester_id);
302 }
303 Notification::ContactRequestAccepted {
304 responder_id: contact_id,
305 } => {
306 user_ids.push(contact_id);
307 }
308 Notification::ChannelMessageMention {
309 sender_id,
310 message_id,
311 ..
312 } => {
313 user_ids.push(sender_id);
314 message_ids.push(message_id);
315 }
316 }
317 }
318
319 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
320 (this.user_store.clone(), this.channel_store.clone())
321 })?;
322
323 user_store
324 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
325 .await?;
326 let messages = channel_store
327 .update(&mut cx, |store, cx| {
328 store.fetch_channel_messages(message_ids, cx)
329 })?
330 .await?;
331 this.update(&mut cx, |this, cx| {
332 if options.clear_old {
333 cx.emit(NotificationEvent::NotificationsUpdated {
334 old_range: 0..this.notifications.summary().count,
335 new_count: 0,
336 });
337 this.notifications = SumTree::default();
338 this.channel_messages.clear();
339 this.loaded_all_notifications = false;
340 }
341
342 if options.includes_first {
343 this.loaded_all_notifications = true;
344 }
345
346 this.channel_messages
347 .extend(messages.into_iter().filter_map(|message| {
348 if let ChannelMessageId::Saved(id) = message.id {
349 Some((id, message))
350 } else {
351 None
352 }
353 }));
354
355 this.splice_notifications(
356 notifications
357 .into_iter()
358 .map(|notification| (notification.id, Some(notification))),
359 options.is_new,
360 cx,
361 );
362 })
363 .log_err();
364
365 Ok(())
366 }
367
368 fn splice_notifications(
369 &mut self,
370 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
371 is_new: bool,
372 cx: &mut ModelContext<'_, NotificationStore>,
373 ) {
374 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>(&());
375 let mut new_notifications = SumTree::default();
376 let mut old_range = 0..0;
377
378 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
379 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
380
381 if i == 0 {
382 old_range.start = cursor.start().1 .0;
383 }
384
385 let old_notification = cursor.item();
386 if let Some(old_notification) = old_notification {
387 if old_notification.id == id {
388 cursor.next(&());
389
390 if let Some(new_notification) = &new_notification {
391 if new_notification.is_read {
392 cx.emit(NotificationEvent::NotificationRead {
393 entry: new_notification.clone(),
394 });
395 }
396 } else {
397 cx.emit(NotificationEvent::NotificationRemoved {
398 entry: old_notification.clone(),
399 });
400 }
401 }
402 } else if let Some(new_notification) = &new_notification {
403 if is_new {
404 cx.emit(NotificationEvent::NewNotification {
405 entry: new_notification.clone(),
406 });
407 }
408 }
409
410 if let Some(notification) = new_notification {
411 new_notifications.push(notification, &());
412 }
413 }
414
415 old_range.end = cursor.start().1 .0;
416 let new_count = new_notifications.summary().count - old_range.start;
417 new_notifications.append(cursor.suffix(&()), &());
418 drop(cursor);
419
420 self.notifications = new_notifications;
421 cx.emit(NotificationEvent::NotificationsUpdated {
422 old_range,
423 new_count,
424 });
425 }
426
427 pub fn respond_to_notification(
428 &mut self,
429 notification: Notification,
430 response: bool,
431 cx: &mut ModelContext<Self>,
432 ) {
433 match notification {
434 Notification::ContactRequest { sender_id } => {
435 self.user_store
436 .update(cx, |store, cx| {
437 store.respond_to_contact_request(sender_id, response, cx)
438 })
439 .detach();
440 }
441 Notification::ChannelInvitation { channel_id, .. } => {
442 self.channel_store
443 .update(cx, |store, cx| {
444 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
445 })
446 .detach();
447 }
448 _ => {}
449 }
450 }
451}
452
453impl EventEmitter<NotificationEvent> for NotificationStore {}
454
455impl sum_tree::Item for NotificationEntry {
456 type Summary = NotificationSummary;
457
458 fn summary(&self, _cx: &()) -> Self::Summary {
459 NotificationSummary {
460 max_id: self.id,
461 count: 1,
462 unread_count: if self.is_read { 0 } else { 1 },
463 }
464 }
465}
466
467impl sum_tree::Summary for NotificationSummary {
468 type Context = ();
469
470 fn zero(_cx: &()) -> Self {
471 Default::default()
472 }
473
474 fn add_summary(&mut self, summary: &Self, _: &()) {
475 self.max_id = self.max_id.max(summary.max_id);
476 self.count += summary.count;
477 self.unread_count += summary.unread_count;
478 }
479}
480
481impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
482 fn zero(_cx: &()) -> Self {
483 Default::default()
484 }
485
486 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
487 debug_assert!(summary.max_id > self.0);
488 self.0 = summary.max_id;
489 }
490}
491
492impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
493 fn zero(_cx: &()) -> Self {
494 Default::default()
495 }
496
497 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
498 self.0 += summary.count;
499 }
500}
501
502struct AddNotificationsOptions {
503 is_new: bool,
504 clear_old: bool,
505 includes_first: bool,
506}