Fix mention notifications are not updated after message change and not removed after a message is deleted (#9847)

Remco Smits and Bennet Bo Fenner created

@ConradIrwin This is a followup for #9035 as agreed.

Release Notes:

- Fixed mention notifications are updated when channel message is
updated. And mention notifications are removed when message is removed.

---------

Co-authored-by: Bennet Bo Fenner <53836821+bennetbo@users.noreply.github.com>

Change summary

crates/collab/src/db.rs                          |   2 
crates/collab/src/db/queries/messages.rs         |  78 +++++++++
crates/collab/src/db/queries/notifications.rs    |  19 +-
crates/collab/src/rpc.rs                         |  48 +++++
crates/collab/src/tests/channel_message_tests.rs | 130 +++++++++++++++++
crates/notifications/src/notification_store.rs   |  35 ++++
crates/rpc/proto/zed.proto                       |   8 
crates/rpc/src/proto.rs                          |   1 
8 files changed, 297 insertions(+), 24 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -460,6 +460,8 @@ pub struct UpdatedChannelMessage {
     pub notifications: NotificationBatch,
     pub reply_to_message_id: Option<MessageId>,
     pub timestamp: PrimitiveDateTime,
+    pub deleted_mention_notification_ids: Vec<NotificationId>,
+    pub updated_mention_notifications: Vec<rpc::proto::Notification>,
 }
 
 #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]

crates/collab/src/db/queries/messages.rs 🔗

@@ -1,7 +1,8 @@
 use super::*;
 use rpc::Notification;
-use sea_orm::TryInsertResult;
+use sea_orm::{SelectColumns, TryInsertResult};
 use time::OffsetDateTime;
+use util::ResultExt;
 
 impl Database {
     /// Inserts a record representing a user joining the chat for a given channel.
@@ -480,13 +481,20 @@ impl Database {
         Ok(results)
     }
 
+    fn get_notification_kind_id_by_name(&self, notification_kind: &str) -> Option<i32> {
+        self.notification_kinds_by_id
+            .iter()
+            .find(|(_, kind)| **kind == notification_kind)
+            .map(|kind| kind.0 .0)
+    }
+
     /// Removes the channel message with the given ID.
     pub async fn remove_channel_message(
         &self,
         channel_id: ChannelId,
         message_id: MessageId,
         user_id: UserId,
-    ) -> Result<Vec<ConnectionId>> {
+    ) -> Result<(Vec<ConnectionId>, Vec<NotificationId>)> {
         self.transaction(|tx| async move {
             let mut rows = channel_chat_participant::Entity::find()
                 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
@@ -531,7 +539,29 @@ impl Database {
                 }
             }
 
-            Ok(participant_connection_ids)
+            let notification_kind_id =
+                self.get_notification_kind_id_by_name("ChannelMessageMention");
+
+            let existing_notifications = notification::Entity::find()
+                .filter(notification::Column::EntityId.eq(message_id))
+                .filter(notification::Column::Kind.eq(notification_kind_id))
+                .select_column(notification::Column::Id)
+                .all(&*tx)
+                .await?;
+
+            let existing_notification_ids = existing_notifications
+                .into_iter()
+                .map(|notification| notification.id)
+                .collect();
+
+            // remove all the mention notifications for this message
+            notification::Entity::delete_many()
+                .filter(notification::Column::EntityId.eq(message_id))
+                .filter(notification::Column::Kind.eq(notification_kind_id))
+                .exec(&*tx)
+                .await?;
+
+            Ok((participant_connection_ids, existing_notification_ids))
         })
         .await
     }
@@ -629,14 +659,44 @@ impl Database {
                     .await?;
             }
 
-            let mut mentioned_user_ids = mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
+            let mut update_mention_user_ids = HashSet::default();
+            let mut new_mention_user_ids =
+                mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
             // Filter out users that were mentioned before
-            for mention in old_mentions {
-                mentioned_user_ids.remove(&mention.user_id.to_proto());
+            for mention in &old_mentions {
+                if new_mention_user_ids.contains(&mention.user_id.to_proto()) {
+                    update_mention_user_ids.insert(mention.user_id.to_proto());
+                }
+
+                new_mention_user_ids.remove(&mention.user_id.to_proto());
+            }
+
+            let notification_kind_id =
+                self.get_notification_kind_id_by_name("ChannelMessageMention");
+
+            let existing_notifications = notification::Entity::find()
+                .filter(notification::Column::EntityId.eq(message_id))
+                .filter(notification::Column::Kind.eq(notification_kind_id))
+                .all(&*tx)
+                .await?;
+
+            // determine which notifications should be updated or deleted
+            let mut deleted_notification_ids = HashSet::default();
+            let mut updated_mention_notifications = Vec::new();
+            for notification in existing_notifications {
+                if update_mention_user_ids.contains(&notification.recipient_id.to_proto()) {
+                    if let Some(notification) =
+                        self::notifications::model_to_proto(self, notification).log_err()
+                    {
+                        updated_mention_notifications.push(notification);
+                    }
+                } else {
+                    deleted_notification_ids.insert(notification.id);
+                }
             }
 
             let mut notifications = Vec::new();
-            for mentioned_user in mentioned_user_ids {
+            for mentioned_user in new_mention_user_ids {
                 notifications.extend(
                     self.create_notification(
                         UserId::from_proto(mentioned_user),
@@ -658,6 +718,10 @@ impl Database {
                 notifications,
                 reply_to_message_id: channel_message.reply_to_message_id,
                 timestamp: channel_message.sent_at,
+                deleted_mention_notification_ids: deleted_notification_ids
+                    .into_iter()
+                    .collect::<Vec<_>>(),
+                updated_mention_notifications,
             })
         })
         .await

crates/collab/src/db/queries/notifications.rs 🔗

@@ -1,5 +1,6 @@
 use super::*;
 use rpc::Notification;
+use util::ResultExt;
 
 impl Database {
     /// Initializes the different kinds of notifications by upserting records for them.
@@ -53,11 +54,8 @@ impl Database {
                 .await?;
             while let Some(row) = rows.next().await {
                 let row = row?;
-                let kind = row.kind;
-                if let Some(proto) = model_to_proto(self, row) {
+                if let Some(proto) = model_to_proto(self, row).log_err() {
                     result.push(proto);
-                } else {
-                    log::warn!("unknown notification kind {:?}", kind);
                 }
             }
             result.reverse();
@@ -200,7 +198,9 @@ impl Database {
             })
             .exec(tx)
             .await?;
-            Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification)))
+            Ok(model_to_proto(self, row)
+                .map(|notification| (recipient_id, notification))
+                .ok())
         } else {
             Ok(None)
         }
@@ -241,9 +241,12 @@ impl Database {
     }
 }
 
