1use anyhow::Result;
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
7use rpc::{proto, Notification, TypedEnvelope};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
14 let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(notification_store);
16}
17
18pub struct NotificationStore {
19 client: Arc<Client>,
20 user_store: ModelHandle<UserStore>,
21 channel_messages: HashMap<u64, ChannelMessage>,
22 channel_store: ModelHandle<ChannelStore>,
23 notifications: SumTree<NotificationEntry>,
24 loaded_all_notifications: bool,
25 _watch_connection_status: Task<Option<()>>,
26 _subscriptions: Vec<client::Subscription>,
27}
28
29#[derive(Clone, PartialEq, Eq, Debug)]
30pub enum NotificationEvent {
31 NotificationsUpdated {
32 old_range: Range<usize>,
33 new_count: usize,
34 },
35 NewNotification {
36 entry: NotificationEntry,
37 },
38 NotificationRemoved {
39 entry: NotificationEntry,
40 },
41 NotificationRead {
42 entry: NotificationEntry,
43 },
44}
45
46#[derive(Debug, PartialEq, Eq, Clone)]
47pub struct NotificationEntry {
48 pub id: u64,
49 pub notification: Notification,
50 pub timestamp: OffsetDateTime,
51 pub is_read: bool,
52 pub response: Option<bool>,
53}
54
55#[derive(Clone, Debug, Default)]
56pub struct NotificationSummary {
57 max_id: u64,
58 count: usize,
59 unread_count: usize,
60}
61
62#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
63struct Count(usize);
64
65#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
66struct UnreadCount(usize);
67
68#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
69struct NotificationId(u64);
70
71impl NotificationStore {
72 pub fn global(cx: &AppContext) -> ModelHandle<Self> {
73 cx.global::<ModelHandle<Self>>().clone()
74 }
75
76 pub fn new(
77 client: Arc<Client>,
78 user_store: ModelHandle<UserStore>,
79 cx: &mut ModelContext<Self>,
80 ) -> Self {
81 let mut connection_status = client.status();
82 let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
83 while let Some(status) = connection_status.next().await {
84 let this = this.upgrade(&cx)?;
85 match status {
86 client::Status::Connected { .. } => {
87 if let Some(task) = this.update(&mut cx, |this, cx| this.handle_connect(cx))
88 {
89 task.await.log_err()?;
90 }
91 }
92 _ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
93 }
94 }
95 Some(())
96 });
97
98 Self {
99 channel_store: ChannelStore::global(cx),
100 notifications: Default::default(),
101 loaded_all_notifications: false,
102 channel_messages: Default::default(),
103 _watch_connection_status: watch_connection_status,
104 _subscriptions: vec![
105 client.add_message_handler(cx.handle(), Self::handle_new_notification),
106 client.add_message_handler(cx.handle(), Self::handle_delete_notification),
107 ],
108 user_store,
109 client,
110 }
111 }
112
113 pub fn notification_count(&self) -> usize {
114 self.notifications.summary().count
115 }
116
117 pub fn unread_notification_count(&self) -> usize {
118 self.notifications.summary().unread_count
119 }
120
121 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
122 self.channel_messages.get(&id)
123 }
124
125 // Get the nth newest notification.
126 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
127 let count = self.notifications.summary().count;
128 if ix >= count {
129 return None;
130 }
131 let ix = count - 1 - ix;
132 let mut cursor = self.notifications.cursor::<Count>();
133 cursor.seek(&Count(ix), Bias::Right, &());
134 cursor.item()
135 }
136
137 pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
138 let mut cursor = self.notifications.cursor::<NotificationId>();
139 cursor.seek(&NotificationId(id), Bias::Left, &());
140 if let Some(item) = cursor.item() {
141 if item.id == id {
142 return Some(item);
143 }
144 }
145 None
146 }
147
148 pub fn load_more_notifications(
149 &self,
150 clear_old: bool,
151 cx: &mut ModelContext<Self>,
152 ) -> Option<Task<Result<()>>> {
153 if self.loaded_all_notifications && !clear_old {
154 return None;
155 }
156
157 let before_id = if clear_old {
158 None
159 } else {
160 self.notifications.first().map(|entry| entry.id)
161 };
162 let request = self.client.request(proto::GetNotifications { before_id });
163 Some(cx.spawn(|this, mut cx| async move {
164 let response = request.await?;
165 this.update(&mut cx, |this, _| {
166 this.loaded_all_notifications = response.done
167 });
168 Self::add_notifications(
169 this,
170 response.notifications,
171 AddNotificationsOptions {
172 is_new: false,
173 clear_old,
174 includes_first: response.done,
175 },
176 cx,
177 )
178 .await?;
179 Ok(())
180 }))
181 }
182
183 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
184 self.notifications = Default::default();
185 self.channel_messages = Default::default();
186 cx.notify();
187 self.load_more_notifications(true, cx)
188 }
189
190 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
191 cx.notify()
192 }
193
194 async fn handle_new_notification(
195 this: ModelHandle<Self>,
196 envelope: TypedEnvelope<proto::AddNotification>,
197 _: Arc<Client>,
198 cx: AsyncAppContext,
199 ) -> Result<()> {
200 Self::add_notifications(
201 this,
202 envelope.payload.notification.into_iter().collect(),
203 AddNotificationsOptions {
204 is_new: true,
205 clear_old: false,
206 includes_first: false,
207 },
208 cx,
209 )
210 .await
211 }
212
213 async fn handle_delete_notification(
214 this: ModelHandle<Self>,
215 envelope: TypedEnvelope<proto::DeleteNotification>,
216 _: Arc<Client>,
217 mut cx: AsyncAppContext,
218 ) -> Result<()> {
219 this.update(&mut cx, |this, cx| {
220 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
221 Ok(())
222 })
223 }
224
225 async fn add_notifications(
226 this: ModelHandle<Self>,
227 notifications: Vec<proto::Notification>,
228 options: AddNotificationsOptions,
229 mut cx: AsyncAppContext,
230 ) -> Result<()> {
231 let mut user_ids = Vec::new();
232 let mut message_ids = Vec::new();
233
234 let notifications = notifications
235 .into_iter()
236 .filter_map(|message| {
237 Some(NotificationEntry {
238 id: message.id,
239 is_read: message.is_read,
240 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
241 .ok()?,
242 notification: Notification::from_proto(&message)?,
243 response: message.response,
244 })
245 })
246 .collect::<Vec<_>>();
247 if notifications.is_empty() {
248 return Ok(());
249 }
250
251 for entry in ¬ifications {
252 match entry.notification {
253 Notification::ChannelInvitation { inviter_id, .. } => {
254 user_ids.push(inviter_id);
255 }
256 Notification::ContactRequest {
257 sender_id: requester_id,
258 } => {
259 user_ids.push(requester_id);
260 }
261 Notification::ContactRequestAccepted {
262 responder_id: contact_id,
263 } => {
264 user_ids.push(contact_id);
265 }
266 Notification::ChannelMessageMention {
267 sender_id,
268 message_id,
269 ..
270 } => {
271 user_ids.push(sender_id);
272 message_ids.push(message_id);
273 }
274 }
275 }
276
277 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
278 (this.user_store.clone(), this.channel_store.clone())
279 });
280
281 user_store
282 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))
283 .await?;
284 let messages = channel_store
285 .update(&mut cx, |store, cx| {
286 store.fetch_channel_messages(message_ids, cx)
287 })
288 .await?;
289 this.update(&mut cx, |this, cx| {
290 if options.clear_old {
291 cx.emit(NotificationEvent::NotificationsUpdated {
292 old_range: 0..this.notifications.summary().count,
293 new_count: 0,
294 });
295 this.notifications = SumTree::default();
296 this.channel_messages.clear();
297 this.loaded_all_notifications = false;
298 }
299
300 if options.includes_first {
301 this.loaded_all_notifications = true;
302 }
303
304 this.channel_messages
305 .extend(messages.into_iter().filter_map(|message| {
306 if let ChannelMessageId::Saved(id) = message.id {
307 Some((id, message))
308 } else {
309 None
310 }
311 }));
312
313 this.splice_notifications(
314 notifications
315 .into_iter()
316 .map(|notification| (notification.id, Some(notification))),
317 options.is_new,
318 cx,
319 );
320 });
321
322 Ok(())
323 }
324
325 fn splice_notifications(
326 &mut self,
327 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
328 is_new: bool,
329 cx: &mut ModelContext<'_, NotificationStore>,
330 ) {
331 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
332 let mut new_notifications = SumTree::new();
333 let mut old_range = 0..0;
334
335 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
336 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
337
338 if i == 0 {
339 old_range.start = cursor.start().1 .0;
340 }
341
342 let old_notification = cursor.item();
343 if let Some(old_notification) = old_notification {
344 if old_notification.id == id {
345 cursor.next(&());
346
347 if let Some(new_notification) = &new_notification {
348 if new_notification.is_read {
349 cx.emit(NotificationEvent::NotificationRead {
350 entry: new_notification.clone(),
351 });
352 }
353 } else {
354 cx.emit(NotificationEvent::NotificationRemoved {
355 entry: old_notification.clone(),
356 });
357 }
358 }
359 } else if let Some(new_notification) = &new_notification {
360 if is_new {
361 cx.emit(NotificationEvent::NewNotification {
362 entry: new_notification.clone(),
363 });
364 }
365 }
366
367 if let Some(notification) = new_notification {
368 new_notifications.push(notification, &());
369 }
370 }
371
372 old_range.end = cursor.start().1 .0;
373 let new_count = new_notifications.summary().count - old_range.start;
374 new_notifications.append(cursor.suffix(&()), &());
375 drop(cursor);
376
377 self.notifications = new_notifications;
378 cx.emit(NotificationEvent::NotificationsUpdated {
379 old_range,
380 new_count,
381 });
382 }
383
384 pub fn respond_to_notification(
385 &mut self,
386 notification: Notification,
387 response: bool,
388 cx: &mut ModelContext<Self>,
389 ) {
390 match notification {
391 Notification::ContactRequest { sender_id } => {
392 self.user_store
393 .update(cx, |store, cx| {
394 store.respond_to_contact_request(sender_id, response, cx)
395 })
396 .detach();
397 }
398 Notification::ChannelInvitation { channel_id, .. } => {
399 self.channel_store
400 .update(cx, |store, cx| {
401 store.respond_to_channel_invite(channel_id, response, cx)
402 })
403 .detach();
404 }
405 _ => {}
406 }
407 }
408}
409
410impl Entity for NotificationStore {
411 type Event = NotificationEvent;
412}
413
414impl sum_tree::Item for NotificationEntry {
415 type Summary = NotificationSummary;
416
417 fn summary(&self) -> Self::Summary {
418 NotificationSummary {
419 max_id: self.id,
420 count: 1,
421 unread_count: if self.is_read { 0 } else { 1 },
422 }
423 }
424}
425
426impl sum_tree::Summary for NotificationSummary {
427 type Context = ();
428
429 fn add_summary(&mut self, summary: &Self, _: &()) {
430 self.max_id = self.max_id.max(summary.max_id);
431 self.count += summary.count;
432 self.unread_count += summary.unread_count;
433 }
434}
435
436impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
437 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
438 debug_assert!(summary.max_id > self.0);
439 self.0 = summary.max_id;
440 }
441}
442
443impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
444 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
445 self.0 += summary.count;
446 }
447}
448
449impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
450 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
451 self.0 += summary.unread_count;
452 }
453}
454
455struct AddNotificationsOptions {
456 is_new: bool,
457 clear_old: bool,
458 includes_first: bool,
459}