Add database implementation of channel message change tracking

Mikayla created

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql  |   9 
crates/collab/migrations/20230925210437_add_channel_changes.sql |  18 
crates/collab/migrations/20230925210437_add_observed_notes.sql  |   9 
crates/collab/src/db/queries.rs                                 |  10 
crates/collab/src/db/queries/buffers.rs                         |  10 
crates/collab/src/db/queries/messages.rs                        | 121 ++
crates/collab/src/db/tables.rs                                  |   1 
crates/collab/src/db/tables/observed_channel_messages.rs        |  41 
crates/collab/src/db/tests/message_tests.rs                     | 139 +++
9 files changed, 339 insertions(+), 19 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -300,3 +300,12 @@ CREATE TABLE "observed_buffer_edits" (
 );
 
 CREATE UNIQUE INDEX "index_observed_buffers_user_and_buffer_id" ON "observed_buffer_edits" ("user_id", "buffer_id");
+
+CREATE TABLE IF NOT EXISTS "observed_channel_messages" (
+    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+    "channel_message_id" INTEGER NOT NULL,
+    PRIMARY KEY (user_id, channel_id)
+);
+
+CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id");

crates/collab/migrations/20230925210437_add_channel_changes.sql 🔗

@@ -0,0 +1,18 @@
+CREATE TABLE IF NOT EXISTS "observed_buffer_edits" (
+    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
+    "epoch" INTEGER NOT NULL,
+    "lamport_timestamp" INTEGER NOT NULL,
+    PRIMARY KEY (user_id, buffer_id)
+);
+
+CREATE UNIQUE INDEX "index_observed_buffer_user_and_buffer_id" ON "observed_buffer_edits" ("user_id", "buffer_id");
+
+CREATE TABLE IF NOT EXISTS "observed_channel_messages" (
+    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+    "channel_message_id" INTEGER NOT NULL,
+    PRIMARY KEY (user_id, channel_id)
+);
+
+CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id");

crates/collab/migrations/20230925210437_add_observed_notes.sql 🔗

@@ -1,9 +0,0 @@
-CREATE TABLE "observed_buffer_edits" (
-    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
-    "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
-    "epoch" INTEGER NOT NULL,
-    "lamport_timestamp" INTEGER NOT NULL,
-    PRIMARY KEY (user_id, buffer_id)
-);
-
-CREATE UNIQUE INDEX "index_observed_buffer_user_and_buffer_id" ON "observed_buffer_edits" ("user_id", "buffer_id");

crates/collab/src/db/queries.rs 🔗

@@ -9,3 +9,13 @@ pub mod projects;
 pub mod rooms;
 pub mod servers;
 pub mod users;
+
+fn max_assign<T: Ord>(max: &mut Option<T>, val: T) {
+    if let Some(max_val) = max {
+        if val > *max_val {
+            *max = Some(val);
+        }
+    } else {
+        *max = Some(val);
+    }
+}

crates/collab/src/db/queries/buffers.rs 🔗

@@ -787,16 +787,6 @@ impl Database {
     }
 }
 
-fn max_assign<T: Ord>(max: &mut Option<T>, val: T) {
-    if let Some(max_val) = max {
-        if val > *max_val {
-            *max = Some(val);
-        }
-    } else {
-        *max = Some(val);
-    }
-}
-
 fn operation_to_storage(
     operation: &proto::Operation,
     buffer: &buffer::Model,

crates/collab/src/db/queries/messages.rs 🔗

@@ -93,9 +93,13 @@ impl Database {
                 .stream(&*tx)
                 .await?;
 
+            let mut max_id = None;
             let mut messages = Vec::new();
             while let Some(row) = rows.next().await {
                 let row = row?;
+                dbg!(&max_id);
+                max_assign(&mut max_id, row.id);
+
                 let nonce = row.nonce.as_u64_pair();
                 messages.push(proto::ChannelMessage {
                     id: row.id.to_proto(),
@@ -108,6 +112,55 @@ impl Database {
                     }),
                 });
             }
+            drop(rows);
+            dbg!(&max_id);
+
+            if let Some(max_id) = max_id {
+                let has_older_message = dbg!(
+                    observed_channel_messages::Entity::find()
+                        .filter(
+                            observed_channel_messages::Column::UserId
+                                .eq(user_id)
+                                .and(observed_channel_messages::Column::ChannelId.eq(channel_id))
+                                .and(
+                                    observed_channel_messages::Column::ChannelMessageId.lt(max_id)
+                                ),
+                        )
+                        .one(&*tx)
+                        .await
+                )?
+                .is_some();
+
+                if has_older_message {
+                    observed_channel_messages::Entity::update(
+                        observed_channel_messages::ActiveModel {
+                            user_id: ActiveValue::Unchanged(user_id),
+                            channel_id: ActiveValue::Unchanged(channel_id),
+                            channel_message_id: ActiveValue::Set(max_id),
+                        },
+                    )
+                    .exec(&*tx)
+                    .await?;
+                } else {
+                    observed_channel_messages::Entity::insert(
+                        observed_channel_messages::ActiveModel {
+                            user_id: ActiveValue::Set(user_id),
+                            channel_id: ActiveValue::Set(channel_id),
+                            channel_message_id: ActiveValue::Set(max_id),
+                        },
+                    )
+                    .on_conflict(
+                        OnConflict::columns([
+                            observed_channel_messages::Column::UserId,
+                            observed_channel_messages::Column::ChannelId,
+                        ])
+                        .update_columns([observed_channel_messages::Column::ChannelMessageId])
+                        .to_owned(),
+                    )
+                    .exec(&*tx)
+                    .await?;
+                }
+            }
 
             Ok(messages)
         })
