Generalize notifications' actor id to entity id

Max Brunsfeld created

This way, we can retrieve channel invite notifications when
responding to the invites.

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql   |  2 
crates/collab/migrations/20231004130100_create_notifications.sql |  2 
crates/collab/src/db.rs                                          |  2 
crates/collab/src/db/queries/channels.rs                         |  7 
crates/collab/src/db/queries/contacts.rs                         |  8 
crates/collab/src/db/queries/notifications.rs                    | 91 +
crates/collab/src/db/tables/notification.rs                      |  2 
crates/collab/src/db/tests.rs                                    |  4 
crates/collab/src/lib.rs                                         |  2 
crates/collab_ui/src/notification_panel.rs                       | 37 
crates/notifications/src/notification_store.rs                   |  6 
crates/rpc/proto/zed.proto                                       |  4 
crates/rpc/src/notification.rs                                   | 46 
13 files changed, 115 insertions(+), 98 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -324,8 +324,8 @@ CREATE TABLE "notifications" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP,
     "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
-    "actor_id" INTEGER REFERENCES users (id) ON DELETE CASCADE,
     "kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
+    "entity_id" INTEGER,
     "content" TEXT,
     "is_read" BOOLEAN NOT NULL DEFAULT FALSE,
     "response" BOOLEAN

crates/collab/migrations/20231004130100_create_notifications.sql 🔗