-fn model_to_proto(this: &Database, row: notification::Model) -> Option<proto::Notification> {
-    let kind = this.notification_kinds_by_id.get(&row.kind)?;
-    Some(proto::Notification {
+pub fn model_to_proto(this: &Database, row: notification::Model) -> Result<proto::Notification> {
+    let kind = this
+        .notification_kinds_by_id
+        .get(&row.kind)
+        .ok_or_else(|| anyhow!("Unknown notification kind"))?;
+    Ok(proto::Notification {
         id: row.id.to_proto(),
         kind: kind.to_string(),
         timestamp: row.created_at.assume_utc().unix_timestamp() as u64,

crates/collab/src/rpc.rs 🔗

@@ -3388,14 +3388,30 @@ async fn remove_channel_message(
 ) -> Result<()> {
     let channel_id = ChannelId::from_proto(request.channel_id);
     let message_id = MessageId::from_proto(request.message_id);
-    let connection_ids = session
+    let (connection_ids, existing_notification_ids) = session
         .db()
         .await
         .remove_channel_message(channel_id, message_id, session.user_id())
         .await?;
-    broadcast(Some(session.connection_id), connection_ids, |connection| {
-        session.peer.send(connection, request.clone())
-    });
+
+    broadcast(
+        Some(session.connection_id),
+        connection_ids,
+        move |connection| {
+            session.peer.send(connection, request.clone())?;
+
+            for notification_id in &existing_notification_ids {
+                session.peer.send(
+                    connection,
+                    proto::DeleteNotification {
+                        notification_id: (*notification_id).to_proto(),
+                    },
+                )?;
+            }
+
+            Ok(())
+        },
+    );
     response.send(proto::Ack {})?;
     Ok(())
 }
@@ -3414,6 +3430,8 @@ async fn update_channel_message(
         notifications,
         reply_to_message_id,
         timestamp,
+        deleted_mention_notification_ids,
+        updated_mention_notifications,
     } = session
         .db()
         .await
@@ -3456,7 +3474,27 @@ async fn update_channel_message(
                     channel_id: channel_id.to_proto(),
                     message: Some(message.clone()),
                 },
-            )
+            )?;
+
+            for notification_id in &deleted_mention_notification_ids {
+                session.peer.send(
+                    connection,
+                    proto::DeleteNotification {
+                        notification_id: (*notification_id).to_proto(),
+                    },
+                )?;
+            }
+
+            for notification in &updated_mention_notifications {
+                session.peer.send(
+                    connection,
+                    proto::UpdateNotification {
+                        notification: Some(notification.clone()),
+                    },
+                )?;
+            }
+
+            Ok(())
         },
     );
 

crates/collab/src/tests/channel_message_tests.rs 🔗

@@ -222,8 +222,18 @@ async fn test_remove_channel_message(
         .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
         .await
         .unwrap();
-    channel_chat_a
-        .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap())
+    let msg_id_2 = channel_chat_a
+        .update(cx_a, |c, cx| {
+            c.send_message(
+                MessageParams {
+                    text: "two @user_b".to_string(),
+                    mentions: vec![(4..12, client_b.id())],
+                    reply_to_message_id: None,
+                },
+                cx,
+            )
+            .unwrap()
+        })
         .await
         .unwrap();
     channel_chat_a
@@ -233,10 +243,24 @@ async fn test_remove_channel_message(
 
     // Clients A and B see all of the messages.
     executor.run_until_parked();
-    let expected_messages = &["one", "two", "three"];
+    let expected_messages = &["one", "two @user_b", "three"];
     assert_messages(&channel_chat_a, expected_messages, cx_a);
     assert_messages(&channel_chat_b, expected_messages, cx_b);
 
+    // Ensure that client B received a notification for the mention.
+    client_b.notification_store().read_with(cx_b, |store, _| {
+        assert_eq!(store.notification_count(), 2);
+        let entry = store.notification_at(0).unwrap();
+        assert_eq!(
+            entry.notification,
+            Notification::ChannelMessageMention {
+                message_id: msg_id_2,
+                sender_id: client_a.id(),
+                channel_id: channel_id.0,
+            }
+        );
+    });
+
     // Client A deletes one of their messages.
     channel_chat_a
         .update(cx_a, |c, cx| {
@@ -261,6 +285,13 @@ async fn test_remove_channel_message(
         .await
         .unwrap();
     assert_messages(&channel_chat_c, expected_messages, cx_c);
+
+    // Ensure we remove the notifications when the message is removed
+    client_b.notification_store().read_with(cx_b, |store, _| {
+        // First notification is the channel invitation, second would be the mention
+        // notification, which should now be removed.
+        assert_eq!(store.notification_count(), 1);
+    });
 }
 
 #[track_caller]
@@ -598,4 +629,97 @@ async fn test_chat_editing(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext)
             }
         );
     });
+
+    // Test update message and keep the mention and check that the body is updated correctly
+
+    channel_chat_a
+        .update(cx_a, |c, cx| {
+            c.update_message(
+                msg_id,
+                MessageParams {
+                    text: "Updated body v2 including a mention for @user_b".into(),
+                    reply_to_message_id: None,
+                    mentions: vec![(37..45, client_b.id())],
+                },
+                cx,
+            )
+            .unwrap()
+        })
+        .await
+        .unwrap();
+
+    cx_a.run_until_parked();
+    cx_b.run_until_parked();
+
+    channel_chat_a.update(cx_a, |channel_chat, _| {
+        assert_eq!(
+            channel_chat.find_loaded_message(msg_id).unwrap().body,
+            "Updated body v2 including a mention for @user_b",
+        )
+    });
+    channel_chat_b.update(cx_b, |channel_chat, _| {
+        assert_eq!(
+            channel_chat.find_loaded_message(msg_id).unwrap().body,
+            "Updated body v2 including a mention for @user_b",
+        )
+    });
+
+    client_b.notification_store().read_with(cx_b, |store, _| {
+        let message = store.channel_message_for_id(msg_id);
+        assert!(message.is_some());
+        assert_eq!(
+            message.unwrap().body,
+            "Updated body v2 including a mention for @user_b"
+        );
+        assert_eq!(store.notification_count(), 2);
+        let entry = store.notification_at(0).unwrap();
+        assert_eq!(
+            entry.notification,
+            Notification::ChannelMessageMention {
+                message_id: msg_id,
+                sender_id: client_a.id(),
+                channel_id: channel_id.0,
+            }
+        );
+    });
+
+    // If we remove a mention from a message the corresponding mention notification
+    // should also be removed.
+
+    channel_chat_a
+        .update(cx_a, |c, cx| {
+            c.update_message(
+                msg_id,
+                MessageParams {
+                    text: "Updated body without a mention".into(),
+                    reply_to_message_id: None,
+                    mentions: vec![],
+                },
+                cx,
+            )
+            .unwrap()
+        })
+        .await
+        .unwrap();
+
+    cx_a.run_until_parked();
+    cx_b.run_until_parked();
+
+    channel_chat_a.update(cx_a, |channel_chat, _| {
+        assert_eq!(
+            channel_chat.find_loaded_message(msg_id).unwrap().body,
+            "Updated body without a mention",
+        )
+    });
+    channel_chat_b.update(cx_b, |channel_chat, _| {
+        assert_eq!(
+            channel_chat.find_loaded_message(msg_id).unwrap().body,
+            "Updated body without a mention",
+        )
+    });
+    client_b.notification_store().read_with(cx_b, |store, _| {
+        // First notification is the channel invitation, second would be the mention
+        // notification, which should now be removed.
+        assert_eq!(store.notification_count(), 1);
+    });
 }