@@ -130,11 +183,13 @@ impl Database {
 
             let mut is_participant = false;
             let mut participant_connection_ids = Vec::new();
+            let mut participant_user_ids = Vec::new();
             while let Some(row) = rows.next().await {
                 let row = row?;
                 if row.user_id == user_id {
                     is_participant = true;
                 }
+                participant_user_ids.push(row.user_id);
                 participant_connection_ids.push(row.connection());
             }
             drop(rows);
@@ -167,11 +222,77 @@ impl Database {
                 ConnectionId,
             }
 
+            // Observe this message for all participants
+            observed_channel_messages::Entity::insert_many(participant_user_ids.iter().map(
+                |pariticpant_id| observed_channel_messages::ActiveModel {
+                    user_id: ActiveValue::Set(*pariticpant_id),
+                    channel_id: ActiveValue::Set(channel_id),
+                    channel_message_id: ActiveValue::Set(message.last_insert_id),
+                },
+            ))
+            .on_conflict(
+                OnConflict::columns([
+                    observed_channel_messages::Column::ChannelId,
+                    observed_channel_messages::Column::UserId,
+                ])
+                .update_column(observed_channel_messages::Column::ChannelMessageId)
+                .to_owned(),
+            )
+            .exec(&*tx)
+            .await?;
+
             Ok((message.last_insert_id, participant_connection_ids))
         })
         .await
     }
 
+    #[cfg(test)]
+    pub async fn has_new_message_tx(&self, channel_id: ChannelId, user_id: UserId) -> Result<bool> {
+        self.transaction(|tx| async move { self.has_new_message(channel_id, user_id, &*tx).await })
+            .await
+    }
+
+    #[cfg(test)]
+    pub async fn dbg_print_messages(&self) -> Result<()> {
+        self.transaction(|tx| async move {
+            dbg!(observed_channel_messages::Entity::find()
+                .all(&*tx)
+                .await
+                .unwrap());
+            dbg!(channel_message::Entity::find().all(&*tx).await.unwrap());
+
+            Ok(())
+        })
+        .await
+    }
+
+    pub async fn has_new_message(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        tx: &DatabaseTransaction,
+    ) -> Result<bool> {
+        self.check_user_is_channel_member(channel_id, user_id, &*tx)
+            .await?;
+
+        let latest_message_id = channel_message::Entity::find()
+            .filter(Condition::all().add(channel_message::Column::ChannelId.eq(channel_id)))
+            .order_by(channel_message::Column::SentAt, sea_query::Order::Desc)
+            .limit(1 as u64)
+            .one(&*tx)
+            .await?
+            .map(|model| model.id);
+
+        let last_message_read = observed_channel_messages::Entity::find()
+            .filter(observed_channel_messages::Column::ChannelId.eq(channel_id))
+            .filter(observed_channel_messages::Column::UserId.eq(user_id))
+            .one(&*tx)
+            .await?
+            .map(|model| model.channel_message_id);
+
+        Ok(dbg!(last_message_read) != dbg!(latest_message_id))
+    }
+
     pub async fn remove_channel_message(
         &self,
         channel_id: ChannelId,

crates/collab/src/db/tables.rs 🔗

@@ -13,6 +13,7 @@ pub mod feature_flag;
 pub mod follower;
 pub mod language_server;
 pub mod observed_buffer_edits;
+pub mod observed_channel_messages;
 pub mod project;
 pub mod project_collaborator;
 pub mod room;

crates/collab/src/db/tables/observed_channel_messages.rs 🔗

@@ -0,0 +1,41 @@
+use crate::db::{ChannelId, MessageId, UserId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "observed_channel_messages")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub user_id: UserId,
+    pub channel_id: ChannelId,
+    pub channel_message_id: MessageId,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::channel::Entity",
+        from = "Column::ChannelId",
+        to = "super::channel::Column::Id"
+    )]
+    Channel,
+    #[sea_orm(
+        belongs_to = "super::user::Entity",
+        from = "Column::UserId",
+        to = "super::user::Column::Id"
+    )]
+    User,
+}
+
+impl Related<super::channel::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Channel.def()
+    }
+}
+
+impl Related<super::user::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::User.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/db/tests/message_tests.rs 🔗

