1use anyhow::Result;
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
7use rpc::{proto, Notification, TypedEnvelope};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
14 let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(notification_store);
16}
17
18pub struct NotificationStore {
19 client: Arc<Client>,
20 user_store: ModelHandle<UserStore>,
21 channel_messages: HashMap<u64, ChannelMessage>,
22 channel_store: ModelHandle<ChannelStore>,
23 notifications: SumTree<NotificationEntry>,
24 _watch_connection_status: Task<Option<()>>,
25 _subscriptions: Vec<client::Subscription>,
26}
27
28pub enum NotificationEvent {
29 NotificationsUpdated {
30 old_range: Range<usize>,
31 new_count: usize,
32 },
33 NewNotification {
34 entry: NotificationEntry,
35 },
36 NotificationRemoved {
37 entry: NotificationEntry,
38 },
39}
40
41#[derive(Debug, PartialEq, Eq, Clone)]
42pub struct NotificationEntry {
43 pub id: u64,
44 pub notification: Notification,
45 pub timestamp: OffsetDateTime,
46 pub is_read: bool,
47}
48
49#[derive(Clone, Debug, Default)]
50pub struct NotificationSummary {
51 max_id: u64,
52 count: usize,
53 unread_count: usize,
54}
55
56#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
57struct Count(usize);
58
59#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
60struct UnreadCount(usize);
61
62#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
63struct NotificationId(u64);
64
65impl NotificationStore {
66 pub fn global(cx: &AppContext) -> ModelHandle<Self> {
67 cx.global::<ModelHandle<Self>>().clone()
68 }
69
70 pub fn new(
71 client: Arc<Client>,
72 user_store: ModelHandle<UserStore>,
73 cx: &mut ModelContext<Self>,
74 ) -> Self {
75 let mut connection_status = client.status();
76 let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
77 while let Some(status) = connection_status.next().await {
78 let this = this.upgrade(&cx)?;
79 match status {
80 client::Status::Connected { .. } => {
81 this.update(&mut cx, |this, cx| this.handle_connect(cx))
82 .await
83 .log_err()?;
84 }
85 _ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
86 }
87 }
88 Some(())
89 });
90
91 Self {
92 channel_store: ChannelStore::global(cx),
93 notifications: Default::default(),
94 channel_messages: Default::default(),
95 _watch_connection_status: watch_connection_status,
96 _subscriptions: vec![
97 client.add_message_handler(cx.handle(), Self::handle_new_notification),
98 client.add_message_handler(cx.handle(), Self::handle_delete_notification),
99 ],
100 user_store,
101 client,
102 }
103 }
104
105 pub fn notification_count(&self) -> usize {
106 self.notifications.summary().count
107 }
108
109 pub fn unread_notification_count(&self) -> usize {
110 self.notifications.summary().unread_count
111 }
112
113 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
114 self.channel_messages.get(&id)
115 }
116
117 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
118 let mut cursor = self.notifications.cursor::<Count>();
119 cursor.seek(&Count(ix), Bias::Right, &());
120 cursor.item()
121 }
122
123 pub fn load_more_notifications(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
124 let request = self
125 .client
126 .request(proto::GetNotifications { before_id: None });
127 cx.spawn(|this, cx| async move {
128 let response = request.await?;
129 Self::add_notifications(this, false, response.notifications, cx).await?;
130 Ok(())
131 })
132 }
133
134 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
135 self.notifications = Default::default();
136 self.channel_messages = Default::default();
137 self.load_more_notifications(cx)
138 }
139
140 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
141 cx.notify()
142 }
143
144 async fn handle_new_notification(
145 this: ModelHandle<Self>,
146 envelope: TypedEnvelope<proto::NewNotification>,
147 _: Arc<Client>,
148 cx: AsyncAppContext,
149 ) -> Result<()> {
150 Self::add_notifications(
151 this,
152 true,
153 envelope.payload.notification.into_iter().collect(),
154 cx,
155 )
156 .await
157 }
158
159 async fn handle_delete_notification(
160 this: ModelHandle<Self>,
161 envelope: TypedEnvelope<proto::DeleteNotification>,
162 _: Arc<Client>,
163 mut cx: AsyncAppContext,
164 ) -> Result<()> {
165 this.update(&mut cx, |this, cx| {
166 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
167 Ok(())
168 })
169 }
170
171 async fn add_notifications(
172 this: ModelHandle<Self>,
173 is_new: bool,
174 notifications: Vec<proto::Notification>,
175 mut cx: AsyncAppContext,
176 ) -> Result<()> {
177 let mut user_ids = Vec::new();
178 let mut message_ids = Vec::new();
179
180 let notifications = notifications
181 .into_iter()
182 .filter_map(|message| {
183 Some(NotificationEntry {
184 id: message.id,
185 is_read: message.is_read,
186 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
187 .ok()?,
188 notification: Notification::from_proto(&message)?,
189 })
190 })
191 .collect::<Vec<_>>();
192 if notifications.is_empty() {
193 return Ok(());
194 }
195
196 for entry in ¬ifications {
197 match entry.notification {
198 Notification::ChannelInvitation {
199 actor_id: inviter_id,
200 ..
201 } => {
202 user_ids.push(inviter_id);
203 }
204 Notification::ContactRequest {
205 actor_id: requester_id,
206 } => {
207 user_ids.push(requester_id);
208 }
209 Notification::ContactRequestAccepted {
210 actor_id: contact_id,
211 } => {
212 user_ids.push(contact_id);
213 }
214 Notification::ChannelMessageMention {
215 actor_id: sender_id,
216 message_id,
217 ..
218 } => {
219 user_ids.push(sender_id);
220 message_ids.push(message_id);
221 }
222 }
223 }
224
225 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
226 (this.user_store.clone(), this.channel_store.clone())
227 });
228
229 user_store
230 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))
231 .await?;
232 let messages = channel_store
233 .update(&mut cx, |store, cx| {
234 store.fetch_channel_messages(message_ids, cx)
235 })
236 .await?;
237 this.update(&mut cx, |this, cx| {
238 this.channel_messages
239 .extend(messages.into_iter().filter_map(|message| {
240 if let ChannelMessageId::Saved(id) = message.id {
241 Some((id, message))
242 } else {
243 None
244 }
245 }));
246
247 this.splice_notifications(
248 notifications
249 .into_iter()
250 .map(|notification| (notification.id, Some(notification))),
251 is_new,
252 cx,
253 );
254 });
255
256 Ok(())
257 }
258
259 fn splice_notifications(
260 &mut self,
261 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
262 is_new: bool,
263 cx: &mut ModelContext<'_, NotificationStore>,
264 ) {
265 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
266 let mut new_notifications = SumTree::new();
267 let mut old_range = 0..0;
268
269 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
270 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
271
272 if i == 0 {
273 old_range.start = cursor.start().1 .0;
274 }
275
276 if let Some(existing_notification) = cursor.item() {
277 if existing_notification.id == id {
278 if new_notification.is_none() {
279 cx.emit(NotificationEvent::NotificationRemoved {
280 entry: existing_notification.clone(),
281 });
282 }
283 cursor.next(&());
284 }
285 }
286
287 if let Some(notification) = new_notification {
288 if is_new {
289 cx.emit(NotificationEvent::NewNotification {
290 entry: notification.clone(),
291 });
292 }
293
294 new_notifications.push(notification, &());
295 }
296 }
297
298 old_range.end = cursor.start().1 .0;
299 let new_count = new_notifications.summary().count - old_range.start;
300 new_notifications.append(cursor.suffix(&()), &());
301 drop(cursor);
302
303 self.notifications = new_notifications;
304 cx.emit(NotificationEvent::NotificationsUpdated {
305 old_range,
306 new_count,
307 });
308 }
309}
310
311impl Entity for NotificationStore {
312 type Event = NotificationEvent;
313}
314
315impl sum_tree::Item for NotificationEntry {
316 type Summary = NotificationSummary;
317
318 fn summary(&self) -> Self::Summary {
319 NotificationSummary {
320 max_id: self.id,
321 count: 1,
322 unread_count: if self.is_read { 0 } else { 1 },
323 }
324 }
325}
326
327impl sum_tree::Summary for NotificationSummary {
328 type Context = ();
329
330 fn add_summary(&mut self, summary: &Self, _: &()) {
331 self.max_id = self.max_id.max(summary.max_id);
332 self.count += summary.count;
333 self.unread_count += summary.unread_count;
334 }
335}
336
337impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
338 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
339 debug_assert!(summary.max_id > self.0);
340 self.0 = summary.max_id;
341 }
342}
343
344impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
345 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
346 self.0 += summary.count;
347 }
348}
349
350impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
351 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
352 self.0 += summary.unread_count;
353 }
354}