1use anyhow::Result;
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
7use rpc::{proto, Notification, TypedEnvelope};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
14 let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(notification_store);
16}
17
18pub struct NotificationStore {
19 client: Arc<Client>,
20 user_store: ModelHandle<UserStore>,
21 channel_messages: HashMap<u64, ChannelMessage>,
22 channel_store: ModelHandle<ChannelStore>,
23 notifications: SumTree<NotificationEntry>,
24 _watch_connection_status: Task<Option<()>>,
25 _subscriptions: Vec<client::Subscription>,
26}
27
28pub enum NotificationEvent {
29 NotificationsUpdated {
30 old_range: Range<usize>,
31 new_count: usize,
32 },
33 NewNotification {
34 entry: NotificationEntry,
35 },
36 NotificationRemoved {
37 entry: NotificationEntry,
38 },
39 NotificationRead {
40 entry: NotificationEntry,
41 },
42}
43
44#[derive(Debug, PartialEq, Eq, Clone)]
45pub struct NotificationEntry {
46 pub id: u64,
47 pub notification: Notification,
48 pub timestamp: OffsetDateTime,
49 pub is_read: bool,
50 pub response: Option<bool>,
51}
52
53#[derive(Clone, Debug, Default)]
54pub struct NotificationSummary {
55 max_id: u64,
56 count: usize,
57 unread_count: usize,
58}
59
60#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
61struct Count(usize);
62
63#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
64struct UnreadCount(usize);
65
66#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
67struct NotificationId(u64);
68
69impl NotificationStore {
70 pub fn global(cx: &AppContext) -> ModelHandle<Self> {
71 cx.global::<ModelHandle<Self>>().clone()
72 }
73
74 pub fn new(
75 client: Arc<Client>,
76 user_store: ModelHandle<UserStore>,
77 cx: &mut ModelContext<Self>,
78 ) -> Self {
79 let mut connection_status = client.status();
80 let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
81 while let Some(status) = connection_status.next().await {
82 let this = this.upgrade(&cx)?;
83 match status {
84 client::Status::Connected { .. } => {
85 this.update(&mut cx, |this, cx| this.handle_connect(cx))
86 .await
87 .log_err()?;
88 }
89 _ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
90 }
91 }
92 Some(())
93 });
94
95 Self {
96 channel_store: ChannelStore::global(cx),
97 notifications: Default::default(),
98 channel_messages: Default::default(),
99 _watch_connection_status: watch_connection_status,
100 _subscriptions: vec![
101 client.add_message_handler(cx.handle(), Self::handle_new_notification),
102 client.add_message_handler(cx.handle(), Self::handle_delete_notification),
103 ],
104 user_store,
105 client,
106 }
107 }
108
109 pub fn notification_count(&self) -> usize {
110 self.notifications.summary().count
111 }
112
113 pub fn unread_notification_count(&self) -> usize {
114 self.notifications.summary().unread_count
115 }
116
117 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
118 self.channel_messages.get(&id)
119 }
120
121 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
122 let mut cursor = self.notifications.cursor::<Count>();
123 cursor.seek(&Count(ix), Bias::Right, &());
124 cursor.item()
125 }
126
127 pub fn load_more_notifications(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
128 let request = self
129 .client
130 .request(proto::GetNotifications { before_id: None });
131 cx.spawn(|this, cx| async move {
132 let response = request.await?;
133 Self::add_notifications(this, false, response.notifications, cx).await?;
134 Ok(())
135 })
136 }
137
138 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
139 self.notifications = Default::default();
140 self.channel_messages = Default::default();
141 self.load_more_notifications(cx)
142 }
143
144 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
145 cx.notify()
146 }
147
148 async fn handle_new_notification(
149 this: ModelHandle<Self>,
150 envelope: TypedEnvelope<proto::NewNotification>,
151 _: Arc<Client>,
152 cx: AsyncAppContext,
153 ) -> Result<()> {
154 Self::add_notifications(
155 this,
156 true,
157 envelope.payload.notification.into_iter().collect(),
158 cx,
159 )
160 .await
161 }
162
163 async fn handle_delete_notification(
164 this: ModelHandle<Self>,
165 envelope: TypedEnvelope<proto::DeleteNotification>,
166 _: Arc<Client>,
167 mut cx: AsyncAppContext,
168 ) -> Result<()> {
169 this.update(&mut cx, |this, cx| {
170 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
171 Ok(())
172 })
173 }
174
175 async fn add_notifications(
176 this: ModelHandle<Self>,
177 is_new: bool,
178 notifications: Vec<proto::Notification>,
179 mut cx: AsyncAppContext,
180 ) -> Result<()> {
181 let mut user_ids = Vec::new();
182 let mut message_ids = Vec::new();
183
184 let notifications = notifications
185 .into_iter()
186 .filter_map(|message| {
187 Some(NotificationEntry {
188 id: message.id,
189 is_read: message.is_read,
190 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
191 .ok()?,
192 notification: Notification::from_proto(&message)?,
193 response: message.response,
194 })
195 })
196 .collect::<Vec<_>>();
197 if notifications.is_empty() {
198 return Ok(());
199 }
200
201 for entry in ¬ifications {
202 match entry.notification {
203 Notification::ChannelInvitation { .. } => {}
204 Notification::ContactRequest {
205 sender_id: requester_id,
206 } => {
207 user_ids.push(requester_id);
208 }
209 Notification::ContactRequestAccepted {
210 responder_id: contact_id,
211 } => {
212 user_ids.push(contact_id);
213 }
214 Notification::ChannelMessageMention {
215 sender_id,
216 message_id,
217 ..
218 } => {
219 user_ids.push(sender_id);
220 message_ids.push(message_id);
221 }
222 }
223 }
224
225 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
226 (this.user_store.clone(), this.channel_store.clone())
227 });
228
229 user_store
230 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))
231 .await?;
232 let messages = channel_store
233 .update(&mut cx, |store, cx| {
234 store.fetch_channel_messages(message_ids, cx)
235 })
236 .await?;
237 this.update(&mut cx, |this, cx| {
238 this.channel_messages
239 .extend(messages.into_iter().filter_map(|message| {
240 if let ChannelMessageId::Saved(id) = message.id {
241 Some((id, message))
242 } else {
243 None
244 }
245 }));
246
247 this.splice_notifications(
248 notifications
249 .into_iter()
250 .map(|notification| (notification.id, Some(notification))),
251 is_new,
252 cx,
253 );
254 });
255
256 Ok(())
257 }
258
259 fn splice_notifications(
260 &mut self,
261 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
262 is_new: bool,
263 cx: &mut ModelContext<'_, NotificationStore>,
264 ) {
265 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
266 let mut new_notifications = SumTree::new();
267 let mut old_range = 0..0;
268
269 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
270 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
271
272 if i == 0 {
273 old_range.start = cursor.start().1 .0;
274 }
275
276 if let Some(existing_notification) = cursor.item() {
277 if existing_notification.id == id {
278 if let Some(new_notification) = &new_notification {
279 if new_notification.is_read {
280 cx.emit(NotificationEvent::NotificationRead {
281 entry: new_notification.clone(),
282 });
283 }
284 } else {
285 cx.emit(NotificationEvent::NotificationRemoved {
286 entry: existing_notification.clone(),
287 });
288 }
289 cursor.next(&());
290 }
291 }
292
293 if let Some(notification) = new_notification {
294 if is_new {
295 cx.emit(NotificationEvent::NewNotification {
296 entry: notification.clone(),
297 });
298 }
299
300 new_notifications.push(notification, &());
301 }
302 }
303
304 old_range.end = cursor.start().1 .0;
305 let new_count = new_notifications.summary().count - old_range.start;
306 new_notifications.append(cursor.suffix(&()), &());
307 drop(cursor);
308
309 self.notifications = new_notifications;
310 cx.emit(NotificationEvent::NotificationsUpdated {
311 old_range,
312 new_count,
313 });
314 }
315
316 pub fn respond_to_notification(
317 &mut self,
318 notification: Notification,
319 response: bool,
320 cx: &mut ModelContext<Self>,
321 ) {
322 match notification {
323 Notification::ContactRequest { sender_id } => {
324 self.user_store
325 .update(cx, |store, cx| {
326 store.respond_to_contact_request(sender_id, response, cx)
327 })
328 .detach();
329 }
330 Notification::ChannelInvitation { channel_id, .. } => {
331 self.channel_store
332 .update(cx, |store, cx| {
333 store.respond_to_channel_invite(channel_id, response, cx)
334 })
335 .detach();
336 }
337 _ => {}
338 }
339 }
340}
341
342impl Entity for NotificationStore {
343 type Event = NotificationEvent;
344}
345
346impl sum_tree::Item for NotificationEntry {
347 type Summary = NotificationSummary;
348
349 fn summary(&self) -> Self::Summary {
350 NotificationSummary {
351 max_id: self.id,
352 count: 1,
353 unread_count: if self.is_read { 0 } else { 1 },
354 }
355 }
356}
357
358impl sum_tree::Summary for NotificationSummary {
359 type Context = ();
360
361 fn add_summary(&mut self, summary: &Self, _: &()) {
362 self.max_id = self.max_id.max(summary.max_id);
363 self.count += summary.count;
364 self.unread_count += summary.unread_count;
365 }
366}
367
368impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
369 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
370 debug_assert!(summary.max_id > self.0);
371 self.0 = summary.max_id;
372 }
373}
374
375impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
376 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
377 self.0 += summary.count;
378 }
379}
380
381impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
382 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
383 self.0 += summary.unread_count;
384 }
385}