@@ -57,3 +57,142 @@ async fn test_channel_message_nonces(db: &Arc<Database>) {
     assert_eq!(msg1_id, msg3_id);
     assert_eq!(msg2_id, msg4_id);
 }
+
+test_both_dbs!(
+    test_channel_message_new_notification,
+    test_channel_message_new_notification_postgres,
+    test_channel_message_new_notification_sqlite
+);
+
+async fn test_channel_message_new_notification(db: &Arc<Database>) {
+    let user_a = db
+        .create_user(
+            "user_a@example.com",
+            false,
+            NewUserParams {
+                github_login: "user_a".into(),
+                github_user_id: 1,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+    let user_b = db
+        .create_user(
+            "user_b@example.com",
+            false,
+            NewUserParams {
+                github_login: "user_b".into(),
+                github_user_id: 1,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+
+    let channel = db
+        .create_channel("channel", None, "room", user_a)
+        .await
+        .unwrap();
+
+    db.invite_channel_member(channel, user_b, user_a, false)
+        .await
+        .unwrap();
+
+    db.respond_to_channel_invite(channel, user_b, true)
+        .await
+        .unwrap();
+
+    let owner_id = db.create_server("test").await.unwrap().0 as u32;
+
+    // Zero case: no messages at all
+    assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    let a_connection_id = rpc::ConnectionId { owner_id, id: 0 };
+    db.join_channel_chat(channel, a_connection_id, user_a)
+        .await
+        .unwrap();
+
+    let _ = db
+        .create_channel_message(channel, user_a, "1", OffsetDateTime::now_utc(), 1)
+        .await
+        .unwrap();
+
+    let (second_message, _) = db
+        .create_channel_message(channel, user_a, "2", OffsetDateTime::now_utc(), 2)
+        .await
+        .unwrap();
+
+    let _ = db
+        .create_channel_message(channel, user_a, "3", OffsetDateTime::now_utc(), 3)
+        .await
+        .unwrap();
+
+    // Smoke test: can we detect a new message?
+    assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    let b_connection_id = rpc::ConnectionId { owner_id, id: 1 };
+    db.join_channel_chat(channel, b_connection_id, user_b)
+        .await
+        .unwrap();
+
+    // Joining the channel should _not_ update us to the latest message
+    assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    // Reading the earlier messages should not change that we have new messages
+    let _ = db
+        .get_channel_messages(channel, user_b, 1, Some(second_message))
+        .await
+        .unwrap();
+
+    assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    // This constraint is currently inexpressible, creating a message implicitly broadcasts
+    // it to all participants
+    //
+    // Creating new messages when we haven't read the latest one should not change the flag
+    // let _ = db
+    //     .create_channel_message(channel, user_a, "4", OffsetDateTime::now_utc(), 4)
+    //     .await
+    //     .unwrap();
+    // assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    // But reading the latest message should clear the flag
+    let _ = db
+        .get_channel_messages(channel, user_b, 4, None)
+        .await
+        .unwrap();
+
+    assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    // And future messages should not reset the flag
+    let _ = db
+        .create_channel_message(channel, user_a, "5", OffsetDateTime::now_utc(), 5)
+        .await
+        .unwrap();
+
+    assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    let _ = db
+        .create_channel_message(channel, user_b, "6", OffsetDateTime::now_utc(), 6)
+        .await
+        .unwrap();
+
+    assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    // And we should start seeing the flag again after we've left the channel
+    db.leave_channel_chat(channel, b_connection_id, user_b)
+        .await
+        .unwrap();
+
+    assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+
+    let _ = db
+        .create_channel_message(channel, user_a, "7", OffsetDateTime::now_utc(), 7)
+        .await
+        .unwrap();
+
+    assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
+}