Start work on storing notifications in the database

Max Brunsfeld created

Change summary

Cargo.lock                                                       |  23 
Cargo.toml                                                       |   1 
crates/collab/migrations.sqlite/20221109000000_test_schema.sql   |  19 
crates/collab/migrations/20231004130100_create_notifications.sql |  18 
crates/collab/src/db.rs                                          |   2 
crates/collab/src/db/ids.rs                                      |   1 
crates/collab/src/db/queries.rs                                  |   1 
crates/collab/src/db/queries/access_tokens.rs                    |   1 
crates/collab/src/db/queries/notifications.rs                    | 140 ++
crates/collab/src/db/tables.rs                                   |   2 
crates/collab/src/db/tables/notification.rs                      |  29 
crates/collab/src/db/tables/notification_kind.rs                 |  14 
crates/rpc/Cargo.toml                                            |   1 
crates/rpc/proto/zed.proto                                       |  41 
crates/rpc/src/notification.rs                                   | 105 +
crates/rpc/src/rpc.rs                                            |   3 
16 files changed, 399 insertions(+), 2 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -6403,6 +6403,7 @@ dependencies = [
  "serde_derive",
  "smol",
  "smol-timeout",
+ "strum",
  "tempdir",
  "tracing",
  "util",
@@ -6623,6 +6624,12 @@ dependencies = [
  "untrusted",
 ]
 
+[[package]]
+name = "rustversion"
+version = "1.0.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
+
 [[package]]
 name = "rustybuzz"
 version = "0.3.0"
@@ -7698,6 +7705,22 @@ name = "strum"
 version = "0.25.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
+dependencies = [
+ "strum_macros",
+]
+
+[[package]]
+name = "strum_macros"
+version = "0.25.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059"
+dependencies = [
+ "heck 0.4.1",
+ "proc-macro2",
+ "quote",
+ "rustversion",
+ "syn 2.0.37",
+]
 
 [[package]]
 name = "subtle"

Cargo.toml 🔗

@@ -112,6 +112,7 @@ serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
 serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] }
 smallvec = { version = "1.6", features = ["union"] }
 smol = { version = "1.2" }
+strum = { version = "0.25.0", features = ["derive"] }
 sysinfo = "0.29.10"
 tempdir = { version = "0.3.7" }
 thiserror = { version = "1.0.29" }

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -312,3 +312,22 @@ CREATE TABLE IF NOT EXISTS "observed_channel_messages" (
 );
 
 CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id");
+
+CREATE TABLE "notification_kinds" (
+    "id" INTEGER PRIMARY KEY NOT NULL,
+    "name" VARCHAR NOT NULL,
+);
+
+CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name");
+
+CREATE TABLE "notifications" (
+    "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+    "created_at" TIMESTAMP NOT NULL default now,
+    "recipent_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
+    "is_read" BOOLEAN NOT NULL DEFAULT FALSE,
+    "entity_id_1" INTEGER,
+    "entity_id_2" INTEGER
+);
+
+CREATE INDEX "index_notifications_on_recipient_id" ON "notifications" ("recipient_id");

crates/collab/migrations/20231004130100_create_notifications.sql 🔗

@@ -0,0 +1,18 @@
+CREATE TABLE "notification_kinds" (
+    "id" INTEGER PRIMARY KEY NOT NULL,
+    "name" VARCHAR NOT NULL,
+);
+
+CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name");
+
+CREATE TABLE notifications (
+    "id" SERIAL PRIMARY KEY,
+    "created_at" TIMESTAMP NOT NULL DEFAULT now(),
+    "recipent_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
+    "is_read" BOOLEAN NOT NULL DEFAULT FALSE
+    "entity_id_1" INTEGER,
+    "entity_id_2" INTEGER
+);
+
+CREATE INDEX "index_notifications_on_recipient_id" ON "notifications" ("recipient_id");

crates/collab/src/db.rs 🔗