crates/notifications/src/notification_store.rs 🔗

@@ -114,6 +114,7 @@ impl NotificationStore {
             _subscriptions: vec![
                 client.add_message_handler(cx.weak_model(), Self::handle_new_notification),
                 client.add_message_handler(cx.weak_model(), Self::handle_delete_notification),
+                client.add_message_handler(cx.weak_model(), Self::handle_update_notification),
             ],
             user_store,
             client,
@@ -236,6 +237,40 @@ impl NotificationStore {
         })?
     }
 
+    async fn handle_update_notification(
+        this: Model<Self>,
+        envelope: TypedEnvelope<proto::UpdateNotification>,
+        _: Arc<Client>,
+        mut cx: AsyncAppContext,
+    ) -> Result<()> {
+        this.update(&mut cx, |this, cx| {
+            if let Some(notification) = envelope.payload.notification {
+                if let Some(rpc::Notification::ChannelMessageMention {
+                    message_id,
+                    sender_id: _,
+                    channel_id: _,
+                }) = Notification::from_proto(&notification)
+                {
+                    let fetch_message_task = this.channel_store.update(cx, |this, cx| {
+                        this.fetch_channel_messages(vec![message_id], cx)
+                    });
+
+                    cx.spawn(|this, mut cx| async move {
+                        let messages = fetch_message_task.await?;
+                        this.update(&mut cx, move |this, cx| {
+                            for message in messages {
+                                this.channel_messages.insert(message_id, message);
+                            }
+                            cx.notify();
+                        })
+                    })
+                    .detach_and_log_err(cx)
+                }
+            }
+            Ok(())
+        })?
+    }
+
     async fn add_notifications(
         this: Model<Self>,
         notifications: Vec<proto::Notification>,

crates/rpc/proto/zed.proto 🔗

@@ -208,7 +208,9 @@ message Envelope {
         ChannelMessageUpdate channel_message_update = 171;
 
         BlameBuffer blame_buffer = 172;
-        BlameBufferResponse blame_buffer_response = 173; // Current max
+        BlameBufferResponse blame_buffer_response = 173;
+        
+        UpdateNotification update_notification = 174;  // current max
     }
 
     reserved 158 to 161;
@@ -1715,6 +1717,10 @@ message DeleteNotification {
     uint64 notification_id = 1;
 }
 
+message UpdateNotification {
+    Notification notification = 1;
+}
+
 message MarkNotificationRead {
     uint64 notification_id = 1;
 }

crates/rpc/src/proto.rs 🔗

@@ -163,6 +163,7 @@ messages!(
     (DeclineCall, Foreground),
     (DeleteChannel, Foreground),
     (DeleteNotification, Foreground),
+    (UpdateNotification, Foreground),
     (DeleteProjectEntry, Foreground),
     (EndStream, Foreground),
     (Error, Foreground),