@@ -9,8 +9,8 @@ CREATE TABLE notifications (
     "id" SERIAL PRIMARY KEY,
     "created_at" TIMESTAMP NOT NULL DEFAULT now(),
     "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
-    "actor_id" INTEGER REFERENCES users (id) ON DELETE CASCADE,
     "kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
+    "entity_id" INTEGER,
     "content" TEXT,
     "is_read" BOOLEAN NOT NULL DEFAULT FALSE,
     "response" BOOLEAN

crates/collab/src/db.rs 🔗

@@ -125,7 +125,7 @@ impl Database {
     }
 
     pub async fn initialize_static_data(&mut self) -> Result<()> {
-        self.initialize_notification_enum().await?;
+        self.initialize_notification_kinds().await?;
         Ok(())
     }
 

crates/collab/src/db/queries/channels.rs 🔗

@@ -166,6 +166,11 @@ impl Database {
             self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
                 .await?;
 
+            let channel = channel::Entity::find_by_id(channel_id)
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such channel"))?;
+
             channel_member::ActiveModel {
                 channel_id: ActiveValue::Set(channel_id),
                 user_id: ActiveValue::Set(invitee_id),
@@ -181,6 +186,7 @@ impl Database {
                     invitee_id,
                     rpc::Notification::ChannelInvitation {
                         channel_id: channel_id.to_proto(),
+                        channel_name: channel.name,
                     },
                     true,
                     &*tx,
@@ -269,6 +275,7 @@ impl Database {
                     user_id,
                     &rpc::Notification::ChannelInvitation {
                         channel_id: channel_id.to_proto(),
+                        channel_name: Default::default(),
                     },
                     accept,
                     &*tx,

crates/collab/src/db/queries/contacts.rs 🔗

@@ -168,7 +168,7 @@ impl Database {
                 .create_notification(
                     receiver_id,
                     rpc::Notification::ContactRequest {
-                        actor_id: sender_id.to_proto(),
+                        sender_id: sender_id.to_proto(),
                     },
                     true,
                     &*tx,
@@ -219,7 +219,7 @@ impl Database {
                     .remove_notification(
                         responder_id,
                         rpc::Notification::ContactRequest {
-                            actor_id: requester_id.to_proto(),
+                            sender_id: requester_id.to_proto(),
                         },
                         &*tx,
                     )
@@ -324,7 +324,7 @@ impl Database {
                 self.respond_to_notification(
                     responder_id,
                     &rpc::Notification::ContactRequest {
-                        actor_id: requester_id.to_proto(),
+                        sender_id: requester_id.to_proto(),
                     },
                     accept,
                     &*tx,
@@ -337,7 +337,7 @@ impl Database {
                     self.create_notification(
                         requester_id,
                         rpc::Notification::ContactRequestAccepted {
-                            actor_id: responder_id.to_proto(),
+                            responder_id: responder_id.to_proto(),
                         },
                         true,
                         &*tx,

crates/collab/src/db/queries/notifications.rs 🔗

@@ -2,7 +2,7 @@ use super::*;
 use rpc::Notification;
 
 impl Database {
-    pub async fn initialize_notification_enum(&mut self) -> Result<()> {
+    pub async fn initialize_notification_kinds(&mut self) -> Result<()> {
         notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map(
             |kind| notification_kind::ActiveModel {
                 name: ActiveValue::Set(kind.to_string()),
@@ -64,6 +64,9 @@ impl Database {
         .await
     }
 
+    /// Create a notification. If `avoid_duplicates` is set to true, then avoid
+    /// creating a new notification if the given recipient already has an
+    /// unread notification with the given kind and entity id.
     pub async fn create_notification(
         &self,
         recipient_id: UserId,
@@ -81,22 +84,14 @@ impl Database {
             }
         }
 
-        let notification_proto = notification.to_proto();
-        let kind = *self
-            .notification_kinds_by_name
-            .get(&notification_proto.kind)
-            .ok_or_else(|| anyhow!("invalid notification kind {:?}", notification_proto.kind))?;
-        let actor_id = notification_proto.actor_id.map(|id| UserId::from_proto(id));
-
+        let proto = notification.to_proto();
+        let kind = notification_kind_from_proto(self, &proto)?;
         let model = notification::ActiveModel {
             recipient_id: ActiveValue::Set(recipient_id),
             kind: ActiveValue::Set(kind),
-            content: ActiveValue::Set(notification_proto.content.clone()),
-            actor_id: ActiveValue::Set(actor_id),
-            is_read: ActiveValue::NotSet,
-            response: ActiveValue::NotSet,
-            created_at: ActiveValue::NotSet,
-            id: ActiveValue::NotSet,
+            entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)),
+            content: ActiveValue::Set(proto.content.clone()),
+            ..Default::default()
         }
         .save(&*tx)
         .await?;
@@ -105,16 +100,18 @@ impl Database {
             recipient_id,
             proto::Notification {
                 id: model.id.as_ref().to_proto(),
-                kind: notification_proto.kind,
+                kind: proto.kind,
                 timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
                 is_read: false,
                 response: None,
-                content: notification_proto.content,
-                actor_id: notification_proto.actor_id,
+                content: proto.content,
+                entity_id: proto.entity_id,
             },
         )))
     }
 
+    /// Remove an unread notification with the given recipient, kind and
+    /// entity id.
     pub async fn remove_notification(
         &self,
         recipient_id: UserId,
@@ -130,6 +127,8 @@ impl Database {
         Ok(id)
     }
 
+    /// Populate the response for the notification with the given kind and
+    /// entity id.
     pub async fn respond_to_notification(
         &self,
         recipient_id: UserId,
@@ -156,47 +155,38 @@ impl Database {
         }
     }
 
-    pub async fn find_notification(
+    /// Find an unread notification by its recipient, kind and entity id.
+    async fn find_notification(
         &self,
         recipient_id: UserId,
         notification: &Notification,
         tx: &DatabaseTransaction,
     ) -> Result<Option<NotificationId>> {
         let proto = notification.to_proto();
-        let kind = *self
-            .notification_kinds_by_name
-            .get(&proto.kind)
-            .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?;
-        let mut rows = notification::Entity::find()
+        let kind = notification_kind_from_proto(self, &proto)?;
+
+        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+        enum QueryIds {
+            Id,
+        }
+
+        Ok(notification::Entity::find()
+            .select_only()
+            .column(notification::Column::Id)
             .filter(
                 Condition::all()
                     .add(notification::Column::RecipientId.eq(recipient_id))
                     .add(notification::Column::IsRead.eq(false))
                     .add(notification::Column::Kind.eq(kind))
-                    .add(if proto.actor_id.is_some() {
-                        notification::Column::ActorId.eq(proto.actor_id)
+                    .add(if proto.entity_id.is_some() {
+                        notification::Column::EntityId.eq(proto.entity_id)
                     } else {
-                        notification::Column::ActorId.is_null()
+                        notification::Column::EntityId.is_null()
                     }),
             )
-            .stream(&*tx)
-            .await?;
-
-        // Don't rely on the JSON serialization being identical, in case the
-        // notification type is changed in backward-compatible ways.
-        while let Some(row) = rows.next().await {
-            let row = row?;
-            let id = row.id;
-            if let Some(proto) = model_to_proto(self, row) {
-                if let Some(existing) = Notification::from_proto(&proto) {
-                    if existing == *notification {
-                        return Ok(Some(id));
-                    }
-                }
-            }
-        }
-
-        Ok(None)
+            .into_values::<_, QueryIds>()
+            .one(&*tx)
+            .await?)
     }
 }
 
@@ -209,6 +199,17 @@ fn model_to_proto(this: &Database, row: notification::Model) -> Option<proto::No
         is_read: row.is_read,
         response: row.response,
         content: row.content,
-        actor_id: row.actor_id.map(|id| id.to_proto()),
+        entity_id: row.entity_id.map(|id| id as u64),
     })
 }
+
+fn notification_kind_from_proto(
+    this: &Database,
+    proto: &proto::Notification,
+) -> Result<NotificationKindId> {
+    Ok(this
+        .notification_kinds_by_name
+        .get(&proto.kind)
+        .copied()
+        .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?)
+}

crates/collab/src/db/tables/notification.rs 🔗

@@ -9,8 +9,8 @@ pub struct Model {
     pub id: NotificationId,
     pub created_at: PrimitiveDateTime,
     pub recipient_id: UserId,
-    pub actor_id: Option<UserId>,
     pub kind: NotificationKindId,
+    pub entity_id: Option<i32>,
     pub content: String,
     pub is_read: bool,
     pub response: Option<bool>,

crates/collab/src/db/tests.rs 🔗

@@ -45,7 +45,7 @@ impl TestDb {
                 ))
                 .await
                 .unwrap();
-            db.initialize_notification_enum().await.unwrap();
+            db.initialize_notification_kinds().await.unwrap();
             db
         });
 
@@ -85,7 +85,7 @@ impl TestDb {
                 .unwrap();
             let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
             db.migrate(Path::new(migrations_path), false).await.unwrap();
-            db.initialize_notification_enum().await.unwrap();
+            db.initialize_notification_kinds().await.unwrap();
             db
         });
 

crates/collab/src/lib.rs 🔗

@@ -120,7 +120,7 @@ impl AppState {
         let mut db_options = db::ConnectOptions::new(config.database_url.clone());
         db_options.max_connections(config.database_max_connections);
         let mut db = Database::new(db_options, Executor::Production).await?;
-        db.initialize_notification_enum().await?;
+        db.initialize_notification_kinds().await?;
 
         let live_kit_client = if let Some(((server, key), secret)) = config
             .live_kit_server

crates/collab_ui/src/notification_panel.rs 🔗

@@ -192,39 +192,34 @@ impl NotificationPanel {
         let actor;
         let needs_acceptance;
         match notification {
-            Notification::ContactRequest { actor_id } => {
-                let requester = user_store.get_cached_user(actor_id)?;
+            Notification::ContactRequest { sender_id } => {
+                let requester = user_store.get_cached_user(sender_id)?;
                 icon = "icons/plus.svg";
                 text = format!("{} wants to add you as a contact", requester.github_login);
                 needs_acceptance = true;
                 actor = Some(requester);
             }
-            Notification::ContactRequestAccepted { actor_id } => {
-                let responder = user_store.get_cached_user(actor_id)?;
+            Notification::ContactRequestAccepted { responder_id } => {
+                let responder = user_store.get_cached_user(responder_id)?;
                 icon = "icons/plus.svg";
                 text = format!("{} accepted your contact invite", responder.github_login);
                 needs_acceptance = false;
                 actor = Some(responder);
             }
-            Notification::ChannelInvitation { channel_id } => {
+            Notification::ChannelInvitation {
+                ref channel_name, ..
+            } => {
                 actor = None;
-                let channel = channel_store.channel_for_id(channel_id).or_else(|| {
-                    channel_store
-                        .channel_invitations()
-                        .iter()
-                        .find(|c| c.id == channel_id)
-                })?;
-
                 icon = "icons/hash.svg";
-                text = format!("you were invited to join the #{} channel", channel.name);
+                text = format!("you were invited to join the #{channel_name} channel");
                 needs_acceptance = true;
             }
             Notification::ChannelMessageMention {
-                actor_id,
+                sender_id,
                 channel_id,
                 message_id,
             } => {
-                let sender = user_store.get_cached_user(actor_id)?;
+                let sender = user_store.get_cached_user(sender_id)?;
                 let channel = channel_store.channel_for_id(channel_id)?;
                 let message = notification_store.channel_message_for_id(message_id)?;
 
@@ -405,8 +400,12 @@ impl NotificationPanel {
     fn add_toast(&mut self, entry: &NotificationEntry, cx: &mut ViewContext<Self>) {
         let id = entry.id as usize;
         match entry.notification {
-            Notification::ContactRequest { actor_id }
-            | Notification::ContactRequestAccepted { actor_id } => {
+            Notification::ContactRequest {
+                sender_id: actor_id,
+            }
+            | Notification::ContactRequestAccepted {
+                responder_id: actor_id,
+            } => {
                 let user_store = self.user_store.clone();
                 let Some(user) = user_store.read(cx).get_cached_user(actor_id) else {
                     return;
@@ -452,7 +451,9 @@ impl NotificationPanel {
         cx: &mut ViewContext<Self>,
     ) {
         match notification {
-            Notification::ContactRequest { actor_id } => {
+            Notification::ContactRequest {
+                sender_id: actor_id,
+            } => {
                 self.user_store
                     .update(cx, |store, cx| {
                         store.respond_to_contact_request(actor_id, response, cx)

crates/notifications/src/notification_store.rs 🔗

@@ -199,17 +199,17 @@ impl NotificationStore {
             match entry.notification {
                 Notification::ChannelInvitation { .. } => {}
                 Notification::ContactRequest {
-                    actor_id: requester_id,
+                    sender_id: requester_id,
                 } => {
                     user_ids.push(requester_id);
                 }
                 Notification::ContactRequestAccepted {
-                    actor_id: contact_id,
+                    responder_id: contact_id,
                 } => {
                     user_ids.push(contact_id);
                 }
                 Notification::ChannelMessageMention {
-                    actor_id: sender_id,
+                    sender_id,
                     message_id,
                     ..
                 } => {

crates/rpc/proto/zed.proto 🔗

@@ -1599,8 +1599,8 @@ message Notification {
     uint64 id = 1;
     uint64 timestamp = 2;
     string kind = 3;
-    string content = 4;
-    optional uint64 actor_id = 5;
+    optional uint64 entity_id = 4;
+    string content = 5;
     bool is_read = 6;
     optional bool response = 7;
 }

crates/rpc/src/notification.rs 🔗

@@ -4,32 +4,37 @@ use serde_json::{map, Value};
 use strum::{EnumVariantNames, VariantNames as _};
 
 const KIND: &'static str = "kind";
-const ACTOR_ID: &'static str = "actor_id";
+const ENTITY_ID: &'static str = "entity_id";
 
-/// A notification that can be stored, associated with a given user.
+/// A notification that can be stored, associated with a given recipient.
 ///
 /// This struct is stored in the collab database as JSON, so it shouldn't be
 /// changed in a backward-incompatible way. For example, when renaming a
 /// variant, add a serde alias for the old name.
 ///
-/// When a notification is initiated by a user, use the `actor_id` field
-/// to store the user's id. This is value is stored in a dedicated column
-/// in the database, so it can be queried more efficiently.
+/// Most notification types have a special field which is aliased to
+/// `entity_id`. This field is stored in its own database column, and can
+/// be used to query the notification.
 #[derive(Debug, Clone, PartialEq, Eq, EnumVariantNames, Serialize, Deserialize)]
 #[serde(tag = "kind")]
 pub enum Notification {
     ContactRequest {
-        actor_id: u64,
+        #[serde(rename = "entity_id")]
+        sender_id: u64,
     },
     ContactRequestAccepted {
-        actor_id: u64,
+        #[serde(rename = "entity_id")]
+        responder_id: u64,
     },
     ChannelInvitation {
+        #[serde(rename = "entity_id")]
         channel_id: u64,
+        channel_name: String,
     },
     ChannelMessageMention {
-        actor_id: u64,
+        sender_id: u64,
         channel_id: u64,
+        #[serde(rename = "entity_id")]
         message_id: u64,
     },
 }
@@ -37,19 +42,19 @@ pub enum Notification {
 impl Notification {
     pub fn to_proto(&self) -> proto::Notification {
         let mut value = serde_json::to_value(self).unwrap();
-        let mut actor_id = None;
+        let mut entity_id = None;
         let value = value.as_object_mut().unwrap();
         let Some(Value::String(kind)) = value.remove(KIND) else {
             unreachable!("kind is the enum tag")
         };
-        if let map::Entry::Occupied(e) = value.entry(ACTOR_ID) {
+        if let map::Entry::Occupied(e) = value.entry(ENTITY_ID) {
             if e.get().is_u64() {
-                actor_id = e.remove().as_u64();
+                entity_id = e.remove().as_u64();
             }
         }
         proto::Notification {
             kind,
-            actor_id,
+            entity_id,
             content: serde_json::to_string(&value).unwrap(),
             ..Default::default()
         }
@@ -59,8 +64,8 @@ impl Notification {
         let mut value = serde_json::from_str::<Value>(&notification.content).ok()?;
         let object = value.as_object_mut()?;
         object.insert(KIND.into(), notification.kind.to_string().into());
-        if let Some(actor_id) = notification.actor_id {
-            object.insert(ACTOR_ID.into(), actor_id.into());
+        if let Some(entity_id) = notification.entity_id {
+            object.insert(ENTITY_ID.into(), entity_id.into());
         }
         serde_json::from_value(value).ok()
     }
@@ -74,11 +79,14 @@ impl Notification {
 fn test_notification() {
     // Notifications can be serialized and deserialized.
     for notification in [
-        Notification::ContactRequest { actor_id: 1 },
-        Notification::ContactRequestAccepted { actor_id: 2 },
-        Notification::ChannelInvitation { channel_id: 100 },
+        Notification::ContactRequest { sender_id: 1 },
+        Notification::ContactRequestAccepted { responder_id: 2 },
+        Notification::ChannelInvitation {
+            channel_id: 100,
+            channel_name: "the-channel".into(),
+        },
         Notification::ChannelMessageMention {
-            actor_id: 200,
+            sender_id: 200,
             channel_id: 30,
             message_id: 1,
         },
@@ -90,6 +98,6 @@ fn test_notification() {
 
     // When notifications are serialized, the `kind` and `actor_id` fields are
     // stored separately, and do not appear redundantly in the JSON.
-    let notification = Notification::ContactRequest { actor_id: 1 };
+    let notification = Notification::ContactRequest { sender_id: 1 };
     assert_eq!(notification.to_proto().content, "{}");
 }