@@ -20,7 +20,7 @@ use rpc::{
 };
 use sea_orm::{
     entity::prelude::*,
-    sea_query::{Alias, Expr, OnConflict, Query},
+    sea_query::{Alias, Expr, OnConflict},
     ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
     FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
     TransactionTrait,

crates/collab/src/db/ids.rs 🔗

@@ -80,3 +80,4 @@ id_type!(SignupId);
 id_type!(UserId);
 id_type!(ChannelBufferCollaboratorId);
 id_type!(FlagId);
+id_type!(NotificationId);

crates/collab/src/db/queries.rs 🔗

@@ -5,6 +5,7 @@ pub mod buffers;
 pub mod channels;
 pub mod contacts;
 pub mod messages;
+pub mod notifications;
 pub mod projects;
 pub mod rooms;
 pub mod servers;

crates/collab/src/db/queries/notifications.rs 🔗

@@ -0,0 +1,140 @@
+use super::*;
+use rpc::{Notification, NotificationEntityKind, NotificationKind};
+
+impl Database {
+    pub async fn ensure_notification_kinds(&self) -> Result<()> {
+        self.transaction(|tx| async move {
+            notification_kind::Entity::insert_many(NotificationKind::all().map(|kind| {
+                notification_kind::ActiveModel {
+                    id: ActiveValue::Set(kind as i32),
+                    name: ActiveValue::Set(kind.to_string()),
+                }
+            }))
+            .on_conflict(OnConflict::new().do_nothing().to_owned())
+            .exec(&*tx)
+            .await?;
+            Ok(())
+        })
+        .await
+    }
+
+    pub async fn get_notifications(
+        &self,
+        recipient_id: UserId,
+        limit: usize,
+    ) -> Result<proto::AddNotifications> {
+        self.transaction(|tx| async move {
+            let mut result = proto::AddNotifications::default();
+
+            let mut rows = notification::Entity::find()
+                .filter(notification::Column::RecipientId.eq(recipient_id))
+                .order_by_desc(notification::Column::Id)
+                .limit(limit as u64)
+                .stream(&*tx)
+                .await?;
+
+            let mut user_ids = Vec::new();
+            let mut channel_ids = Vec::new();
+            let mut message_ids = Vec::new();
+            while let Some(row) = rows.next().await {
+                let row = row?;
+
+                let Some(kind) = NotificationKind::from_i32(row.kind) else {
+                    continue;
+                };
+                let Some(notification) = Notification::from_fields(
+                    kind,
+                    [
+                        row.entity_id_1.map(|id| id as u64),
+                        row.entity_id_2.map(|id| id as u64),
+                        row.entity_id_3.map(|id| id as u64),
+                    ],
+                ) else {
+                    continue;
+                };
+
+                // Gather the ids of all associated entities.
+                let (_, associated_entities) = notification.to_fields();
+                for entity in associated_entities {
+                    let Some((id, kind)) = entity else {
+                        break;
+                    };
+                    match kind {
+                        NotificationEntityKind::User => &mut user_ids,
+                        NotificationEntityKind::Channel => &mut channel_ids,
+                        NotificationEntityKind::ChannelMessage => &mut message_ids,
+                    }
+                    .push(id);
+                }
+
+                result.notifications.push(proto::Notification {
+                    kind: row.kind as u32,
+                    timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
+                    is_read: row.is_read,
+                    entity_id_1: row.entity_id_1.map(|id| id as u64),
+                    entity_id_2: row.entity_id_2.map(|id| id as u64),
+                    entity_id_3: row.entity_id_3.map(|id| id as u64),
+                });
+            }
+
+            let users = user::Entity::find()
+                .filter(user::Column::Id.is_in(user_ids))
+                .all(&*tx)
+                .await?;
+            let channels = channel::Entity::find()
+                .filter(user::Column::Id.is_in(channel_ids))
+                .all(&*tx)
+                .await?;
+            let messages = channel_message::Entity::find()
+                .filter(user::Column::Id.is_in(message_ids))
+                .all(&*tx)
+                .await?;
+
+            for user in users {
+                result.users.push(proto::User {
+                    id: user.id.to_proto(),
+                    github_login: user.github_login,
+                    avatar_url: String::new(),
+                });
+            }
+            for channel in channels {
+                result.channels.push(proto::Channel {
+                    id: channel.id.to_proto(),
+                    name: channel.name,
+                });
+            }
+            for message in messages {
+                result.messages.push(proto::ChannelMessage {
+                    id: message.id.to_proto(),
+                    body: message.body,
+                    timestamp: message.sent_at.assume_utc().unix_timestamp() as u64,
+                    sender_id: message.sender_id.to_proto(),
+                    nonce: None,
+                });
+            }
+
+            Ok(result)
+        })
+        .await
+    }
+
+    pub async fn create_notification(
+        &self,
+        recipient_id: UserId,
+        notification: Notification,
+        tx: &DatabaseTransaction,
+    ) -> Result<()> {
+        let (kind, associated_entities) = notification.to_fields();
+        notification::ActiveModel {
+            recipient_id: ActiveValue::Set(recipient_id),
+            kind: ActiveValue::Set(kind as i32),
+            entity_id_1: ActiveValue::Set(associated_entities[0].map(|(id, _)| id as i32)),
+            entity_id_2: ActiveValue::Set(associated_entities[1].map(|(id, _)| id as i32)),
+            entity_id_3: ActiveValue::Set(associated_entities[2].map(|(id, _)| id as i32)),
+            ..Default::default()
+        }
+        .save(&*tx)
+        .await?;
+        Ok(())
+    }
+}

crates/collab/src/db/tables.rs 🔗

@@ -12,6 +12,8 @@ pub mod contact;
 pub mod feature_flag;
 pub mod follower;
 pub mod language_server;
+pub mod notification;
+pub mod notification_kind;
 pub mod observed_buffer_edits;
 pub mod observed_channel_messages;
 pub mod project;

crates/collab/src/db/tables/notification.rs 🔗

@@ -0,0 +1,29 @@
+use crate::db::{NotificationId, UserId};
+use sea_orm::entity::prelude::*;
+use time::PrimitiveDateTime;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "notifications")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: NotificationId,
+    pub recipient_id: UserId,
+    pub kind: i32,
+    pub is_read: bool,
+    pub created_at: PrimitiveDateTime,
+    pub entity_id_1: Option<i32>,
+    pub entity_id_2: Option<i32>,
+    pub entity_id_3: Option<i32>,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::user::Entity",
+        from = "Column::RecipientId",
+        to = "super::user::Column::Id"
+    )]
+    Recipient,
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/db/tables/notification_kind.rs 🔗

