1use super::*;
2use anyhow::Context as _;
3use rpc::Notification;
4use util::ResultExt;
5
6impl Database {
7 /// Initializes the different kinds of notifications by upserting records for them.
8 pub async fn initialize_notification_kinds(&mut self) -> Result<()> {
9 let all_kinds = Notification::all_variant_names();
10 let existing_kinds = notification_kind::Entity::find().all(&self.pool).await?;
11
12 let kinds_to_create: Vec<_> = all_kinds
13 .iter()
14 .filter(|&kind| {
15 !existing_kinds
16 .iter()
17 .any(|existing| existing.name == **kind)
18 })
19 .map(|kind| notification_kind::ActiveModel {
20 name: ActiveValue::Set(kind.to_string()),
21 ..Default::default()
22 })
23 .collect();
24
25 if !kinds_to_create.is_empty() {
26 notification_kind::Entity::insert_many(kinds_to_create)
27 .exec_without_returning(&self.pool)
28 .await?;
29 }
30
31 let mut rows = notification_kind::Entity::find().stream(&self.pool).await?;
32 while let Some(row) = rows.next().await {
33 let row = row?;
34 self.notification_kinds_by_name.insert(row.name, row.id);
35 }
36
37 for name in Notification::all_variant_names() {
38 if let Some(id) = self.notification_kinds_by_name.get(*name).copied() {
39 self.notification_kinds_by_id.insert(id, name);
40 }
41 }
42
43 Ok(())
44 }
45
46 /// Returns the notifications for the given recipient.
47 pub async fn get_notifications(
48 &self,
49 recipient_id: UserId,
50 limit: usize,
51 before_id: Option<NotificationId>,
52 ) -> Result<Vec<proto::Notification>> {
53 self.transaction(|tx| async move {
54 let mut result = Vec::new();
55 let mut condition =
56 Condition::all().add(notification::Column::RecipientId.eq(recipient_id));
57
58 if let Some(before_id) = before_id {
59 condition = condition.add(notification::Column::Id.lt(before_id));
60 }
61
62 let mut rows = notification::Entity::find()
63 .filter(condition)
64 .order_by_desc(notification::Column::Id)
65 .limit(limit as u64)
66 .stream(&*tx)
67 .await?;
68 while let Some(row) = rows.next().await {
69 let row = row?;
70 if let Some(proto) = model_to_proto(self, row).log_err() {
71 result.push(proto);
72 }
73 }
74 result.reverse();
75 Ok(result)
76 })
77 .await
78 }
79
80 /// Creates a notification. If `avoid_duplicates` is set to true, then avoid
81 /// creating a new notification if the given recipient already has an
82 /// unread notification with the given kind and entity id.
83 pub async fn create_notification(
84 &self,
85 recipient_id: UserId,
86 notification: Notification,
87 avoid_duplicates: bool,
88 tx: &DatabaseTransaction,
89 ) -> Result<Option<(UserId, proto::Notification)>> {
90 if avoid_duplicates
91 && self
92 .find_notification(recipient_id, ¬ification, tx)
93 .await?
94 .is_some()
95 {
96 return Ok(None);
97 }
98
99 let proto = notification.to_proto();
100 let kind = notification_kind_from_proto(self, &proto)?;
101 let model = notification::ActiveModel {
102 recipient_id: ActiveValue::Set(recipient_id),
103 kind: ActiveValue::Set(kind),
104 entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)),
105 content: ActiveValue::Set(proto.content.clone()),
106 ..Default::default()
107 }
108 .save(tx)
109 .await?;
110
111 Ok(Some((
112 recipient_id,
113 proto::Notification {
114 id: model.id.as_ref().to_proto(),
115 kind: proto.kind,
116 timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
117 is_read: false,
118 response: None,
119 content: proto.content,
120 entity_id: proto.entity_id,
121 },
122 )))
123 }
124
125 /// Remove an unread notification with the given recipient, kind and
126 /// entity id.
127 pub async fn remove_notification(
128 &self,
129 recipient_id: UserId,
130 notification: Notification,
131 tx: &DatabaseTransaction,
132 ) -> Result<Option<NotificationId>> {
133 let id = self
134 .find_notification(recipient_id, ¬ification, tx)
135 .await?;
136 if let Some(id) = id {
137 notification::Entity::delete_by_id(id).exec(tx).await?;
138 }
139 Ok(id)
140 }
141
142 /// Populate the response for the notification with the given kind and
143 /// entity id.
144 pub async fn mark_notification_as_read_with_response(
145 &self,
146 recipient_id: UserId,
147 notification: &Notification,
148 response: bool,
149 tx: &DatabaseTransaction,
150 ) -> Result<Option<(UserId, proto::Notification)>> {
151 self.mark_notification_as_read_internal(recipient_id, notification, Some(response), tx)
152 .await
153 }
154
155 /// Marks the given notification as read.
156 pub async fn mark_notification_as_read(
157 &self,
158 recipient_id: UserId,
159 notification: &Notification,
160 tx: &DatabaseTransaction,
161 ) -> Result<Option<(UserId, proto::Notification)>> {
162 self.mark_notification_as_read_internal(recipient_id, notification, None, tx)
163 .await
164 }
165
166 /// Marks the notification with the given ID as read.
167 pub async fn mark_notification_as_read_by_id(
168 &self,
169 recipient_id: UserId,
170 notification_id: NotificationId,
171 ) -> Result<NotificationBatch> {
172 self.transaction(|tx| async move {
173 let row = notification::Entity::update(notification::ActiveModel {
174 id: ActiveValue::Unchanged(notification_id),
175 recipient_id: ActiveValue::Unchanged(recipient_id),
176 is_read: ActiveValue::Set(true),
177 ..Default::default()
178 })
179 .exec(&*tx)
180 .await?;
181 Ok(model_to_proto(self, row)
182 .map(|notification| (recipient_id, notification))
183 .into_iter()
184 .collect())
185 })
186 .await
187 }
188
189 async fn mark_notification_as_read_internal(
190 &self,
191 recipient_id: UserId,
192 notification: &Notification,
193 response: Option<bool>,
194 tx: &DatabaseTransaction,
195 ) -> Result<Option<(UserId, proto::Notification)>> {
196 if let Some(id) = self
197 .find_notification(recipient_id, notification, tx)
198 .await?
199 {
200 let row = notification::Entity::update(notification::ActiveModel {
201 id: ActiveValue::Unchanged(id),
202 recipient_id: ActiveValue::Unchanged(recipient_id),
203 is_read: ActiveValue::Set(true),
204 response: if let Some(response) = response {
205 ActiveValue::Set(Some(response))
206 } else {
207 ActiveValue::NotSet
208 },
209 ..Default::default()
210 })
211 .exec(tx)
212 .await?;
213 Ok(model_to_proto(self, row)
214 .map(|notification| (recipient_id, notification))
215 .ok())
216 } else {
217 Ok(None)
218 }
219 }
220
221 /// Find an unread notification by its recipient, kind and entity id.
222 async fn find_notification(
223 &self,
224 recipient_id: UserId,
225 notification: &Notification,
226 tx: &DatabaseTransaction,
227 ) -> Result<Option<NotificationId>> {
228 let proto = notification.to_proto();
229 let kind = notification_kind_from_proto(self, &proto)?;
230
231 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
232 enum QueryIds {
233 Id,
234 }
235
236 Ok(notification::Entity::find()
237 .select_only()
238 .column(notification::Column::Id)
239 .filter(
240 Condition::all()
241 .add(notification::Column::RecipientId.eq(recipient_id))
242 .add(notification::Column::IsRead.eq(false))
243 .add(notification::Column::Kind.eq(kind))
244 .add(if proto.entity_id.is_some() {
245 notification::Column::EntityId.eq(proto.entity_id)
246 } else {
247 notification::Column::EntityId.is_null()
248 }),
249 )
250 .into_values::<_, QueryIds>()
251 .one(tx)
252 .await?)
253 }
254}
255
256pub fn model_to_proto(this: &Database, row: notification::Model) -> Result<proto::Notification> {
257 let kind = this
258 .notification_kinds_by_id
259 .get(&row.kind)
260 .context("Unknown notification kind")?;
261 Ok(proto::Notification {
262 id: row.id.to_proto(),
263 kind: kind.to_string(),
264 timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
265 is_read: row.is_read,
266 response: row.response,
267 content: row.content,
268 entity_id: row.entity_id.map(|id| id as u64),
269 })
270}
271
272fn notification_kind_from_proto(
273 this: &Database,
274 proto: &proto::Notification,
275) -> Result<NotificationKindId> {
276 Ok(this
277 .notification_kinds_by_name
278 .get(&proto.kind)
279 .copied()
280 .with_context(|| format!("invalid notification kind {:?}", proto.kind))?)
281}