1use anyhow::{Context, Result};
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{AppContext, AsyncAppContext, Context as _, EventEmitter, Model, ModelContext, Task};
7use rpc::{proto, Notification, TypedEnvelope};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: Model<UserStore>, cx: &mut AppContext) {
14 let notification_store = cx.build_model(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(notification_store);
16}
17
18pub struct NotificationStore {
19 client: Arc<Client>,
20 user_store: Model<UserStore>,
21 channel_messages: HashMap<u64, ChannelMessage>,
22 channel_store: Model<ChannelStore>,
23 notifications: SumTree<NotificationEntry>,
24 loaded_all_notifications: bool,
25 _watch_connection_status: Task<Option<()>>,
26 _subscriptions: Vec<client::Subscription>,
27}
28
29#[derive(Clone, PartialEq, Eq, Debug)]
30pub enum NotificationEvent {
31 NotificationsUpdated {
32 old_range: Range<usize>,
33 new_count: usize,
34 },
35 NewNotification {
36 entry: NotificationEntry,
37 },
38 NotificationRemoved {
39 entry: NotificationEntry,
40 },
41 NotificationRead {
42 entry: NotificationEntry,
43 },
44}
45
46#[derive(Debug, PartialEq, Eq, Clone)]
47pub struct NotificationEntry {
48 pub id: u64,
49 pub notification: Notification,
50 pub timestamp: OffsetDateTime,
51 pub is_read: bool,
52 pub response: Option<bool>,
53}
54
55#[derive(Clone, Debug, Default)]
56pub struct NotificationSummary {
57 max_id: u64,
58 count: usize,
59 unread_count: usize,
60}
61
62#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
63struct Count(usize);
64
65#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
66struct UnreadCount(usize);
67
68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
69struct NotificationId(u64);
70
71impl NotificationStore {
72 pub fn global(cx: &AppContext) -> Model<Self> {
73 cx.global::<Model<Self>>().clone()
74 }
75
76 pub fn new(
77 client: Arc<Client>,
78 user_store: Model<UserStore>,
79 cx: &mut ModelContext<Self>,
80 ) -> Self {
81 let mut connection_status = client.status();
82 let watch_connection_status = cx.spawn(|this, mut cx| async move {
83 while let Some(status) = connection_status.next().await {
84 let this = this.upgrade()?;
85 match status {
86 client::Status::Connected { .. } => {
87 if let Some(task) = this
88 .update(&mut cx, |this, cx| this.handle_connect(cx))
89 .log_err()?
90 {
91 task.await.log_err()?;
92 }
93 }
94 _ => this
95 .update(&mut cx, |this, cx| this.handle_disconnect(cx))
96 .log_err()?,
97 }
98 }
99 Some(())
100 });
101
102 Self {
103 channel_store: ChannelStore::global(cx),
104 notifications: Default::default(),
105 loaded_all_notifications: false,
106 channel_messages: Default::default(),
107 _watch_connection_status: watch_connection_status,
108 _subscriptions: vec![
109 client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
110 client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
111 ],
112 user_store,
113 client,
114 }
115 }
116
117 pub fn notification_count(&self) -> usize {
118 self.notifications.summary().count
119 }
120
121 pub fn unread_notification_count(&self) -> usize {
122 self.notifications.summary().unread_count
123 }
124
125 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
126 self.channel_messages.get(&id)
127 }
128
129 // Get the nth newest notification.
130 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
131 let count = self.notifications.summary().count;
132 if ix >= count {
133 return None;
134 }
135 let ix = count - 1 - ix;
136 let mut cursor = self.notifications.cursor::<Count>();
137 cursor.seek(&Count(ix), Bias::Right, &());
138 cursor.item()
139 }
140
141 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
142 let mut cursor = self.notifications.cursor::<NotificationId>();
143 cursor.seek(&NotificationId(id), Bias::Left, &());
144 if let Some(item) = cursor.item() {
145 if item.id == id {
146 return Some(item);
147 }
148 }
149 None
150 }
151
152 pub fn load_more_notifications(
153 &self,
154 clear_old: bool,
155 cx: &mut ModelContext<Self>,
156 ) -> Option<Task<Result<()>>> {
157 if self.loaded_all_notifications && !clear_old {
158 return None;
159 }
160
161 let before_id = if clear_old {
162 None
163 } else {
164 self.notifications.first().map(|entry| entry.id)
165 };
166 let request = self.client.request(proto::GetNotifications { before_id });
167 Some(cx.spawn(|this, mut cx| async move {
168 let this = this
169 .upgrade()
170 .context("Notification store was dropped while loading notifications")?;
171
172 let response = request.await?;
173 this.update(&mut cx, |this, _| {
174 this.loaded_all_notifications = response.done
175 })?;
176 Self::add_notifications(
177 this,
178 response.notifications,
179 AddNotificationsOptions {
180 is_new: false,
181 clear_old,
182 includes_first: response.done,
183 },
184 cx,
185 )
186 .await?;
187 Ok(())
188 }))
189 }
190
191 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
192 self.notifications = Default::default();
193 self.channel_messages = Default::default();
194 cx.notify();
195 self.load_more_notifications(true, cx)
196 }
197
198 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
199 cx.notify()
200 }
201
202 async fn handle_new_notification(
203 this: Model<Self>,
204 envelope: TypedEnvelope<proto::AddNotification>,
205 _: Arc<Client>,
206 cx: AsyncAppContext,
207 ) -> Result<()> {
208 Self::add_notifications(
209 this,
210 envelope.payload.notification.into_iter().collect(),
211 AddNotificationsOptions {
212 is_new: true,
213 clear_old: false,
214 includes_first: false,
215 },
216 cx,
217 )
218 .await
219 }
220
221 async fn handle_delete_notification(
222 this: Model<Self>,
223 envelope: TypedEnvelope<proto::DeleteNotification>,
224 _: Arc<Client>,
225 mut cx: AsyncAppContext,
226 ) -> Result<()> {
227 this.update(&mut cx, |this, cx| {
228 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
229 Ok(())
230 })?
231 }
232
233 async fn add_notifications(
234 this: Model<Self>,
235 notifications: Vec<proto::Notification>,
236 options: AddNotificationsOptions,
237 mut cx: AsyncAppContext,
238 ) -> Result<()> {
239 let mut user_ids = Vec::new();
240 let mut message_ids = Vec::new();
241
242 let notifications = notifications
243 .into_iter()
244 .filter_map(|message| {
245 Some(NotificationEntry {
246 id: message.id,
247 is_read: message.is_read,
248 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
249 .ok()?,
250 notification: Notification::from_proto(&message)?,
251 response: message.response,
252 })
253 })
254 .collect::<Vec<_>>();
255 if notifications.is_empty() {
256 return Ok(());
257 }
258
259 for entry in ¬ifications {
260 match entry.notification {
261 Notification::ChannelInvitation { inviter_id, .. } => {
262 user_ids.push(inviter_id);
263 }
264 Notification::ContactRequest {
265 sender_id: requester_id,
266 } => {
267 user_ids.push(requester_id);
268 }
269 Notification::ContactRequestAccepted {
270 responder_id: contact_id,
271 } => {
272 user_ids.push(contact_id);
273 }
274 Notification::ChannelMessageMention {
275 sender_id,
276 message_id,
277 ..
278 } => {
279 user_ids.push(sender_id);
280 message_ids.push(message_id);
281 }
282 }
283 }
284
285 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
286 (this.user_store.clone(), this.channel_store.clone())
287 })?;
288
289 user_store
290 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))?
291 .await?;
292 let messages = channel_store
293 .update(&mut cx, |store, cx| {
294 store.fetch_channel_messages(message_ids, cx)
295 })?
296 .await?;
297 this.update(&mut cx, |this, cx| {
298 if options.clear_old {
299 cx.emit(NotificationEvent::NotificationsUpdated {
300 old_range: 0..this.notifications.summary().count,
301 new_count: 0,
302 });
303 this.notifications = SumTree::default();
304 this.channel_messages.clear();
305 this.loaded_all_notifications = false;
306 }
307
308 if options.includes_first {
309 this.loaded_all_notifications = true;
310 }
311
312 this.channel_messages
313 .extend(messages.into_iter().filter_map(|message| {
314 if let ChannelMessageId::Saved(id) = message.id {
315 Some((id, message))
316 } else {
317 None
318 }
319 }));
320
321 this.splice_notifications(
322 notifications
323 .into_iter()
324 .map(|notification| (notification.id, Some(notification))),
325 options.is_new,
326 cx,
327 );
328 })
329 .log_err();
330
331 Ok(())
332 }
333
334 fn splice_notifications(
335 &mut self,
336 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
337 is_new: bool,
338 cx: &mut ModelContext<'_, NotificationStore>,
339 ) {
340 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
341 let mut new_notifications = SumTree::new();
342 let mut old_range = 0..0;
343
344 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
345 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
346
347 if i == 0 {
348 old_range.start = cursor.start().1 .0;
349 }
350
351 let old_notification = cursor.item();
352 if let Some(old_notification) = old_notification {
353 if old_notification.id == id {
354 cursor.next(&());
355
356 if let Some(new_notification) = &new_notification {
357 if new_notification.is_read {
358 cx.emit(NotificationEvent::NotificationRead {
359 entry: new_notification.clone(),
360 });
361 }
362 } else {
363 cx.emit(NotificationEvent::NotificationRemoved {
364 entry: old_notification.clone(),
365 });
366 }
367 }
368 } else if let Some(new_notification) = &new_notification {
369 if is_new {
370 cx.emit(NotificationEvent::NewNotification {
371 entry: new_notification.clone(),
372 });
373 }
374 }
375
376 if let Some(notification) = new_notification {
377 new_notifications.push(notification, &());
378 }
379 }
380
381 old_range.end = cursor.start().1 .0;
382 let new_count = new_notifications.summary().count - old_range.start;
383 new_notifications.append(cursor.suffix(&()), &());
384 drop(cursor);
385
386 self.notifications = new_notifications;
387 cx.emit(NotificationEvent::NotificationsUpdated {
388 old_range,
389 new_count,
390 });
391 }
392
393 pub fn respond_to_notification(
394 &mut self,
395 notification: Notification,
396 response: bool,
397 cx: &mut ModelContext<Self>,
398 ) {
399 match notification {
400 Notification::ContactRequest { sender_id } => {
401 self.user_store
402 .update(cx, |store, cx| {
403 store.respond_to_contact_request(sender_id, response, cx)
404 })
405 .detach();
406 }
407 Notification::ChannelInvitation { channel_id, .. } => {
408 self.channel_store
409 .update(cx, |store, cx| {
410 store.respond_to_channel_invite(channel_id, response, cx)
411 })
412 .detach();
413 }
414 _ => {}
415 }
416 }
417}
418
419impl EventEmitter<NotificationEvent> for NotificationStore {}
420
421impl sum_tree::Item for NotificationEntry {
422 type Summary = NotificationSummary;
423
424 fn summary(&self) -> Self::Summary {
425 NotificationSummary {
426 max_id: self.id,
427 count: 1,
428 unread_count: if self.is_read { 0 } else { 1 },
429 }
430 }
431}
432
433impl sum_tree::Summary for NotificationSummary {
434 type Context = ();
435
436 fn add_summary(&mut self, summary: &Self, _: &()) {
437 self.max_id = self.max_id.max(summary.max_id);
438 self.count += summary.count;
439 self.unread_count += summary.unread_count;
440 }
441}
442
443impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
444 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
445 debug_assert!(summary.max_id > self.0);
446 self.0 = summary.max_id;
447 }
448}
449
450impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
451 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
452 self.0 += summary.count;
453 }
454}
455
456impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
457 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
458 self.0 += summary.unread_count;
459 }
460}
461
462struct AddNotificationsOptions {
463 is_new: bool,
464 clear_old: bool,
465 includes_first: bool,
466}