@@ -0,0 +1,14 @@
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "notification_kinds")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub id: i32,
+    pub name: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/rpc/Cargo.toml 🔗

@@ -29,6 +29,7 @@ rsa = "0.4"
 serde.workspace = true
 serde_derive.workspace = true
 smol-timeout = "0.6"
+strum.workspace = true
 tracing = { version = "0.1.34", features = ["log"] }
 zstd = "0.11"
 

crates/rpc/proto/zed.proto 🔗

@@ -170,7 +170,9 @@ message Envelope {
 
         LinkChannel link_channel = 140;
         UnlinkChannel unlink_channel = 141;
-        MoveChannel move_channel = 142; // current max: 144
+        MoveChannel move_channel = 142;
+
+        AddNotifications add_notification = 145; // Current max
     }
 }
 
@@ -1557,3 +1559,40 @@ message UpdateDiffBase {
     uint64 buffer_id = 2;
     optional string diff_base = 3;
 }
+
+message AddNotifications {
+    repeated Notification notifications = 1;
+    repeated User users = 2;
+    repeated Channel channels = 3;
+    repeated ChannelMessage messages = 4;
+}
+
+message Notification {
+    uint32 kind = 1;
+    uint64 timestamp = 2;
+    bool is_read = 3;
+    optional uint64 entity_id_1 = 4;
+    optional uint64 entity_id_2 = 5;
+    optional uint64 entity_id_3 = 6;
+
+    // oneof variant {
+    //     ContactRequest contact_request = 3;
+    //     ChannelInvitation channel_invitation = 4;
+    //     ChatMessageMention chat_message_mention = 5;
+    // };
+
+    // message ContactRequest {
+    //     uint64 requester_id = 1;
+    // }
+
+    // message ChannelInvitation {
+    //     uint64 inviter_id = 1;
+    //     uint64 channel_id = 2;
+    // }
+
+    // message ChatMessageMention {
+    //     uint64 sender_id = 1;
+    //     uint64 channel_id = 2;
+    //     uint64 message_id = 3;
+    // }
+}

