1use super::*;
2use rpc::Notification;
3
4impl Database {
5 pub async fn initialize_notification_kinds(&mut self) -> Result<()> {
6 notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map(
7 |kind| notification_kind::ActiveModel {
8 name: ActiveValue::Set(kind.to_string()),
9 ..Default::default()
10 },
11 ))
12 .on_conflict(OnConflict::new().do_nothing().to_owned())
13 .exec_without_returning(&self.pool)
14 .await?;
15
16 let mut rows = notification_kind::Entity::find().stream(&self.pool).await?;
17 while let Some(row) = rows.next().await {
18 let row = row?;
19 self.notification_kinds_by_name.insert(row.name, row.id);
20 }
21
22 for name in Notification::all_variant_names() {
23 if let Some(id) = self.notification_kinds_by_name.get(*name).copied() {
24 self.notification_kinds_by_id.insert(id, name);
25 }
26 }
27
28 Ok(())
29 }
30
31 pub async fn get_notifications(
32 &self,
33 recipient_id: UserId,
34 limit: usize,
35 before_id: Option<NotificationId>,
36 ) -> Result<Vec<proto::Notification>> {
37 self.transaction(|tx| async move {
38 let mut result = Vec::new();
39 let mut condition =
40 Condition::all().add(notification::Column::RecipientId.eq(recipient_id));
41
42 if let Some(before_id) = before_id {
43 condition = condition.add(notification::Column::Id.lt(before_id));
44 }
45
46 let mut rows = notification::Entity::find()
47 .filter(condition)
48 .order_by_desc(notification::Column::Id)
49 .limit(limit as u64)
50 .stream(&*tx)
51 .await?;
52 while let Some(row) = rows.next().await {
53 let row = row?;
54 let kind = row.kind;
55 if let Some(proto) = model_to_proto(self, row) {
56 result.push(proto);
57 } else {
58 log::warn!("unknown notification kind {:?}", kind);
59 }
60 }
61 result.reverse();
62 Ok(result)
63 })
64 .await
65 }
66
67 /// Create a notification. If `avoid_duplicates` is set to true, then avoid
68 /// creating a new notification if the given recipient already has an
69 /// unread notification with the given kind and entity id.
70 pub async fn create_notification(
71 &self,
72 recipient_id: UserId,
73 notification: Notification,
74 avoid_duplicates: bool,
75 tx: &DatabaseTransaction,
76 ) -> Result<Option<(UserId, proto::Notification)>> {
77 if avoid_duplicates {
78 if self
79 .find_notification(recipient_id, ¬ification, tx)
80 .await?
81 .is_some()
82 {
83 return Ok(None);
84 }
85 }
86
87 let proto = notification.to_proto();
88 let kind = notification_kind_from_proto(self, &proto)?;
89 let model = notification::ActiveModel {
90 recipient_id: ActiveValue::Set(recipient_id),
91 kind: ActiveValue::Set(kind),
92 entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)),
93 content: ActiveValue::Set(proto.content.clone()),
94 ..Default::default()
95 }
96 .save(&*tx)
97 .await?;
98
99 Ok(Some((
100 recipient_id,
101 proto::Notification {
102 id: model.id.as_ref().to_proto(),
103 kind: proto.kind,
104 timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
105 is_read: false,
106 response: None,
107 content: proto.content,
108 entity_id: proto.entity_id,
109 },
110 )))
111 }
112
113 /// Remove an unread notification with the given recipient, kind and
114 /// entity id.
115 pub async fn remove_notification(
116 &self,
117 recipient_id: UserId,
118 notification: Notification,
119 tx: &DatabaseTransaction,
120 ) -> Result<Option<NotificationId>> {
121 let id = self
122 .find_notification(recipient_id, ¬ification, tx)
123 .await?;
124 if let Some(id) = id {
125 notification::Entity::delete_by_id(id).exec(tx).await?;
126 }
127 Ok(id)
128 }
129
130 /// Populate the response for the notification with the given kind and
131 /// entity id.
132 pub async fn respond_to_notification(
133 &self,
134 recipient_id: UserId,
135 notification: &Notification,
136 response: bool,
137 tx: &DatabaseTransaction,
138 ) -> Result<Option<(UserId, proto::Notification)>> {
139 if let Some(id) = self
140 .find_notification(recipient_id, notification, tx)
141 .await?
142 {
143 let row = notification::Entity::update(notification::ActiveModel {
144 id: ActiveValue::Unchanged(id),
145 recipient_id: ActiveValue::Unchanged(recipient_id),
146 response: ActiveValue::Set(Some(response)),
147 is_read: ActiveValue::Set(true),
148 ..Default::default()
149 })
150 .exec(tx)
151 .await?;
152 Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification)))
153 } else {
154 Ok(None)
155 }
156 }
157
158 /// Find an unread notification by its recipient, kind and entity id.
159 async fn find_notification(
160 &self,
161 recipient_id: UserId,
162 notification: &Notification,
163 tx: &DatabaseTransaction,
164 ) -> Result<Option<NotificationId>> {
165 let proto = notification.to_proto();
166 let kind = notification_kind_from_proto(self, &proto)?;
167
168 #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
169 enum QueryIds {
170 Id,
171 }
172
173 Ok(notification::Entity::find()
174 .select_only()
175 .column(notification::Column::Id)
176 .filter(
177 Condition::all()
178 .add(notification::Column::RecipientId.eq(recipient_id))
179 .add(notification::Column::IsRead.eq(false))
180 .add(notification::Column::Kind.eq(kind))
181 .add(if proto.entity_id.is_some() {
182 notification::Column::EntityId.eq(proto.entity_id)
183 } else {
184 notification::Column::EntityId.is_null()
185 }),
186 )
187 .into_values::<_, QueryIds>()
188 .one(&*tx)
189 .await?)
190 }
191}
192
193fn model_to_proto(this: &Database, row: notification::Model) -> Option<proto::Notification> {
194 let kind = this.notification_kinds_by_id.get(&row.kind)?;
195 Some(proto::Notification {
196 id: row.id.to_proto(),
197 kind: kind.to_string(),
198 timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
199 is_read: row.is_read,
200 response: row.response,
201 content: row.content,
202 entity_id: row.entity_id.map(|id| id as u64),
203 })
204}
205
206fn notification_kind_from_proto(
207 this: &Database,
208 proto: &proto::Notification,
209) -> Result<NotificationKindId> {
210 Ok(this
211 .notification_kinds_by_name
212 .get(&proto.kind)
213 .copied()
214 .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?)
215}