Added db message and edit operation observation

Mikayla and Max created

Co-authored-by: Max <max@zed.dev>

Change summary

crates/collab/src/db/queries/buffers.rs     | 146 +++-------
crates/collab/src/db/queries/messages.rs    |  61 +++-
crates/collab/src/db/tests/buffer_tests.rs  | 298 +++++++++++-----------
crates/collab/src/db/tests/message_tests.rs | 139 ++++++----
4 files changed, 331 insertions(+), 313 deletions(-)

Detailed changes

crates/collab/src/db/queries/buffers.rs 🔗

@@ -1,6 +1,5 @@
 use super::*;
 use prost::Message;
-use sea_query::Order::Desc;
 use text::{EditOperation, UndoOperation};
 
 pub struct LeftChannelBuffer {
@@ -456,9 +455,21 @@ impl Database {
             let mut channel_members;
 
             if !operations.is_empty() {
+                let max_operation = operations
+                    .iter()
+                    .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
+                    .unwrap();
+
                 // get current channel participants and save the max operation above
-                self.save_max_operation(user, buffer.id, buffer.epoch, operations.as_slice(), &*tx)
-                    .await?;
+                self.save_max_operation(
+                    user,
+                    buffer.id,
+                    buffer.epoch,
+                    *max_operation.replica_id.as_ref(),
+                    *max_operation.lamport_timestamp.as_ref(),
+                    &*tx,
+                )
+                .await?;
 
                 channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
                 let collaborators = self
@@ -509,52 +520,38 @@ impl Database {
         user_id: UserId,
         buffer_id: BufferId,
         epoch: i32,
-        operations: &[buffer_operation::ActiveModel],
+        replica_id: i32,
+        lamport_timestamp: i32,
         tx: &DatabaseTransaction,
     ) -> Result<()> {
         use observed_buffer_edits::Column;
 
-        let max_operation = operations
-            .iter()
-            .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
-            .unwrap();
-
         observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
             user_id: ActiveValue::Set(user_id),
             buffer_id: ActiveValue::Set(buffer_id),
             epoch: ActiveValue::Set(epoch),
-            replica_id: max_operation.replica_id.clone(),
-            lamport_timestamp: max_operation.lamport_timestamp.clone(),
+            replica_id: ActiveValue::Set(replica_id),
+            lamport_timestamp: ActiveValue::Set(lamport_timestamp),
         })
         .on_conflict(
             OnConflict::columns([Column::UserId, Column::BufferId])
                 .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
                 .action_cond_where(
-                    Condition::any()
-                        .add(Column::Epoch.lt(*max_operation.epoch.as_ref()))
-                        .add(
-                            Condition::all()
-                                .add(Column::Epoch.eq(*max_operation.epoch.as_ref()))
+                    Condition::any().add(Column::Epoch.lt(epoch)).add(
+                        Condition::all().add(Column::Epoch.eq(epoch)).add(
+                            Condition::any()
+                                .add(Column::LamportTimestamp.lt(lamport_timestamp))
                                 .add(
-                                    Condition::any()
-                                        .add(
-                                            Column::LamportTimestamp
-                                                .lt(*max_operation.lamport_timestamp.as_ref()),
-                                        )
-                                        .add(
-                                            Column::LamportTimestamp
-                                                .eq(*max_operation.lamport_timestamp.as_ref())
-                                                .and(
-                                                    Column::ReplicaId
-                                                        .lt(*max_operation.replica_id.as_ref()),
-                                                ),
-                                        ),
+                                    Column::LamportTimestamp
+                                        .eq(lamport_timestamp)
+                                        .and(Column::ReplicaId.lt(replica_id)),
                                 ),
                         ),
+                    ),
                 )
                 .to_owned(),
         )
-        .exec(tx)
+        .exec_without_returning(tx)
         .await?;
 
         Ok(())
@@ -689,14 +686,30 @@ impl Database {
         Ok(())
     }
 
-    #[cfg(test)]
-    pub async fn test_has_note_changed(
+    pub async fn observe_buffer_version(
         &self,
+        buffer_id: BufferId,
         user_id: UserId,
-        channel_id: ChannelId,
-    ) -> Result<bool> {
-        self.transaction(|tx| async move { self.has_note_changed(user_id, channel_id, &*tx).await })
-            .await
+        epoch: i32,
+        version: &[proto::VectorClockEntry],
+    ) -> Result<()> {
+        self.transaction(|tx| async move {
+            // For now, combine concurrent operations.
+            let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
+                return Ok(());
+            };
+            self.save_max_operation(
+                user_id,
+                buffer_id,
+                epoch,
+                component.replica_id as i32,
+                component.timestamp as i32,
+                &*tx,
+            )
+            .await?;
+            Ok(())
+        })
+        .await
     }
 
     pub async fn channels_with_changed_notes(
@@ -811,67 +824,6 @@ impl Database {
             .await?;
         Ok(operations)
     }
-
-    pub async fn has_note_changed(
-        &self,
-        user_id: UserId,
-        channel_id: ChannelId,
-        tx: &DatabaseTransaction,
-    ) -> Result<bool> {
-        let Some(buffer_id) = channel::Model {
-            id: channel_id,
-            ..Default::default()
-        }
-        .find_related(buffer::Entity)
-        .one(&*tx)
-        .await?
-        .map(|buffer| buffer.id) else {
-            return Ok(false);
-        };
-
-        let user_max = observed_buffer_edits::Entity::find()
-            .filter(observed_buffer_edits::Column::UserId.eq(user_id))
-            .filter(observed_buffer_edits::Column::BufferId.eq(buffer_id))
-            .one(&*tx)
-            .await?
-            .map(|model| (model.epoch, model.lamport_timestamp));
-
-        let channel_buffer = channel::Model {
-            id: channel_id,
-            ..Default::default()
-        }
-        .find_related(buffer::Entity)
-        .one(&*tx)
-        .await?;
-
-        let Some(channel_buffer) = channel_buffer else {
-            return Ok(false);
-        };
-
-        let mut channel_max = buffer_operation::Entity::find()
-            .filter(buffer_operation::Column::BufferId.eq(channel_buffer.id))
-            .filter(buffer_operation::Column::Epoch.eq(channel_buffer.epoch))
-            .order_by(buffer_operation::Column::Epoch, Desc)
-            .order_by(buffer_operation::Column::LamportTimestamp, Desc)
-            .one(&*tx)
-            .await?
-            .map(|model| (model.epoch, model.lamport_timestamp));
-
-        // If there are no edits in this epoch
-        if channel_max.is_none() {
-            // check if this user observed the last edit of the previous epoch
-            channel_max = buffer_operation::Entity::find()
-                .filter(buffer_operation::Column::BufferId.eq(channel_buffer.id))
-                .filter(buffer_operation::Column::Epoch.eq(channel_buffer.epoch.saturating_sub(1)))
-                .order_by(buffer_operation::Column::Epoch, Desc)
-                .order_by(buffer_operation::Column::LamportTimestamp, Desc)
-                .one(&*tx)
-                .await?
-                .map(|model| (model.epoch, model.lamport_timestamp));
-        }
-
-        Ok(user_max != channel_max)
-    }
 }
 
 fn operation_to_storage(

crates/collab/src/db/queries/messages.rs 🔗

@@ -218,20 +218,12 @@ impl Database {
             }
 
             // Observe this message for the sender
-            observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
-                user_id: ActiveValue::Set(user_id),
-                channel_id: ActiveValue::Set(channel_id),
-                channel_message_id: ActiveValue::Set(message.last_insert_id),
-            })
-            .on_conflict(
-                OnConflict::columns([
-                    observed_channel_messages::Column::ChannelId,
-                    observed_channel_messages::Column::UserId,
-                ])
-                .update_column(observed_channel_messages::Column::ChannelMessageId)
-                .to_owned(),
+            self.observe_channel_message_internal(
+                channel_id,
+                user_id,
+                message.last_insert_id,
+                &*tx,
             )
-            .exec(&*tx)
             .await?;
 
             let mut channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
@@ -246,12 +238,53 @@ impl Database {
         .await
     }
 
+    pub async fn observe_channel_message(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        message_id: MessageId,
+    ) -> Result<()> {
+        self.transaction(|tx| async move {
+            self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
+                .await?;
+            Ok(())
+        })
+        .await
+    }
+
+    async fn observe_channel_message_internal(
+        &self,
+        channel_id: ChannelId,
+        user_id: UserId,
+        message_id: MessageId,
+        tx: &DatabaseTransaction,
+    ) -> Result<()> {
+        observed_channel_messages::Entity::insert(observed_channel_messages::ActiveModel {
+            user_id: ActiveValue::Set(user_id),
+            channel_id: ActiveValue::Set(channel_id),
+            channel_message_id: ActiveValue::Set(message_id),
+        })
+        .on_conflict(
+            OnConflict::columns([
+                observed_channel_messages::Column::ChannelId,
+                observed_channel_messages::Column::UserId,
+            ])
+            .update_column(observed_channel_messages::Column::ChannelMessageId)
+            .action_cond_where(observed_channel_messages::Column::ChannelMessageId.lt(message_id))
+            .to_owned(),
+        )
+        // TODO: Try to upgrade SeaORM so we don't have to do this hack around their bug
+        .exec_without_returning(&*tx)
+        .await?;
+        Ok(())
+    }
+
     pub async fn channels_with_new_messages(
         &self,
         user_id: UserId,
         channel_ids: &[ChannelId],
         tx: &DatabaseTransaction,
-    ) -> Result<HashSet<ChannelId>> {
+    ) -> Result<collections::HashSet<ChannelId>> {
         let mut observed_messages_by_channel_id = HashMap::default();
         let mut rows = observed_channel_messages::Entity::find()
             .filter(observed_channel_messages::Column::UserId.eq(user_id))

crates/collab/src/db/tests/buffer_tests.rs 🔗

@@ -1,6 +1,6 @@
 use super::*;
 use crate::test_both_dbs;
-use language::proto;
+use language::proto::{self, serialize_version};
 use text::Buffer;
 
 test_both_dbs!(
@@ -165,15 +165,13 @@ async fn test_channel_buffers(db: &Arc<Database>) {
 }
 
 test_both_dbs!(
-    test_channel_buffers_diffs,
-    test_channel_buffers_diffs_postgres,
-    test_channel_buffers_diffs_sqlite
+    test_channel_buffers_last_operations,
+    test_channel_buffers_last_operations_postgres,
+    test_channel_buffers_last_operations_sqlite
 );
 
-async fn test_channel_buffers_diffs(db: &Database) {
-    panic!("Rewriting the way this works");
-
-    let a_id = db
+async fn test_channel_buffers_last_operations(db: &Database) {
+    let user_id = db
         .create_user(
             "user_a@example.com",
             false,
@@ -186,7 +184,7 @@ async fn test_channel_buffers_diffs(db: &Database) {
         .await
         .unwrap()
         .user_id;
-    let b_id = db
+    let observer_id = db
         .create_user(
             "user_b@example.com",
             false,
@@ -199,102 +197,6 @@ async fn test_channel_buffers_diffs(db: &Database) {
         .await
         .unwrap()
         .user_id;
-
-    let owner_id = db.create_server("production").await.unwrap().0 as u32;
-
-    let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
-
-    db.invite_channel_member(zed_id, b_id, a_id, false)
-        .await
-        .unwrap();
-
-    db.respond_to_channel_invite(zed_id, b_id, true)
-        .await
-        .unwrap();
-
-    let connection_id_a = ConnectionId {
-        owner_id,
-        id: a_id.0 as u32,
-    };
-    let connection_id_b = ConnectionId {
-        owner_id,
-        id: b_id.0 as u32,
-    };
-
-    // Zero test: A should not register as changed on an unitialized channel buffer
-    assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap());
-
-    let _ = db
-        .join_channel_buffer(zed_id, a_id, connection_id_a)
-        .await
-        .unwrap();
-
-    // Zero test: A should register as changed on an empty channel buffer
-    assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap());
-
-    let mut buffer_a = Buffer::new(0, 0, "".to_string());
-    let mut operations = Vec::new();
-    operations.push(buffer_a.edit([(0..0, "hello world")]));
-    assert_eq!(buffer_a.text(), "hello world");
-
-    let operations = operations
-        .into_iter()
-        .map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
-        .collect::<Vec<_>>();
-
-    db.update_channel_buffer(zed_id, a_id, &operations)
-        .await
-        .unwrap();
-
-    // Smoke test: Does B register as changed, A as unchanged?
-    assert!(db.test_has_note_changed(b_id, zed_id).await.unwrap());
-
-    assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap());
-
-    db.leave_channel_buffer(zed_id, connection_id_a)
-        .await
-        .unwrap();
-
-    // Snapshotting from leaving the channel buffer should not have a diff
-    assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap());
-
-    let _ = db
-        .join_channel_buffer(zed_id, b_id, connection_id_b)
-        .await
-        .unwrap();
-
-    // B has opened the channel buffer, so we shouldn't have any diff
-    assert!(!db.test_has_note_changed(b_id, zed_id).await.unwrap());
-
-    db.leave_channel_buffer(zed_id, connection_id_b)
-        .await
-        .unwrap();
-
-    // Since B just opened and closed the buffer without editing, neither should have a diff
-    assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap());
-    assert!(!db.test_has_note_changed(b_id, zed_id).await.unwrap());
-}
-
-test_both_dbs!(
-    test_channel_buffers_last_operations,
-    test_channel_buffers_last_operations_postgres,
-    test_channel_buffers_last_operations_sqlite
-);
-
-async fn test_channel_buffers_last_operations(db: &Database) {
-    let user_id = db
-        .create_user(
-            "user_a@example.com",
-            false,
-            NewUserParams {
-                github_login: "user_a".into(),
-                github_user_id: 101,
-                invite_count: 0,
-            },
-        )
-        .await
-        .unwrap()
-        .user_id;
     let owner_id = db.create_server("production").await.unwrap().0 as u32;
     let connection_id = ConnectionId {
         owner_id,
@@ -309,6 +211,13 @@ async fn test_channel_buffers_last_operations(db: &Database) {
             .await
             .unwrap();
 
+        db.invite_channel_member(channel, observer_id, user_id, false)
+            .await
+            .unwrap();
+        db.respond_to_channel_invite(channel, observer_id, true)
+            .await
+            .unwrap();
+
         db.join_channel_buffer(channel, user_id, connection_id)
             .await
             .unwrap();
@@ -422,45 +331,146 @@ async fn test_channel_buffers_last_operations(db: &Database) {
         ],
     );
 
-    async fn update_buffer(
-        channel_id: ChannelId,
-        user_id: UserId,
-        db: &Database,
-        operations: Vec<text::Operation>,
-    ) {
-        let operations = operations
+    let changed_channels = db
+        .transaction(|tx| {
+            let buffers = &buffers;
+            async move {
+                db.channels_with_changed_notes(
+                    observer_id,
+                    &[
+                        buffers[0].channel_id,
+                        buffers[1].channel_id,
+                        buffers[2].channel_id,
+                    ],
+                    &*tx,
+                )
+                .await
+            }
+        })
+        .await
+        .unwrap();
+    assert_eq!(
+        changed_channels,
+        [
+            buffers[0].channel_id,
+            buffers[1].channel_id,
+            buffers[2].channel_id,
+        ]
+        .into_iter()
+        .collect::<HashSet<_>>()
+    );
+
+    db.observe_buffer_version(
+        buffers[1].id,
+        observer_id,
+        1,
+        &serialize_version(&text_buffers[1].version()),
+    )
+    .await
+    .unwrap();
+
+    let changed_channels = db
+        .transaction(|tx| {
+            let buffers = &buffers;
+            async move {
+                db.channels_with_changed_notes(
+                    observer_id,
+                    &[
+                        buffers[0].channel_id,
+                        buffers[1].channel_id,
+                        buffers[2].channel_id,
+                    ],
+                    &*tx,
+                )
+                .await
+            }
+        })
+        .await
+        .unwrap();
+    assert_eq!(
+        changed_channels,
+        [buffers[0].channel_id, buffers[2].channel_id,]
+            .into_iter()
+            .collect::<HashSet<_>>()
+    );
+
+    // Observe an earlier version of the buffer.
+    db.observe_buffer_version(
+        buffers[1].id,
+        observer_id,
+        1,
+        &[rpc::proto::VectorClockEntry {
+            replica_id: 0,
+            timestamp: 0,
+        }],
+    )
+    .await
+    .unwrap();
+
+    let changed_channels = db
+        .transaction(|tx| {
+            let buffers = &buffers;
+            async move {
+                db.channels_with_changed_notes(
+                    observer_id,
+                    &[
+                        buffers[0].channel_id,
+                        buffers[1].channel_id,
+                        buffers[2].channel_id,
+                    ],
+                    &*tx,
+                )
+                .await
+            }
+        })
+        .await
+        .unwrap();
+    assert_eq!(
+        changed_channels,
+        [buffers[0].channel_id, buffers[2].channel_id,]
             .into_iter()
-            .map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
-            .collect::<Vec<_>>();
-        db.update_channel_buffer(channel_id, user_id, &operations)
-            .await
-            .unwrap();
-    }
+            .collect::<HashSet<_>>()
+    );
+}
 
-    fn assert_operations(
-        operations: &[buffer_operation::Model],
-        expected: &[(BufferId, i32, &text::Buffer)],
-    ) {
-        let actual = operations
-            .iter()
-            .map(|op| buffer_operation::Model {
-                buffer_id: op.buffer_id,
-                epoch: op.epoch,
-                lamport_timestamp: op.lamport_timestamp,
-                replica_id: op.replica_id,
-                value: vec![],
-            })
-            .collect::<Vec<_>>();
-        let expected = expected
-            .iter()
-            .map(|(buffer_id, epoch, buffer)| buffer_operation::Model {
-                buffer_id: *buffer_id,
-                epoch: *epoch,
-                lamport_timestamp: buffer.lamport_clock.value as i32 - 1,
-                replica_id: buffer.replica_id() as i32,
-                value: vec![],
-            })
-            .collect::<Vec<_>>();
-        assert_eq!(actual, expected, "unexpected operations")
-    }
+async fn update_buffer(
+    channel_id: ChannelId,
+    user_id: UserId,
+    db: &Database,
+    operations: Vec<text::Operation>,
+) {
+    let operations = operations
+        .into_iter()
+        .map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
+        .collect::<Vec<_>>();
+    db.update_channel_buffer(channel_id, user_id, &operations)
+        .await
+        .unwrap();
+}
+
+fn assert_operations(
+    operations: &[buffer_operation::Model],
+    expected: &[(BufferId, i32, &text::Buffer)],
+) {
+    let actual = operations
+        .iter()
+        .map(|op| buffer_operation::Model {
+            buffer_id: op.buffer_id,
+            epoch: op.epoch,
+            lamport_timestamp: op.lamport_timestamp,
+            replica_id: op.replica_id,
+            value: vec![],
+        })
+        .collect::<Vec<_>>();
+    let expected = expected
+        .iter()
+        .map(|(buffer_id, epoch, buffer)| buffer_operation::Model {
+            buffer_id: *buffer_id,
+            epoch: *epoch,
+            lamport_timestamp: buffer.lamport_clock.value as i32 - 1,
+            replica_id: buffer.replica_id() as i32,
+            value: vec![],
+        })
+        .collect::<Vec<_>>();
+    assert_eq!(actual, expected, "unexpected operations")
 }

crates/collab/src/db/tests/message_tests.rs 🔗

@@ -65,9 +65,7 @@ test_both_dbs!(
 );
 
 async fn test_channel_message_new_notification(db: &Arc<Database>) {
-    panic!("Rewriting the way this works");
-
-    let user_a = db
+    let user = db
         .create_user(
             "user_a@example.com",
             false,
@@ -80,7 +78,7 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
         .await
         .unwrap()
         .user_id;
-    let user_b = db
+    let observer = db
         .create_user(
             "user_b@example.com",
             false,
@@ -94,107 +92,132 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
         .unwrap()
         .user_id;
 
-    let channel = db
-        .create_channel("channel", None, "room", user_a)
+    let channel_1 = db
+        .create_channel("channel", None, "room", user)
         .await
         .unwrap();
 
-    db.invite_channel_member(channel, user_b, user_a, false)
+    let channel_2 = db
+        .create_channel("channel-2", None, "room", user)
         .await
         .unwrap();
 
-    db.respond_to_channel_invite(channel, user_b, true)
+    db.invite_channel_member(channel_1, observer, user, false)
         .await
         .unwrap();
 
-    let owner_id = db.create_server("test").await.unwrap().0 as u32;
-
-    // Zero case: no messages at all
-    // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+    db.respond_to_channel_invite(channel_1, observer, true)
+        .await
+        .unwrap();
 
-    let a_connection_id = rpc::ConnectionId { owner_id, id: 0 };
-    db.join_channel_chat(channel, a_connection_id, user_a)
+    db.invite_channel_member(channel_2, observer, user, false)
         .await
         .unwrap();
 
-    let _ = db
-        .create_channel_message(channel, user_a, "1", OffsetDateTime::now_utc(), 1)
+    db.respond_to_channel_invite(channel_2, observer, true)
         .await
         .unwrap();
 
-    let (second_message, _, _) = db
-        .create_channel_message(channel, user_a, "2", OffsetDateTime::now_utc(), 2)
+    let owner_id = db.create_server("test").await.unwrap().0 as u32;
+    let user_connection_id = rpc::ConnectionId { owner_id, id: 0 };
+
+    db.join_channel_chat(channel_1, user_connection_id, user)
         .await
         .unwrap();
 
     let _ = db
-        .create_channel_message(channel, user_a, "3", OffsetDateTime::now_utc(), 3)
+        .create_channel_message(channel_1, user, "1_1", OffsetDateTime::now_utc(), 1)
         .await
         .unwrap();
 
-    // Smoke test: can we detect a new message?
-    // assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
-
-    let b_connection_id = rpc::ConnectionId { owner_id, id: 1 };
-    db.join_channel_chat(channel, b_connection_id, user_b)
+    let (second_message, _, _) = db
+        .create_channel_message(channel_1, user, "1_2", OffsetDateTime::now_utc(), 2)
         .await
         .unwrap();
 
-    // Joining the channel should _not_ update us to the latest message
-    // assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
-
-    // Reading the earlier messages should not change that we have new messages
-    let _ = db
-        .get_channel_messages(channel, user_b, 1, Some(second_message))
+    let (third_message, _, _) = db
+        .create_channel_message(channel_1, user, "1_3", OffsetDateTime::now_utc(), 3)
         .await
         .unwrap();
 
-    // assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
-
-    // This constraint is currently inexpressible, creating a message implicitly broadcasts
-    // it to all participants
-    //
-    // Creating new messages when we haven't read the latest one should not change the flag
-    // let _ = db
-    //     .create_channel_message(channel, user_a, "4", OffsetDateTime::now_utc(), 4)
-    //     .await
-    //     .unwrap();
-    // assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
+    db.join_channel_chat(channel_2, user_connection_id, user)
+        .await
+        .unwrap();
 
-    // But reading the latest message should clear the flag
     let _ = db
-        .get_channel_messages(channel, user_b, 4, None)
+        .create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4)
         .await
         .unwrap();
 
-    // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
-
-    // And future messages should not reset the flag
-    let _ = db
-        .create_channel_message(channel, user_a, "5", OffsetDateTime::now_utc(), 5)
+    // Check that observer has new messages
+    let channels_with_new_messages = db
+        .transaction(|tx| async move {
+            db.channels_with_new_messages(observer, &[channel_1, channel_2], &*tx)
+                .await
+        })
         .await
         .unwrap();
 
-    // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+    assert_eq!(
+        channels_with_new_messages,
+        [channel_1, channel_2]
+            .into_iter()
+            .collect::<collections::HashSet<_>>()
+    );
 
-    let _ = db
-        .create_channel_message(channel, user_b, "6", OffsetDateTime::now_utc(), 6)
+    // Observe the second message
+    db.observe_channel_message(channel_1, observer, second_message)
         .await
         .unwrap();
 
-    // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+    // Make sure the observer still has a new message
+    let channels_with_new_messages = db
+        .transaction(|tx| async move {
+            db.channels_with_new_messages(observer, &[channel_1, channel_2], &*tx)
+                .await
+        })
+        .await
+        .unwrap();
+    assert_eq!(
+        channels_with_new_messages,
+        [channel_1, channel_2]
+            .into_iter()
+            .collect::<collections::HashSet<_>>()
+    );
 
-    // And we should start seeing the flag again after we've left the channel
-    db.leave_channel_chat(channel, b_connection_id, user_b)
+    // Observe the third message,
+    db.observe_channel_message(channel_1, observer, third_message)
         .await
         .unwrap();
 
-    // assert!(!db.has_new_message_tx(channel, user_b).await.unwrap());
+    // Make sure the observer does not have a new method
+    let channels_with_new_messages = db
+        .transaction(|tx| async move {
+            db.channels_with_new_messages(observer, &[channel_1, channel_2], &*tx)
+                .await
+        })
+        .await
+        .unwrap();
+    assert_eq!(
+        channels_with_new_messages,
+        [channel_2].into_iter().collect::<collections::HashSet<_>>()
+    );
 
-    let _ = db
-        .create_channel_message(channel, user_a, "7", OffsetDateTime::now_utc(), 7)
+    // Observe the second message again, should not regress our observed state
+    db.observe_channel_message(channel_1, observer, second_message)
         .await
         .unwrap();
 
-    // assert!(db.has_new_message_tx(channel, user_b).await.unwrap());
+    // Make sure the observer does not have a new method
+    let channels_with_new_messages = db
+        .transaction(|tx| async move {
+            db.channels_with_new_messages(observer, &[channel_1, channel_2], &*tx)
+                .await
+        })
+        .await
+        .unwrap();
+    assert_eq!(
+        channels_with_new_messages,
+        [channel_2].into_iter().collect::<collections::HashSet<_>>()
+    );
 }