Avoid creating duplicate invite notifications

Max Brunsfeld created

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql |  2 
crates/collab/src/db/queries/channels.rs                       | 13 
crates/collab/src/db/queries/contacts.rs                       |  8 
crates/collab/src/db/queries/notifications.rs                  | 83 ++-
crates/collab/src/rpc.rs                                       | 40 +
crates/collab_ui/src/notification_panel.rs                     |  7 
6 files changed, 109 insertions(+), 44 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -330,4 +330,4 @@ CREATE TABLE "notifications" (
     "content" TEXT
 );
 
-CREATE INDEX "index_notifications_on_recipient_id" ON "notifications" ("recipient_id");
+CREATE INDEX "index_notifications_on_recipient_id_is_read" ON "notifications" ("recipient_id", "is_read");

crates/collab/src/db/queries/channels.rs 🔗

@@ -161,7 +161,7 @@ impl Database {
         invitee_id: UserId,
         inviter_id: UserId,
         is_admin: bool,
-    ) -> Result<()> {
+    ) -> Result<Option<proto::Notification>> {
         self.transaction(move |tx| async move {
             self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
                 .await?;
@@ -176,7 +176,16 @@ impl Database {
             .insert(&*tx)
             .await?;
 
-            Ok(())
+            self.create_notification(
+                invitee_id,
+                rpc::Notification::ChannelInvitation {
+                    actor_id: inviter_id.to_proto(),
+                    channel_id: channel_id.to_proto(),
+                },
+                true,
+                &*tx,
+            )
+            .await
         })
         .await
     }

crates/collab/src/db/queries/contacts.rs 🔗

@@ -123,7 +123,7 @@ impl Database {
         &self,
         sender_id: UserId,
         receiver_id: UserId,
-    ) -> Result<proto::Notification> {
+    ) -> Result<Option<proto::Notification>> {
         self.transaction(|tx| async move {
             let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
                 (sender_id, receiver_id, true)
@@ -169,6 +169,7 @@ impl Database {
                 rpc::Notification::ContactRequest {
                     actor_id: sender_id.to_proto(),
                 },
+                true,
                 &*tx,
             )
             .await
@@ -212,7 +213,7 @@ impl Database {
             let mut deleted_notification_id = None;
             if !contact.accepted {
                 deleted_notification_id = self
-                    .delete_notification(
+                    .remove_notification(
                         responder_id,
                         rpc::Notification::ContactRequest {
                             actor_id: requester_id.to_proto(),
@@ -273,7 +274,7 @@ impl Database {
         responder_id: UserId,
         requester_id: UserId,
         accept: bool,
-    ) -> Result<proto::Notification> {
+    ) -> Result<Option<proto::Notification>> {
         self.transaction(|tx| async move {
             let (id_a, id_b, a_to_b) = if responder_id < requester_id {
                 (responder_id, requester_id, false)
@@ -320,6 +321,7 @@ impl Database {
                 rpc::Notification::ContactRequestAccepted {
                     actor_id: responder_id.to_proto(),
                 },
+                true,
                 &*tx,
             )
             .await

crates/collab/src/db/queries/notifications.rs 🔗

@@ -51,18 +51,12 @@ impl Database {
                 .await?;
             while let Some(row) = rows.next().await {
                 let row = row?;
-                let Some(kind) = self.notification_kinds_by_id.get(&row.kind) else {
-                    log::warn!("unknown notification kind {:?}", row.kind);
-                    continue;
-                };
-                result.push(proto::Notification {
-                    id: row.id.to_proto(),
-                    kind: kind.to_string(),
-                    timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
-                    is_read: row.is_read,
-                    content: row.content,
-                    actor_id: row.actor_id.map(|id| id.to_proto()),
-                });
+                let kind = row.kind;
+                if let Some(proto) = self.model_to_proto(row) {
+                    result.push(proto);
+                } else {
+                    log::warn!("unknown notification kind {:?}", kind);
+                }
             }
             result.reverse();
             Ok(result)
@@ -74,19 +68,48 @@ impl Database {
         &self,
         recipient_id: UserId,
         notification: Notification,
+        avoid_duplicates: bool,
         tx: &DatabaseTransaction,
-    ) -> Result<proto::Notification> {
-        let notification = notification.to_proto();
+    ) -> Result<Option<proto::Notification>> {
+        let notification_proto = notification.to_proto();
         let kind = *self
             .notification_kinds_by_name
-            .get(&notification.kind)
-            .ok_or_else(|| anyhow!("invalid notification kind {:?}", notification.kind))?;
+            .get(&notification_proto.kind)
+            .ok_or_else(|| anyhow!("invalid notification kind {:?}", notification_proto.kind))?;
+        let actor_id = notification_proto.actor_id.map(|id| UserId::from_proto(id));
+
+        if avoid_duplicates {
+            let mut existing_notifications = notification::Entity::find()
+                .filter(
+                    Condition::all()
+                        .add(notification::Column::RecipientId.eq(recipient_id))
+                        .add(notification::Column::IsRead.eq(false))
+                        .add(notification::Column::Kind.eq(kind))
+                        .add(notification::Column::ActorId.eq(actor_id)),
+                )
+                .stream(&*tx)
+                .await?;
+
+            // Check if this notification already exists. Don't rely on the
+            // JSON serialization being identical, in case the notification enum
+            // is changed in backward-compatible ways over time.
+            while let Some(row) = existing_notifications.next().await {
+                let row = row?;
+                if let Some(proto) = self.model_to_proto(row) {
+                    if let Some(existing) = Notification::from_proto(&proto) {
+                        if existing == notification {
+                            return Ok(None);
+                        }
+                    }
+                }
+            }
+        }
 
         let model = notification::ActiveModel {
             recipient_id: ActiveValue::Set(recipient_id),
             kind: ActiveValue::Set(kind),
-            content: ActiveValue::Set(notification.content.clone()),
-            actor_id: ActiveValue::Set(notification.actor_id.map(|id| UserId::from_proto(id))),
+            content: ActiveValue::Set(notification_proto.content.clone()),
+            actor_id: ActiveValue::Set(actor_id),
             is_read: ActiveValue::NotSet,
             created_at: ActiveValue::NotSet,
             id: ActiveValue::NotSet,
@@ -94,17 +117,17 @@ impl Database {
         .save(&*tx)
         .await?;
 
-        Ok(proto::Notification {
+        Ok(Some(proto::Notification {
             id: model.id.as_ref().to_proto(),
-            kind: notification.kind.to_string(),
+            kind: notification_proto.kind.to_string(),
             timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
             is_read: false,
-            content: notification.content,
-            actor_id: notification.actor_id,
-        })
+            content: notification_proto.content,
+            actor_id: notification_proto.actor_id,
+        }))
     }
 
-    pub async fn delete_notification(
+    pub async fn remove_notification(
         &self,
         recipient_id: UserId,
         notification: Notification,
@@ -133,4 +156,16 @@ impl Database {
         }
         Ok(notification.map(|notification| notification.id))
     }
+
+    fn model_to_proto(&self, row: notification::Model) -> Option<proto::Notification> {
+        let kind = self.notification_kinds_by_id.get(&row.kind)?;
+        Some(proto::Notification {
+            id: row.id.to_proto(),
+            kind: kind.to_string(),
+            timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
+            is_read: row.is_read,
+            content: row.content,
+            actor_id: row.actor_id.map(|id| id.to_proto()),
+        })
+    }
 }

crates/collab/src/rpc.rs 🔗

@@ -2097,12 +2097,14 @@ async fn request_contact(
         .user_connection_ids(responder_id)
     {
         session.peer.send(connection_id, update.clone())?;
-        session.peer.send(
-            connection_id,
-            proto::NewNotification {
-                notification: Some(notification.clone()),
-            },
-        )?;
+        if let Some(notification) = &notification {
+            session.peer.send(
+                connection_id,
+                proto::NewNotification {
+                    notification: Some(notification.clone()),
+                },
+            )?;
+        }
     }
 
     response.send(proto::Ack {})?;
@@ -2156,12 +2158,14 @@ async fn respond_to_contact_request(
             .push(responder_id.to_proto());
         for connection_id in pool.user_connection_ids(requester_id) {
             session.peer.send(connection_id, update.clone())?;
-            session.peer.send(
-                connection_id,
-                proto::NewNotification {
-                    notification: Some(notification.clone()),
-                },
-            )?;
+            if let Some(notification) = &notification {
+                session.peer.send(
+                    connection_id,
+                    proto::NewNotification {
+                        notification: Some(notification.clone()),
+                    },
+                )?;
+            }
         }
     }
 
@@ -2306,7 +2310,8 @@ async fn invite_channel_member(
     let db = session.db().await;
     let channel_id = ChannelId::from_proto(request.channel_id);
     let invitee_id = UserId::from_proto(request.user_id);
-    db.invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
+    let notification = db
+        .invite_channel_member(channel_id, invitee_id, session.user_id, request.admin)
         .await?;
 
     let (channel, _) = db
@@ -2319,12 +2324,21 @@ async fn invite_channel_member(
         id: channel.id.to_proto(),
         name: channel.name,
     });
+
     for connection_id in session
         .connection_pool()
         .await
         .user_connection_ids(invitee_id)
     {
         session.peer.send(connection_id, update.clone())?;
+        if let Some(notification) = &notification {
+            session.peer.send(
+                connection_id,
+                proto::NewNotification {
+                    notification: Some(notification.clone()),
+                },
+            )?;
+        }
     }
 
     response.send(proto::Ack {})?;

crates/collab_ui/src/notification_panel.rs 🔗

@@ -209,7 +209,12 @@ impl NotificationPanel {
                 channel_id,
             } => {
                 actor = user_store.get_cached_user(inviter_id)?;
-                let channel = channel_store.channel_for_id(channel_id)?;
+                let channel = channel_store.channel_for_id(channel_id).or_else(|| {
+                    channel_store
+                        .channel_invitations()
+                        .iter()
+                        .find(|c| c.id == channel_id)
+                })?;
 
                 icon = "icons/hash.svg";
                 text = format!(