crates/rpc/src/notification.rs 🔗

@@ -0,0 +1,105 @@
+use strum::{Display, EnumIter, EnumString, IntoEnumIterator};
+
+// An integer indicating a type of notification. The variants' numerical
+// values are stored in the database, so they should never be removed
+// or changed.
+#[repr(i32)]
+#[derive(Copy, Clone, Debug, EnumIter, EnumString, Display)]
+pub enum NotificationKind {
+    ContactRequest = 0,
+    ChannelInvitation = 1,
+    ChannelMessageMention = 2,
+}
+
+pub enum Notification {
+    ContactRequest {
+        requester_id: u64,
+    },
+    ChannelInvitation {
+        inviter_id: u64,
+        channel_id: u64,
+    },
+    ChannelMessageMention {
+        sender_id: u64,
+        channel_id: u64,
+        message_id: u64,
+    },
+}
+
+#[derive(Copy, Clone)]
+pub enum NotificationEntityKind {
+    User,
+    Channel,
+    ChannelMessage,
+}
+
+impl Notification {
+    pub fn from_fields(kind: NotificationKind, entity_ids: [Option<u64>; 3]) -> Option<Self> {
+        use NotificationKind::*;
+
+        Some(match kind {
+            ContactRequest => Self::ContactRequest {
+                requester_id: entity_ids[0]?,
+            },
+            ChannelInvitation => Self::ChannelInvitation {
+                inviter_id: entity_ids[0]?,
+                channel_id: entity_ids[1]?,
+            },
+            ChannelMessageMention => Self::ChannelMessageMention {
+                sender_id: entity_ids[0]?,
+                channel_id: entity_ids[1]?,
+                message_id: entity_ids[2]?,
+            },
+        })
+    }
+
+    pub fn to_fields(&self) -> (NotificationKind, [Option<(u64, NotificationEntityKind)>; 3]) {
+        use NotificationKind::*;
+
+        match self {
+            Self::ContactRequest { requester_id } => (
+                ContactRequest,
+                [
+                    Some((*requester_id, NotificationEntityKind::User)),
+                    None,
+                    None,
+                ],
+            ),
+
+            Self::ChannelInvitation {
+                inviter_id,
+                channel_id,
+            } => (
+                ChannelInvitation,
+                [
+                    Some((*inviter_id, NotificationEntityKind::User)),
+                    Some((*channel_id, NotificationEntityKind::User)),
+                    None,
+                ],
+            ),
+
+            Self::ChannelMessageMention {
+                sender_id,
+                channel_id,
+                message_id,
+            } => (
+                ChannelMessageMention,
+                [
+                    Some((*sender_id, NotificationEntityKind::User)),
+                    Some((*channel_id, NotificationEntityKind::ChannelMessage)),
+                    Some((*message_id, NotificationEntityKind::Channel)),
+                ],
+            ),
+        }
+    }
+}
+
+impl NotificationKind {
+    pub fn all() -> impl Iterator<Item = Self> {
+        Self::iter()
+    }
+
+    pub fn from_i32(i: i32) -> Option<Self> {
+        Self::iter().find(|kind| *kind as i32 == i)
+    }
+}

crates/rpc/src/rpc.rs 🔗

@@ -1,9 +1,12 @@
 pub mod auth;
 mod conn;
+mod notification;
 mod peer;
 pub mod proto;
+
 pub use conn::Connection;
 pub use peer::*;
+pub use notification::*;
 mod macros;
 
 pub const PROTOCOL_VERSION: u32 = 64;