notifications.rs

  1use super::*;
  2use rpc::Notification;
  3
  4impl Database {
  5    pub async fn initialize_notification_kinds(&mut self) -> Result<()> {
  6        notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map(
  7            |kind| notification_kind::ActiveModel {
  8                name: ActiveValue::Set(kind.to_string()),
  9                ..Default::default()
 10            },
 11        ))
 12        .on_conflict(OnConflict::new().do_nothing().to_owned())
 13        .exec_without_returning(&self.pool)
 14        .await?;
 15
 16        let mut rows = notification_kind::Entity::find().stream(&self.pool).await?;
 17        while let Some(row) = rows.next().await {
 18            let row = row?;
 19            self.notification_kinds_by_name.insert(row.name, row.id);
 20        }
 21
 22        for name in Notification::all_variant_names() {
 23            if let Some(id) = self.notification_kinds_by_name.get(*name).copied() {
 24                self.notification_kinds_by_id.insert(id, name);
 25            }
 26        }
 27
 28        Ok(())
 29    }
 30
 31    pub async fn get_notifications(
 32        &self,
 33        recipient_id: UserId,
 34        limit: usize,
 35        before_id: Option<NotificationId>,
 36    ) -> Result<Vec<proto::Notification>> {
 37        self.transaction(|tx| async move {
 38            let mut result = Vec::new();
 39            let mut condition =
 40                Condition::all().add(notification::Column::RecipientId.eq(recipient_id));
 41
 42            if let Some(before_id) = before_id {
 43                condition = condition.add(notification::Column::Id.lt(before_id));
 44            }
 45
 46            let mut rows = notification::Entity::find()
 47                .filter(condition)
 48                .order_by_desc(notification::Column::Id)
 49                .limit(limit as u64)
 50                .stream(&*tx)
 51                .await?;
 52            while let Some(row) = rows.next().await {
 53                let row = row?;
 54                let kind = row.kind;
 55                if let Some(proto) = model_to_proto(self, row) {
 56                    result.push(proto);
 57                } else {
 58                    log::warn!("unknown notification kind {:?}", kind);
 59                }
 60            }
 61            result.reverse();
 62            Ok(result)
 63        })
 64        .await
 65    }
 66
 67    /// Create a notification. If `avoid_duplicates` is set to true, then avoid
 68    /// creating a new notification if the given recipient already has an
 69    /// unread notification with the given kind and entity id.
 70    pub async fn create_notification(
 71        &self,
 72        recipient_id: UserId,
 73        notification: Notification,
 74        avoid_duplicates: bool,
 75        tx: &DatabaseTransaction,
 76    ) -> Result<Option<(UserId, proto::Notification)>> {
 77        if avoid_duplicates {
 78            if self
 79                .find_notification(recipient_id, &notification, tx)
 80                .await?
 81                .is_some()
 82            {
 83                return Ok(None);
 84            }
 85        }
 86
 87        let proto = notification.to_proto();
 88        let kind = notification_kind_from_proto(self, &proto)?;
 89        let model = notification::ActiveModel {
 90            recipient_id: ActiveValue::Set(recipient_id),
 91            kind: ActiveValue::Set(kind),
 92            entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)),
 93            content: ActiveValue::Set(proto.content.clone()),
 94            ..Default::default()
 95        }
 96        .save(&*tx)
 97        .await?;
 98
 99        Ok(Some((
100            recipient_id,
101            proto::Notification {
102                id: model.id.as_ref().to_proto(),
103                kind: proto.kind,
104                timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
105                is_read: false,
106                response: None,
107                content: proto.content,
108                entity_id: proto.entity_id,
109            },
110        )))
111    }
112
113    /// Remove an unread notification with the given recipient, kind and
114    /// entity id.
115    pub async fn remove_notification(
116        &self,
117        recipient_id: UserId,
118        notification: Notification,
119        tx: &DatabaseTransaction,
120    ) -> Result<Option<NotificationId>> {
121        let id = self
122            .find_notification(recipient_id, &notification, tx)
123            .await?;
124        if let Some(id) = id {
125            notification::Entity::delete_by_id(id).exec(tx).await?;
126        }
127        Ok(id)
128    }
129
130    /// Populate the response for the notification with the given kind and
131    /// entity id.
132    pub async fn respond_to_notification(
133        &self,
134        recipient_id: UserId,
135        notification: &Notification,
136        response: bool,
137        tx: &DatabaseTransaction,
138    ) -> Result<Option<(UserId, proto::Notification)>> {
139        if let Some(id) = self
140            .find_notification(recipient_id, notification, tx)
141            .await?
142        {
143            let row = notification::Entity::update(notification::ActiveModel {
144                id: ActiveValue::Unchanged(id),
145                recipient_id: ActiveValue::Unchanged(recipient_id),
146                response: ActiveValue::Set(Some(response)),
147                is_read: ActiveValue::Set(true),
148                ..Default::default()
149            })
150            .exec(tx)
151            .await?;
152            Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification)))
153        } else {
154            Ok(None)
155        }
156    }
157
158    /// Find an unread notification by its recipient, kind and entity id.
159    async fn find_notification(
160        &self,
161        recipient_id: UserId,
162        notification: &Notification,
163        tx: &DatabaseTransaction,
164    ) -> Result<Option<NotificationId>> {
165        let proto = notification.to_proto();
166        let kind = notification_kind_from_proto(self, &proto)?;
167
168        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
169        enum QueryIds {
170            Id,
171        }
172
173        Ok(notification::Entity::find()
174            .select_only()
175            .column(notification::Column::Id)
176            .filter(
177                Condition::all()
178                    .add(notification::Column::RecipientId.eq(recipient_id))
179                    .add(notification::Column::IsRead.eq(false))
180                    .add(notification::Column::Kind.eq(kind))
181                    .add(if proto.entity_id.is_some() {
182                        notification::Column::EntityId.eq(proto.entity_id)
183                    } else {
184                        notification::Column::EntityId.is_null()
185                    }),
186            )
187            .into_values::<_, QueryIds>()
188            .one(&*tx)
189            .await?)
190    }
191}
192
193fn model_to_proto(this: &Database, row: notification::Model) -> Option<proto::Notification> {
194    let kind = this.notification_kinds_by_id.get(&row.kind)?;
195    Some(proto::Notification {
196        id: row.id.to_proto(),
197        kind: kind.to_string(),
198        timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
199        is_read: row.is_read,
200        response: row.response,
201        content: row.content,
202        entity_id: row.entity_id.map(|id| id as u64),
203    })
204}
205
206fn notification_kind_from_proto(
207    this: &Database,
208    proto: &proto::Notification,
209) -> Result<NotificationKindId> {
210    Ok(this
211        .notification_kinds_by_name
212        .get(&proto.kind)
213        .copied()
214        .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?)
215}