1use anyhow::{Context, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{ChannelId, Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{
7 AppContext, AsyncAppContext, Context as _, EventEmitter, Global, Model, ModelContext, Task,
8};
9use rpc::{proto, Notification, TypedEnvelope};
10use std::{ops::Range, sync::Arc};
11use sum_tree::{Bias, SumTree};
12use time::OffsetDateTime;
13use util::ResultExt;
14
15pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
16 let notification_store = cx.new_model(|cx| NotificationStore::new(client, user_store, cx));
17 cx.set_global(GlobalNotificationStore(notification_store));
18}
19
20struct GlobalNotificationStore(Model<NotificationStore>);
21
22impl Global for GlobalNotificationStore {}
23
24pub struct NotificationStore {
25 client: Arc<Client>,
26 user_store: Model<UserStore>,
27 channel_messages: HashMap<u64, ChannelMessage>,
28 channel_store: Model<ChannelStore>,
29 notifications: SumTree<NotificationEntry>,
30 loaded_all_notifications: bool,
31 _watch_connection_status: Task<Option<()>>,
32 _subscriptions: Vec<client::Subscription>,
33}
34
35#[derive(Clone, PartialEq, Eq, Debug)]
36pub enum NotificationEvent {
37 NotificationsUpdated {
38 old_range: Range<usize>,
39 new_count: usize,
40 },
41 NewNotification {
42 entry: NotificationEntry,
43 },
44 NotificationRemoved {
45 entry: NotificationEntry,
46 },
47 NotificationRead {
48 entry: NotificationEntry,
49 },
50}
51
52#[derive(Debug, PartialEq, Eq, Clone)]
53pub struct NotificationEntry {
54 pub id: u64,
55 pub notification: Notification,
56 pub timestamp: OffsetDateTime,
57 pub is_read: bool,
58 pub response: Option<bool>,
59}
60
61#[derive(Clone, Debug, Default)]
62pub struct NotificationSummary {
63 max_id: u64,
64 count: usize,
65 unread_count: usize,
66}
67
68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
69struct Count(usize);
70
71#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
72struct UnreadCount(usize);
73
74#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
75struct NotificationId(u64);
76
77impl NotificationStore {
78 pub fn global(cx: &AppContext) -> Model<Self> {
79 cx.global::<GlobalNotificationStore>().0.clone()
80 }
81
82 pub fn new(
83 client: Arc<Client>,
84 user_store: Model<UserStore>,
85 cx: &mut ModelContext<Self>,
86 ) -> Self {
87 let mut connection_status = client.status();
88 let watch_connection_status = cx.spawn(|this, mut cx| async move {
89 while let Some(status) = connection_status.next().await {
90 let this = this.upgrade()?;
91 match status {
92 client::Status::Connected { .. } => {
93 if let Some(task) = this
94 .update(&mut cx, |this, cx| this.handle_connect(cx))
95 .log_err()?
96 {
97 task.await.log_err()?;
98 }
99 }
100 _ => this
101 .update(&mut cx, |this, cx| this.handle_disconnect(cx))
102 .log_err()?,
103 }
104 }
105 Some(())
106 });
107
108 Self {
109 channel_store: ChannelStore::global(cx),
110 notifications: Default::default(),
111 loaded_all_notifications: false,
112 channel_messages: Default::default(),
113 _watch_connection_status: watch_connection_status,
114 _subscriptions: vec![
115 client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
116 client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
117 ],
118 user_store,
119 client,
120 }
121 }
122
123 pub fn notification_count(&self) -> usize {
124 self.notifications.summary().count
125 }
126
127 pub fn unread_notification_count(&self) -> usize {
128 self.notifications.summary().unread_count
129 }
130
131 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
132 self.channel_messages.get(&id)
133 }
134
135 // Get the nth newest notification.
136 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
137 let count = self.notifications.summary().count;
138 if ix >= count {
139 return None;
140 }
141 let ix = count - 1 - ix;
142 let mut cursor = self.notifications.cursor::<Count>();
143 cursor.seek(&Count(ix), Bias::Right, &());
144 cursor.item()
145 }
146
147 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
148 let mut cursor = self.notifications.cursor::<NotificationId>();
149 cursor.seek(&NotificationId(id), Bias::Left, &());
150 if let Some(item) = cursor.item() {
151 if item.id == id {
152 return Some(item);
153 }
154 }
155 None
156 }
157
158 pub fn load_more_notifications(
159 &self,
160 clear_old: bool,
161 cx: &mut ModelContext<Self>,
162 ) -> Option<Task<Result<()>>> {
163 if self.loaded_all_notifications && !clear_old {
164 return None;
165 }
166
167 let before_id = if clear_old {
168 None
169 } else {
170 self.notifications.first().map(|entry| entry.id)
171 };
172 let request = self.client.request(proto::GetNotifications { before_id });
173 Some(cx.spawn(|this, mut cx| async move {
174 let this = this
175 .upgrade()
176 .context("Notification store was dropped while loading notifications")?;
177
178 let response = request.await?;
179 this.update(&mut cx, |this, _| {
180 this.loaded_all_notifications = response.done
181 })?;
182 Self::add_notifications(
183 this,
184 response.notifications,
185 AddNotificationsOptions {
186 is_new: false,
187 clear_old,
188 includes_first: response.done,
189 },
190 cx,
191 )
192 .await?;
193 Ok(())
194 }))
195 }
196
197 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
198 self.notifications = Default::default();
199 self.channel_messages = Default::default();
200 cx.notify();
201 self.load_more_notifications(true, cx)
202 }
203
204 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
205 cx.notify()
206 }
207
208 async fn handle_new_notification(
209 this: Model<Self>,
210 envelope: TypedEnvelope<proto::AddNotification>,
211 _: Arc<Client>,
212 cx: AsyncAppContext,
213 ) -> Result<()> {
214 Self::add_notifications(
215 this,
216 envelope.payload.notification.into_iter().collect(),
217 AddNotificationsOptions {
218 is_new: true,
219 clear_old: false,
220 includes_first: false,
221 },
222 cx,
223 )
224 .await
225 }
226
227 async fn handle_delete_notification(
228 this: Model<Self>,
229 envelope: TypedEnvelope<proto::DeleteNotification>,
230 _: Arc<Client>,
231 mut cx: AsyncAppContext,
232 ) -> Result<()> {
233 this.update(&mut cx, |this, cx| {
234 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
235 Ok(())
236 })?
237 }
238
239 async fn add_notifications(
240 this: Model<Self>,
241 notifications: Vec<proto::Notification>,
242 options: AddNotificationsOptions,
243 mut cx: AsyncAppContext,
244 ) -> Result<()> {
245 let mut user_ids = Vec::new();
246 let mut message_ids = Vec::new();
247
248 let notifications = notifications
249 .into_iter()
250 .filter_map(|message| {
251 Some(NotificationEntry {
252 id: message.id,
253 is_read: message.is_read,
254 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
255 .ok()?,
256 notification: Notification::from_proto(&message)?,
257 response: message.response,
258 })
259 })
260 .collect::<Vec<_>>();
261 if notifications.is_empty() {
262 return Ok(());
263 }
264
265 for entry in ¬ifications {
266 match entry.notification {
267 Notification::ChannelInvitation { inviter_id, .. } => {
268 user_ids.push(inviter_id);
269 }
270 Notification::ContactRequest {
271 sender_id: requester_id,
272 } => {
273 user_ids.push(requester_id);
274 }
275 Notification::ContactRequestAccepted {
276 responder_id: contact_id,
277 } => {
278 user_ids.push(contact_id);
279 }
280 Notification::ChannelMessageMention {
281 sender_id,
282 message_id,
283 ..
284 } => {
285 user_ids.push(sender_id);
286 message_ids.push(message_id);
287 }
288 }
289 }
290
291 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
292 (this.user_store.clone(), this.channel_store.clone())
293 })?;
294
295 user_store
296 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
297 .await?;
298 let messages = channel_store
299 .update(&mut cx, |store, cx| {
300 store.fetch_channel_messages(message_ids, cx)
301 })?
302 .await?;
303 this.update(&mut cx, |this, cx| {
304 if options.clear_old {
305 cx.emit(NotificationEvent::NotificationsUpdated {
306 old_range: 0..this.notifications.summary().count,
307 new_count: 0,
308 });
309 this.notifications = SumTree::default();
310 this.channel_messages.clear();
311 this.loaded_all_notifications = false;
312 }
313
314 if options.includes_first {
315 this.loaded_all_notifications = true;
316 }
317
318 this.channel_messages
319 .extend(messages.into_iter().filter_map(|message| {
320 if let ChannelMessageId::Saved(id) = message.id {
321 Some((id, message))
322 } else {
323 None
324 }
325 }));
326
327 this.splice_notifications(
328 notifications
329 .into_iter()
330 .map(|notification| (notification.id, Some(notification))),
331 options.is_new,
332 cx,
333 );
334 })
335 .log_err();
336
337 Ok(())
338 }
339
340 fn splice_notifications(
341 &mut self,
342 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
343 is_new: bool,
344 cx: &mut ModelContext<'_, NotificationStore>,
345 ) {
346 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
347 let mut new_notifications = SumTree::new();
348 let mut old_range = 0..0;
349
350 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
351 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
352
353 if i == 0 {
354 old_range.start = cursor.start().1 .0;
355 }
356
357 let old_notification = cursor.item();
358 if let Some(old_notification) = old_notification {
359 if old_notification.id == id {
360 cursor.next(&());
361
362 if let Some(new_notification) = &new_notification {
363 if new_notification.is_read {
364 cx.emit(NotificationEvent::NotificationRead {
365 entry: new_notification.clone(),
366 });
367 }
368 } else {
369 cx.emit(NotificationEvent::NotificationRemoved {
370 entry: old_notification.clone(),
371 });
372 }
373 }
374 } else if let Some(new_notification) = &new_notification {
375 if is_new {
376 cx.emit(NotificationEvent::NewNotification {
377 entry: new_notification.clone(),
378 });
379 }
380 }
381
382 if let Some(notification) = new_notification {
383 new_notifications.push(notification, &());
384 }
385 }
386
387 old_range.end = cursor.start().1 .0;
388 let new_count = new_notifications.summary().count - old_range.start;
389 new_notifications.append(cursor.suffix(&()), &());
390 drop(cursor);
391
392 self.notifications = new_notifications;
393 cx.emit(NotificationEvent::NotificationsUpdated {
394 old_range,
395 new_count,
396 });
397 }
398
399 pub fn respond_to_notification(
400 &mut self,
401 notification: Notification,
402 response: bool,
403 cx: &mut ModelContext<Self>,
404 ) {
405 match notification {
406 Notification::ContactRequest { sender_id } => {
407 self.user_store
408 .update(cx, |store, cx| {
409 store.respond_to_contact_request(sender_id, response, cx)
410 })
411 .detach();
412 }
413 Notification::ChannelInvitation { channel_id, .. } => {
414 self.channel_store
415 .update(cx, |store, cx| {
416 store.respond_to_channel_invite(ChannelId(channel_id), response, cx)
417 })
418 .detach();
419 }
420 _ => {}
421 }
422 }
423}
424
425impl EventEmitter<NotificationEvent> for NotificationStore {}
426
427impl sum_tree::Item for NotificationEntry {
428 type Summary = NotificationSummary;
429
430 fn summary(&self) -> Self::Summary {
431 NotificationSummary {
432 max_id: self.id,
433 count: 1,
434 unread_count: if self.is_read { 0 } else { 1 },
435 }
436 }
437}
438
439impl sum_tree::Summary for NotificationSummary {
440 type Context = ();
441
442 fn add_summary(&mut self, summary: &Self, _: &()) {
443 self.max_id = self.max_id.max(summary.max_id);
444 self.count += summary.count;
445 self.unread_count += summary.unread_count;
446 }
447}
448
449impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
450 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
451 debug_assert!(summary.max_id > self.0);
452 self.0 = summary.max_id;
453 }
454}
455
456impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
457 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
458 self.0 += summary.count;
459 }
460}
461
462impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
463 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
464 self.0 += summary.unread_count;
465 }
466}
467
468struct AddNotificationsOptions {
469 is_new: bool,
470 clear_old: bool,
471 includes_first: bool,
472}