notifications.rs

  1use super::*;
  2use rpc::Notification;
  3
  4impl Database {
  5    pub async fn initialize_notification_enum(&mut self) -> Result<()> {
  6        notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map(
  7            |kind| notification_kind::ActiveModel {
  8                name: ActiveValue::Set(kind.to_string()),
  9                ..Default::default()
 10            },
 11        ))
 12        .on_conflict(OnConflict::new().do_nothing().to_owned())
 13        .exec_without_returning(&self.pool)
 14        .await?;
 15
 16        let mut rows = notification_kind::Entity::find().stream(&self.pool).await?;
 17        while let Some(row) = rows.next().await {
 18            let row = row?;
 19            self.notification_kinds_by_name.insert(row.name, row.id);
 20        }
 21
 22        for name in Notification::all_variant_names() {
 23            if let Some(id) = self.notification_kinds_by_name.get(*name).copied() {
 24                self.notification_kinds_by_id.insert(id, name);
 25            }
 26        }
 27
 28        Ok(())
 29    }
 30
 31    pub async fn get_notifications(
 32        &self,
 33        recipient_id: UserId,
 34        limit: usize,
 35        before_id: Option<NotificationId>,
 36    ) -> Result<Vec<proto::Notification>> {
 37        self.transaction(|tx| async move {
 38            let mut result = Vec::new();
 39            let mut condition =
 40                Condition::all().add(notification::Column::RecipientId.eq(recipient_id));
 41
 42            if let Some(before_id) = before_id {
 43                condition = condition.add(notification::Column::Id.lt(before_id));
 44            }
 45
 46            let mut rows = notification::Entity::find()
 47                .filter(condition)
 48                .order_by_desc(notification::Column::Id)
 49                .limit(limit as u64)
 50                .stream(&*tx)
 51                .await?;
 52            while let Some(row) = rows.next().await {
 53                let row = row?;
 54                let kind = row.kind;
 55                if let Some(proto) = self.model_to_proto(row) {
 56                    result.push(proto);
 57                } else {
 58                    log::warn!("unknown notification kind {:?}", kind);
 59                }
 60            }
 61            result.reverse();
 62            Ok(result)
 63        })
 64        .await
 65    }
 66
 67    pub async fn create_notification(
 68        &self,
 69        recipient_id: UserId,
 70        notification: Notification,
 71        avoid_duplicates: bool,
 72        tx: &DatabaseTransaction,
 73    ) -> Result<Option<proto::Notification>> {
 74        if avoid_duplicates {
 75            if self
 76                .find_notification(recipient_id, &notification, tx)
 77                .await?
 78                .is_some()
 79            {
 80                return Ok(None);
 81            }
 82        }
 83
 84        let notification_proto = notification.to_proto();
 85        let kind = *self
 86            .notification_kinds_by_name
 87            .get(&notification_proto.kind)
 88            .ok_or_else(|| anyhow!("invalid notification kind {:?}", notification_proto.kind))?;
 89        let actor_id = notification_proto.actor_id.map(|id| UserId::from_proto(id));
 90
 91        let model = notification::ActiveModel {
 92            recipient_id: ActiveValue::Set(recipient_id),
 93            kind: ActiveValue::Set(kind),
 94            content: ActiveValue::Set(notification_proto.content.clone()),
 95            actor_id: ActiveValue::Set(actor_id),
 96            is_read: ActiveValue::NotSet,
 97            created_at: ActiveValue::NotSet,
 98            id: ActiveValue::NotSet,
 99        }
100        .save(&*tx)
101        .await?;
102
103        Ok(Some(proto::Notification {
104            id: model.id.as_ref().to_proto(),
105            kind: notification_proto.kind,
106            timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
107            is_read: false,
108            content: notification_proto.content,
109            actor_id: notification_proto.actor_id,
110        }))
111    }
112
113    pub async fn remove_notification(
114        &self,
115        recipient_id: UserId,
116        notification: Notification,
117        tx: &DatabaseTransaction,
118    ) -> Result<Option<NotificationId>> {
119        let id = self
120            .find_notification(recipient_id, &notification, tx)
121            .await?;
122        if let Some(id) = id {
123            notification::Entity::delete_by_id(id).exec(tx).await?;
124        }
125        Ok(id)
126    }
127
128    pub async fn find_notification(
129        &self,
130        recipient_id: UserId,
131        notification: &Notification,
132        tx: &DatabaseTransaction,
133    ) -> Result<Option<NotificationId>> {
134        let proto = notification.to_proto();
135        let kind = *self
136            .notification_kinds_by_name
137            .get(&proto.kind)
138            .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?;
139        let mut rows = notification::Entity::find()
140            .filter(
141                Condition::all()
142                    .add(notification::Column::RecipientId.eq(recipient_id))
143                    .add(notification::Column::IsRead.eq(false))
144                    .add(notification::Column::Kind.eq(kind))
145                    .add(notification::Column::ActorId.eq(proto.actor_id)),
146            )
147            .stream(&*tx)
148            .await?;
149
150        // Don't rely on the JSON serialization being identical, in case the
151        // notification type is changed in backward-compatible ways.
152        while let Some(row) = rows.next().await {
153            let row = row?;
154            let id = row.id;
155            if let Some(proto) = self.model_to_proto(row) {
156                if let Some(existing) = Notification::from_proto(&proto) {
157                    if existing == *notification {
158                        return Ok(Some(id));
159                    }
160                }
161            }
162        }
163
164        Ok(None)
165    }
166
167    fn model_to_proto(&self, row: notification::Model) -> Option<proto::Notification> {
168        let kind = self.notification_kinds_by_id.get(&row.kind)?;
169        Some(proto::Notification {
170            id: row.id.to_proto(),
171            kind: kind.to_string(),
172            timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
173            is_read: row.is_read,
174            content: row.content,
175            actor_id: row.actor_id.map(|id| id.to_proto()),
176        })
177    }
178}