1use anyhow::Result;
2use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
3use client::{Client, UserStore};
4use collections::HashMap;
5use db::smol::stream::StreamExt;
6use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
7use rpc::{proto, Notification, TypedEnvelope};
8use std::{ops::Range, sync::Arc};
9use sum_tree::{Bias, SumTree};
10use time::OffsetDateTime;
11use util::ResultExt;
12
13pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
14 let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
15 cx.set_global(notification_store);
16}
17
18pub struct NotificationStore {
19 client: Arc<Client>,
20 user_store: ModelHandle<UserStore>,
21 channel_messages: HashMap<u64, ChannelMessage>,
22 channel_store: ModelHandle<ChannelStore>,
23 notifications: SumTree<NotificationEntry>,
24 _watch_connection_status: Task<Option<()>>,
25 _subscriptions: Vec<client::Subscription>,
26}
27
28#[derive(Clone, PartialEq, Eq, Debug)]
29pub enum NotificationEvent {
30 NotificationsUpdated {
31 old_range: Range<usize>,
32 new_count: usize,
33 },
34 NewNotification {
35 entry: NotificationEntry,
36 },
37 NotificationRemoved {
38 entry: NotificationEntry,
39 },
40 NotificationRead {
41 entry: NotificationEntry,
42 },
43}
44
45#[derive(Debug, PartialEq, Eq, Clone)]
46pub struct NotificationEntry {
47 pub id: u64,
48 pub notification: Notification,
49 pub timestamp: OffsetDateTime,
50 pub is_read: bool,
51 pub response: Option<bool>,
52}
53
54#[derive(Clone, Debug, Default)]
55pub struct NotificationSummary {
56 max_id: u64,
57 count: usize,
58 unread_count: usize,
59}
60
61#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
62struct Count(usize);
63
64#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
65struct UnreadCount(usize);
66
67#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
68struct NotificationId(u64);
69
70impl NotificationStore {
71 pub fn global(cx: &AppContext) -> ModelHandle<Self> {
72 cx.global::<ModelHandle<Self>>().clone()
73 }
74
75 pub fn new(
76 client: Arc<Client>,
77 user_store: ModelHandle<UserStore>,
78 cx: &mut ModelContext<Self>,
79 ) -> Self {
80 let mut connection_status = client.status();
81 let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
82 while let Some(status) = connection_status.next().await {
83 let this = this.upgrade(&cx)?;
84 match status {
85 client::Status::Connected { .. } => {
86 this.update(&mut cx, |this, cx| this.handle_connect(cx))
87 .await
88 .log_err()?;
89 }
90 _ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
91 }
92 }
93 Some(())
94 });
95
96 Self {
97 channel_store: ChannelStore::global(cx),
98 notifications: Default::default(),
99 channel_messages: Default::default(),
100 _watch_connection_status: watch_connection_status,
101 _subscriptions: vec![
102 client.add_message_handler(cx.handle(), Self::handle_new_notification),
103 client.add_message_handler(cx.handle(), Self::handle_delete_notification),
104 ],
105 user_store,
106 client,
107 }
108 }
109
110 pub fn notification_count(&self) -> usize {
111 self.notifications.summary().count
112 }
113
114 pub fn unread_notification_count(&self) -> usize {
115 self.notifications.summary().unread_count
116 }
117
118 pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
119 self.channel_messages.get(&id)
120 }
121
122 // Get the nth newest notification.
123 pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
124 let count = self.notifications.summary().count;
125 if ix >= count {
126 return None;
127 }
128 let ix = count - 1 - ix;
129 let mut cursor = self.notifications.cursor::<Count>();
130 cursor.seek(&Count(ix), Bias::Right, &());
131 cursor.item()
132 }
133
134 pub fn load_more_notifications(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
135 let request = self
136 .client
137 .request(proto::GetNotifications { before_id: None });
138 cx.spawn(|this, cx| async move {
139 let response = request.await?;
140 Self::add_notifications(this, false, response.notifications, cx).await?;
141 Ok(())
142 })
143 }
144
145 fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
146 self.notifications = Default::default();
147 self.channel_messages = Default::default();
148 self.load_more_notifications(cx)
149 }
150
151 fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
152 cx.notify()
153 }
154
155 async fn handle_new_notification(
156 this: ModelHandle<Self>,
157 envelope: TypedEnvelope<proto::NewNotification>,
158 _: Arc<Client>,
159 cx: AsyncAppContext,
160 ) -> Result<()> {
161 Self::add_notifications(
162 this,
163 true,
164 envelope.payload.notification.into_iter().collect(),
165 cx,
166 )
167 .await
168 }
169
170 async fn handle_delete_notification(
171 this: ModelHandle<Self>,
172 envelope: TypedEnvelope<proto::DeleteNotification>,
173 _: Arc<Client>,
174 mut cx: AsyncAppContext,
175 ) -> Result<()> {
176 this.update(&mut cx, |this, cx| {
177 this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
178 Ok(())
179 })
180 }
181
182 async fn add_notifications(
183 this: ModelHandle<Self>,
184 is_new: bool,
185 notifications: Vec<proto::Notification>,
186 mut cx: AsyncAppContext,
187 ) -> Result<()> {
188 let mut user_ids = Vec::new();
189 let mut message_ids = Vec::new();
190
191 let notifications = notifications
192 .into_iter()
193 .filter_map(|message| {
194 Some(NotificationEntry {
195 id: message.id,
196 is_read: message.is_read,
197 timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
198 .ok()?,
199 notification: Notification::from_proto(&message)?,
200 response: message.response,
201 })
202 })
203 .collect::<Vec<_>>();
204 if notifications.is_empty() {
205 return Ok(());
206 }
207
208 for entry in ¬ifications {
209 match entry.notification {
210 Notification::ChannelInvitation { inviter_id, .. } => {
211 user_ids.push(inviter_id);
212 }
213 Notification::ContactRequest {
214 sender_id: requester_id,
215 } => {
216 user_ids.push(requester_id);
217 }
218 Notification::ContactRequestAccepted {
219 responder_id: contact_id,
220 } => {
221 user_ids.push(contact_id);
222 }
223 Notification::ChannelMessageMention {
224 sender_id,
225 message_id,
226 ..
227 } => {
228 user_ids.push(sender_id);
229 message_ids.push(message_id);
230 }
231 }
232 }
233
234 let (user_store, channel_store) = this.read_with(&cx, |this, _| {
235 (this.user_store.clone(), this.channel_store.clone())
236 });
237
238 user_store
239 .update(&mut cx, |store, cx| store.get_users(user_ids, cx))
240 .await?;
241 let messages = channel_store
242 .update(&mut cx, |store, cx| {
243 store.fetch_channel_messages(message_ids, cx)
244 })
245 .await?;
246 this.update(&mut cx, |this, cx| {
247 this.channel_messages
248 .extend(messages.into_iter().filter_map(|message| {
249 if let ChannelMessageId::Saved(id) = message.id {
250 Some((id, message))
251 } else {
252 None
253 }
254 }));
255
256 this.splice_notifications(
257 notifications
258 .into_iter()
259 .map(|notification| (notification.id, Some(notification))),
260 is_new,
261 cx,
262 );
263 });
264
265 Ok(())
266 }
267
268 fn splice_notifications(
269 &mut self,
270 notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
271 is_new: bool,
272 cx: &mut ModelContext<'_, NotificationStore>,
273 ) {
274 let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
275 let mut new_notifications = SumTree::new();
276 let mut old_range = 0..0;
277
278 for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
279 new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
280
281 if i == 0 {
282 old_range.start = cursor.start().1 .0;
283 }
284
285 let old_notification = cursor.item();
286 if let Some(old_notification) = old_notification {
287 if old_notification.id == id {
288 cursor.next(&());
289
290 if let Some(new_notification) = &new_notification {
291 if new_notification.is_read {
292 cx.emit(NotificationEvent::NotificationRead {
293 entry: new_notification.clone(),
294 });
295 }
296 } else {
297 cx.emit(NotificationEvent::NotificationRemoved {
298 entry: old_notification.clone(),
299 });
300 }
301 }
302 } else if let Some(new_notification) = &new_notification {
303 if is_new {
304 cx.emit(NotificationEvent::NewNotification {
305 entry: new_notification.clone(),
306 });
307 }
308 }
309
310 if let Some(notification) = new_notification {
311 new_notifications.push(notification, &());
312 }
313 }
314
315 old_range.end = cursor.start().1 .0;
316 let new_count = new_notifications.summary().count - old_range.start;
317 new_notifications.append(cursor.suffix(&()), &());
318 drop(cursor);
319
320 self.notifications = new_notifications;
321 cx.emit(NotificationEvent::NotificationsUpdated {
322 old_range,
323 new_count,
324 });
325 }
326
327 pub fn respond_to_notification(
328 &mut self,
329 notification: Notification,
330 response: bool,
331 cx: &mut ModelContext<Self>,
332 ) {
333 match notification {
334 Notification::ContactRequest { sender_id } => {
335 self.user_store
336 .update(cx, |store, cx| {
337 store.respond_to_contact_request(sender_id, response, cx)
338 })
339 .detach();
340 }
341 Notification::ChannelInvitation { channel_id, .. } => {
342 self.channel_store
343 .update(cx, |store, cx| {
344 store.respond_to_channel_invite(channel_id, response, cx)
345 })
346 .detach();
347 }
348 _ => {}
349 }
350 }
351}
352
353impl Entity for NotificationStore {
354 type Event = NotificationEvent;
355}
356
357impl sum_tree::Item for NotificationEntry {
358 type Summary = NotificationSummary;
359
360 fn summary(&self) -> Self::Summary {
361 NotificationSummary {
362 max_id: self.id,
363 count: 1,
364 unread_count: if self.is_read { 0 } else { 1 },
365 }
366 }
367}
368
369impl sum_tree::Summary for NotificationSummary {
370 type Context = ();
371
372 fn add_summary(&mut self, summary: &Self, _: &()) {
373 self.max_id = self.max_id.max(summary.max_id);
374 self.count += summary.count;
375 self.unread_count += summary.unread_count;
376 }
377}
378
379impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
380 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
381 debug_assert!(summary.max_id > self.0);
382 self.0 = summary.max_id;
383 }
384}
385
386impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
387 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
388 self.0 += summary.count;
389 }
390}
391
392impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
393 fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
394 self.0 += summary.unread_count;
395 }
396}