1use anyhow::{Context as _, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{ChannelId, Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task};
7use rpc::{proto, Notification, TypedEnvelope};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
14 let notification_store = cx.new(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(GlobalNotificationStore(notification_store));
16}
17
18struct GlobalNotificationStore(Entity<NotificationStore>);
19
20impl Global for GlobalNotificationStore {}
21
22pub struct NotificationStore {
23 client: Arc<Client>,
24 user_store: Entity<UserStore>,
25 channel_messages: HashMap<u64, ChannelMessage>,
26 channel_store: Entity<ChannelStore>,
27 notifications: SumTree<NotificationEntry>,
28 loaded_all_notifications: bool,
29 _watch_connection_status: Task<Option<()>>,
30 _subscriptions: Vec<client::Subscription>,
31}
32
33#[derive(Clone, PartialEq, Eq, Debug)]
34pub enum NotificationEvent {
35 NotificationsUpdated {
36 old_range: Range<usize>,
37 new_count: usize,
38 },
39 NewNotification {
40 entry: NotificationEntry,
41 },
42 NotificationRemoved {
43 entry: NotificationEntry,
44 },
45 NotificationRead {
46 entry: NotificationEntry,
47 },
48}
49
50#[derive(Debug, PartialEq, Eq, Clone)]
51pub struct NotificationEntry {
52 pub id: u64,
53 pub notification: Notification,
54 pub timestamp: OffsetDateTime,
55 pub is_read: bool,
56 pub response: Option<bool>,
57}
58
59#[derive(Clone, Debug, Default)]
60pub struct NotificationSummary {
61 max_id: u64,
62 count: usize,
63 unread_count: usize,
64}
65
66#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
67struct Count(usize);
68
69#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
70struct NotificationId(u64);
71
72impl NotificationStore {
73 pub fn global(cx: &App) -> Entity<Self> {
74 cx.global::<GlobalNotificationStore>().0.clone()
75 }
76
77 pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
78 let mut connection_status = client.status();
79 let watch_connection_status = cx.spawn(|this, mut cx| async move {
80 while let Some(status) = connection_status.next().await {
81 let this = this.upgrade()?;
82 match status {
83 client::Status::Connected { .. } => {
84 if let Some(task) = this
85 .update(&mut cx, |this, cx| this.handle_connect(cx))
86 .log_err()?
87 {
88 task.await.log_err()?;
89 }
90 }
91 _ => this
92 .update(&mut cx, |this, cx| this.handle_disconnect(cx))
93 .log_err()?,
94 }
95 }
96 Some(())
97 });
98
99 Self {
100 channel_store: ChannelStore::global(cx),
101 notifications: Default::default(),
102 loaded_all_notifications: false,
103 channel_messages: Default::default(),
104 _watch_connection_status: watch_connection_status,
105 _subscriptions: vec![
106 client.add_message_handler(cx.weak_entity(), Self::handle_new_notification),
107 client.add_message_handler(cx.weak_entity(), Self::handle_delete_notification),
108 client.add_message_handler(cx.weak_entity(), Self::handle_update_notification),
109 ],
110 user_store,
111 client,
112 }
113 }
114
115 pub fn notification_count(&self) -> usize {
116 self.notifications.summary().count
117 }
118
119 pub fn unread_notification_count(&self) -> usize {
120 self.notifications.summary().unread_count
121 }
122
123 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
124 self.channel_messages.get(&id)
125 }
126
127 // Get the nth newest notification.
128 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
129 let count = self.notifications.summary().count;
130 if ix >= count {
131 return None;
132 }
133 let ix = count - 1 - ix;
134 let mut cursor = self.notifications.cursor::<Count>(&());
135 cursor.seek(&Count(ix), Bias::Right, &());
136 cursor.item()
137 }
138 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
139 let mut cursor = self.notifications.cursor::<NotificationId>(&());
140 cursor.seek(&NotificationId(id), Bias::Left, &());
141 if let Some(item) = cursor.item() {
142 if item.id == id {
143 return Some(item);
144 }
145 }
146 None
147 }
148
149 pub fn load_more_notifications(
150 &self,
151 clear_old: bool,
152 cx: &mut Context<Self>,
153 ) -> Option<Task<Result<()>>> {
154 if self.loaded_all_notifications && !clear_old {
155 return None;
156 }
157
158 let before_id = if clear_old {
159 None
160 } else {
161 self.notifications.first().map(|entry| entry.id)
162 };
163 let request = self.client.request(proto::GetNotifications { before_id });
164 Some(cx.spawn(|this, mut cx| async move {
165 let this = this
166 .upgrade()
167 .context("Notification store was dropped while loading notifications")?;
168
169 let response = request.await?;
170 this.update(&mut cx, |this, _| {
171 this.loaded_all_notifications = response.done
172 })?;
173 Self::add_notifications(
174 this,
175 response.notifications,
176 AddNotificationsOptions {
177 is_new: false,
178 clear_old,
179 includes_first: response.done,
180 },
181 cx,
182 )
183 .await?;
184 Ok(())
185 }))
186 }
187
188 fn handle_connect(&mut self, cx: &mut Context<Self>) -> Option<Task<Result<()>>> {
189 self.notifications = Default::default();
190 self.channel_messages = Default::default();
191 cx.notify();
192 self.load_more_notifications(true, cx)
193 }
194
195 fn handle_disconnect(&mut self, cx: &mut Context<Self>) {
196 cx.notify()
197 }
198
199 async fn handle_new_notification(
200 this: Entity<Self>,
201 envelope: TypedEnvelope<proto::AddNotification>,
202 cx: AsyncApp,
203 ) -> Result<()> {
204 Self::add_notifications(
205 this,
206 envelope.payload.notification.into_iter().collect(),
207 AddNotificationsOptions {
208 is_new: true,
209 clear_old: false,
210 includes_first: false,
211 },
212 cx,
213 )
214 .await
215 }
216
217 async fn handle_delete_notification(
218 this: Entity<Self>,
219 envelope: TypedEnvelope<proto::DeleteNotification>,
220 mut cx: AsyncApp,
221 ) -> Result<()> {
222 this.update(&mut cx, |this, cx| {
223 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
224 Ok(())
225 })?
226 }
227
228 async fn handle_update_notification(
229 this: Entity<Self>,
230 envelope: TypedEnvelope<proto::UpdateNotification>,
231 mut cx: AsyncApp,
232 ) -> Result<()> {
233 this.update(&mut cx, |this, cx| {
234 if let Some(notification) = envelope.payload.notification {
235 if let Some(rpc::Notification::ChannelMessageMention { message_id, .. }) =
236 Notification::from_proto(¬ification)
237 {
238 let fetch_message_task = this.channel_store.update(cx, |this, cx| {
239 this.fetch_channel_messages(vec![message_id], cx)
240 });
241
242 cx.spawn(|this, mut cx| async move {
243 let messages = fetch_message_task.await?;
244 this.update(&mut cx, move |this, cx| {
245 for message in messages {
246 this.channel_messages.insert(message_id, message);
247 }
248 cx.notify();
249 })
250 })
251 .detach_and_log_err(cx)
252 }
253 }
254 Ok(())
255 })?
256 }
257
258 async fn add_notifications(
259 this: Entity<Self>,
260 notifications: Vec<proto::Notification>,
261 options: AddNotificationsOptions,
262 mut cx: AsyncApp,
263 ) -> Result<()> {
264 let mut user_ids = Vec::new();
265 let mut message_ids = Vec::new();
266
267 let notifications = notifications
268 .into_iter()
269 .filter_map(|message| {
270 Some(NotificationEntry {
271 id: message.id,
272 is_read: message.is_read,
273 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
274 .ok()?,
275 notification: Notification::from_proto(&message)?,
276 response: message.response,
277 })
278 })
279 .collect::<Vec<_>>();
280 if notifications.is_empty() {
281 return Ok(());
282 }
283
284 for entry in ¬ifications {
285 match entry.notification {
286 Notification::ChannelInvitation { inviter_id, .. } => {
287 user_ids.push(inviter_id);
288 }
289 Notification::ContactRequest {
290 sender_id: requester_id,
291 } => {
292 user_ids.push(requester_id);
293 }
294 Notification::ContactRequestAccepted {
295 responder_id: contact_id,
296 } => {
297 user_ids.push(contact_id);
298 }
299 Notification::ChannelMessageMention {
300 sender_id,
301 message_id,
302 ..
303 } => {
304 user_ids.push(sender_id);
305 message_ids.push(message_id);
306 }
307 }
308 }
309
310 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
311 (this.user_store.clone(), this.channel_store.clone())
312 })?;
313
314 user_store
315 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
316 .await?;
317 let messages = channel_store
318 .update(&mut cx, |store, cx| {
319 store.fetch_channel_messages(message_ids, cx)
320 })?
321 .await?;
322 this.update(&mut cx, |this, cx| {
323 if options.clear_old {
324 cx.emit(NotificationEvent::NotificationsUpdated {
325 old_range: 0..this.notifications.summary().count,
326 new_count: 0,
327 });
328 this.notifications = SumTree::default();
329 this.channel_messages.clear();
330 this.loaded_all_notifications = false;
331 }
332
333 if options.includes_first {
334 this.loaded_all_notifications = true;
335 }
336
337 this.channel_messages
338 .extend(messages.into_iter().filter_map(|message| {
339 if let ChannelMessageId::Saved(id) = message.id {
340 Some((id, message))
341 } else {
342 None
343 }
344 }));
345
346 this.splice_notifications(
347 notifications
348 .into_iter()
349 .map(|notification| (notification.id, Some(notification))),
350 options.is_new,
351 cx,
352 );
353 })
354 .log_err();
355
356 Ok(())
357 }
358
359 fn splice_notifications(
360 &mut self,
361 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
362 is_new: bool,
363 cx: &mut Context<'_, NotificationStore>,
364 ) {
365 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>(&());
366 let mut new_notifications = SumTree::default();
367 let mut old_range = 0..0;
368
369 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
370 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
371
372 if i == 0 {
373 old_range.start = cursor.start().1 .0;
374 }
375
376 let old_notification = cursor.item();
377 if let Some(old_notification) = old_notification {
378 if old_notification.id == id {
379 cursor.next(&());
380
381 if let Some(new_notification) = &new_notification {
382 if new_notification.is_read {
383 cx.emit(NotificationEvent::NotificationRead {
384 entry: new_notification.clone(),
385 });
386 }
387 } else {
388 cx.emit(NotificationEvent::NotificationRemoved {
389 entry: old_notification.clone(),
390 });
391 }
392 }
393 } else if let Some(new_notification) = &new_notification {
394 if is_new {
395 cx.emit(NotificationEvent::NewNotification {
396 entry: new_notification.clone(),
397 });
398 }
399 }
400
401 if let Some(notification) = new_notification {
402 new_notifications.push(notification, &());
403 }
404 }
405
406 old_range.end = cursor.start().1 .0;
407 let new_count = new_notifications.summary().count - old_range.start;
408 new_notifications.append(cursor.suffix(&()), &());
409 drop(cursor);
410
411 self.notifications = new_notifications;
412 cx.emit(NotificationEvent::NotificationsUpdated {
413 old_range,
414 new_count,
415 });
416 }
417
418 pub fn respond_to_notification(
419 &mut self,
420 notification: Notification,
421 response: bool,
422 cx: &mut Context<Self>,
423 ) {
424 match notification {
425 Notification::ContactRequest { sender_id } => {
426 self.user_store
427 .update(cx, |store, cx| {
428 store.respond_to_contact_request(sender_id, response, cx)
429 })
430 .detach();
431 }
432 Notification::ChannelInvitation { channel_id, .. } => {
433 self.channel_store
434 .update(cx, |store, cx| {
435 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
436 })
437 .detach();
438 }
439 _ => {}
440 }
441 }
442}
443
444impl EventEmitter<NotificationEvent> for NotificationStore {}
445
446impl sum_tree::Item for NotificationEntry {
447 type Summary = NotificationSummary;
448
449 fn summary(&self, _cx: &()) -> Self::Summary {
450 NotificationSummary {
451 max_id: self.id,
452 count: 1,
453 unread_count: if self.is_read { 0 } else { 1 },
454 }
455 }
456}
457
458impl sum_tree::Summary for NotificationSummary {
459 type Context = ();
460
461 fn zero(_cx: &()) -> Self {
462 Default::default()
463 }
464
465 fn add_summary(&mut self, summary: &Self, _: &()) {
466 self.max_id = self.max_id.max(summary.max_id);
467 self.count += summary.count;
468 self.unread_count += summary.unread_count;
469 }
470}
471
472impl sum_tree::Dimension<'_, NotificationSummary> for NotificationId {
473 fn zero(_cx: &()) -> Self {
474 Default::default()
475 }
476
477 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
478 debug_assert!(summary.max_id > self.0);
479 self.0 = summary.max_id;
480 }
481}
482
483impl sum_tree::Dimension<'_, NotificationSummary> for Count {
484 fn zero(_cx: &()) -> Self {
485 Default::default()
486 }
487
488 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
489 self.0 += summary.count;
490 }
491}
492
493struct AddNotificationsOptions {
494 is_new: bool,
495 clear_old: bool,
496 includes_first: bool,
497}