Add observed_channel_notes table and implement note diffing

Mikayla created

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql |  11 
crates/collab/migrations/20230925210437_add_observed_notes.sql |   9 
crates/collab/src/db/queries/buffers.rs                        | 235 +++
crates/collab/src/db/tables.rs                                 |   1 
crates/collab/src/db/tables/observed_note_edits.rs             |  42 
crates/collab/src/db/tests/buffer_tests.rs                     | 108 +
6 files changed, 384 insertions(+), 22 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -289,3 +289,14 @@ CREATE TABLE "user_features" (
 CREATE UNIQUE INDEX "index_user_features_user_id_and_feature_id" ON "user_features" ("user_id", "feature_id");
 CREATE INDEX "index_user_features_on_user_id" ON "user_features" ("user_id");
 CREATE INDEX "index_user_features_on_feature_id" ON "user_features" ("feature_id");
+
+
+CREATE TABLE "observed_channel_note_edits" (
+    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+    "epoch" INTEGER NOT NULL,
+    "lamport_timestamp" INTEGER NOT NULL,
+    PRIMARY KEY (user_id, channel_id)
+);
+
+CREATE UNIQUE INDEX "index_observed_notes_user_and_channel_id" ON "observed_channel_note_edits" ("user_id", "channel_id");

crates/collab/migrations/20230925210437_add_observed_notes.sql 🔗

@@ -0,0 +1,9 @@
+CREATE TABLE "observed_channel_note_edits" (
+    "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+    "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
+    "epoch" INTEGER NOT NULL,
+    "lamport_timestamp" INTEGER NOT NULL,
+    PRIMARY KEY (user_id, channel_id)
+);
+
+CREATE UNIQUE INDEX "index_observed_notes_user_and_channel_id" ON "observed_channel_note_edits" ("user_id", "channel_id");

crates/collab/src/db/queries/buffers.rs 🔗

@@ -1,5 +1,6 @@
 use super::*;
 use prost::Message;
+use sea_query::Order::Desc;
 use text::{EditOperation, UndoOperation};
 
 pub struct LeftChannelBuffer {
@@ -74,7 +75,62 @@ impl Database {
             .await?;
             collaborators.push(collaborator);
 
-            let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
+            let (base_text, operations, max_operation) =
+                self.get_buffer_state(&buffer, &tx).await?;
+
+            // Save the last observed operation
+            if let Some(max_operation) = max_operation {
+                observed_note_edits::Entity::insert(observed_note_edits::ActiveModel {
+                    user_id: ActiveValue::Set(user_id),
+                    channel_id: ActiveValue::Set(channel_id),
+                    epoch: ActiveValue::Set(max_operation.0),
+                    lamport_timestamp: ActiveValue::Set(max_operation.1),
+                })
+                .on_conflict(
+                    OnConflict::columns([
+                        observed_note_edits::Column::UserId,
+                        observed_note_edits::Column::ChannelId,
+                    ])
+                    .update_columns([
+                        observed_note_edits::Column::Epoch,
+                        observed_note_edits::Column::LamportTimestamp,
+                    ])
+                    .to_owned(),
+                )
+                .exec(&*tx)
+                .await?;
+            } else {
+                let buffer_max = buffer_operation::Entity::find()
+                    .filter(buffer_operation::Column::BufferId.eq(buffer.id))
+                    .filter(buffer_operation::Column::Epoch.eq(buffer.epoch.saturating_sub(1)))
+                    .order_by(buffer_operation::Column::Epoch, Desc)
+                    .order_by(buffer_operation::Column::LamportTimestamp, Desc)
+                    .one(&*tx)
+                    .await?
+                    .map(|model| (model.epoch, model.lamport_timestamp));
+
+                if let Some(buffer_max) = buffer_max {
+                    observed_note_edits::Entity::insert(observed_note_edits::ActiveModel {
+                        user_id: ActiveValue::Set(user_id),
+                        channel_id: ActiveValue::Set(channel_id),
+                        epoch: ActiveValue::Set(buffer_max.0),
+                        lamport_timestamp: ActiveValue::Set(buffer_max.1),
+                    })
+                    .on_conflict(
+                        OnConflict::columns([
+                            observed_note_edits::Column::UserId,
+                            observed_note_edits::Column::ChannelId,
+                        ])
+                        .update_columns([
+                            observed_note_edits::Column::Epoch,
+                            observed_note_edits::Column::LamportTimestamp,
+                        ])
+                        .to_owned(),
+                    )
+                    .exec(&*tx)
+                    .await?;
+                }
+            }
 
             Ok(proto::JoinChannelBufferResponse {
                 buffer_id: buffer.id.to_proto(),
@@ -373,27 +429,35 @@ impl Database {
         channel_id: ChannelId,
     ) -> Result<Vec<UserId>> {
         self.transaction(|tx| async move {
-            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
-            enum QueryUserIds {
-                UserId,
-            }
-
-            let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
-                .select_only()
-                .column(channel_buffer_collaborator::Column::UserId)
-                .filter(
-                    Condition::all()
-                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
-                )
-                .into_values::<_, QueryUserIds>()
-                .all(&*tx)
-                .await?;
-
-            Ok(users)
+            self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
+                .await
         })
         .await
     }
 
+    async fn get_channel_buffer_collaborators_internal(
+        &self,
+        channel_id: ChannelId,
+        tx: &DatabaseTransaction,
+    ) -> Result<Vec<UserId>> {
+        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+        enum QueryUserIds {
+            UserId,
+        }
+
+        let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
+            .select_only()
+            .column(channel_buffer_collaborator::Column::UserId)
+            .filter(
+                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
+            )
+            .into_values::<_, QueryUserIds>()
+            .all(&*tx)
+            .await?;
+
+        Ok(users)
+    }
+
     pub async fn update_channel_buffer(
         &self,
         channel_id: ChannelId,
@@ -418,7 +482,12 @@ impl Database {
                 .iter()
                 .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
                 .collect::<Vec<_>>();
+
             if !operations.is_empty() {
+                // get current channel participants and save the max operation above
+                self.save_max_operation_for_collaborators(operations.as_slice(), channel_id, &*tx)
+                    .await?;
+
                 buffer_operation::Entity::insert_many(operations)
                     .on_conflict(
                         OnConflict::columns([
@@ -455,6 +524,60 @@ impl Database {
         .await
     }
 
+    async fn save_max_operation_for_collaborators(
+        &self,
+        operations: &[buffer_operation::ActiveModel],
+        channel_id: ChannelId,
+        tx: &DatabaseTransaction,
+    ) -> Result<()> {
+        let max_operation = operations
+            .iter()
+            .map(|storage_model| {
+                (
+                    storage_model.epoch.clone(),
+                    storage_model.lamport_timestamp.clone(),
+                )
+            })
+            .max_by(
+                |(epoch_a, lamport_timestamp_a), (epoch_b, lamport_timestamp_b)| {
+                    epoch_a.as_ref().cmp(epoch_b.as_ref()).then(
+                        lamport_timestamp_a
+                            .as_ref()
+                            .cmp(lamport_timestamp_b.as_ref()),
+                    )
+                },
+            )
+            .unwrap();
+
+        let users = self
+            .get_channel_buffer_collaborators_internal(channel_id, tx)
+            .await?;
+
+        observed_note_edits::Entity::insert_many(users.iter().map(|id| {
+            observed_note_edits::ActiveModel {
+                user_id: ActiveValue::Set(*id),
+                channel_id: ActiveValue::Set(channel_id),
+                epoch: max_operation.0.clone(),
+                lamport_timestamp: ActiveValue::Set(*max_operation.1.as_ref()),
+            }
+        }))
+        .on_conflict(
+            OnConflict::columns([
+                observed_note_edits::Column::UserId,
+                observed_note_edits::Column::ChannelId,
+            ])
+            .update_columns([
+                observed_note_edits::Column::Epoch,
+                observed_note_edits::Column::LamportTimestamp,
+            ])
+            .to_owned(),
+        )
+        .exec(tx)
+        .await?;
+
+        Ok(())
+    }
+
     async fn get_buffer_operation_serialization_version(
         &self,
         buffer_id: BufferId,
@@ -491,7 +614,7 @@ impl Database {
         &self,
         buffer: &buffer::Model,
         tx: &DatabaseTransaction,
-    ) -> Result<(String, Vec<proto::Operation>)> {
+    ) -> Result<(String, Vec<proto::Operation>, Option<(i32, i32)>)> {
         let id = buffer.id;
         let (base_text, version) = if buffer.epoch > 0 {
             let snapshot = buffer_snapshot::Entity::find()
@@ -519,13 +642,21 @@ impl Database {
             .stream(&*tx)
             .await?;
         let mut operations = Vec::new();
+
+        let mut max_epoch: Option<i32> = None;
+        let mut max_timestamp: Option<i32> = None;
         while let Some(row) = rows.next().await {
+            let row = row?;
+
+            max_assign(&mut max_epoch, row.epoch);
+            max_assign(&mut max_timestamp, row.lamport_timestamp);
+
             operations.push(proto::Operation {
-                variant: Some(operation_from_storage(row?, version)?),
+                variant: Some(operation_from_storage(row, version)?),
             })
         }
 
-        Ok((base_text, operations))
+        Ok((base_text, operations, max_epoch.zip(max_timestamp)))
     }
 
     async fn snapshot_channel_buffer(
@@ -534,7 +665,7 @@ impl Database {
         tx: &DatabaseTransaction,
     ) -> Result<()> {
         let buffer = self.get_channel_buffer(channel_id, tx).await?;
-        let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
+        let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
         if operations.is_empty() {
             return Ok(());
         }
@@ -567,6 +698,66 @@ impl Database {
 
         Ok(())
     }
+
+    pub async fn has_buffer_changed(&self, user_id: UserId, channel_id: ChannelId) -> Result<bool> {
+        self.transaction(|tx| async move {
+            let user_max = observed_note_edits::Entity::find()
+                .filter(observed_note_edits::Column::UserId.eq(user_id))
+                .filter(observed_note_edits::Column::ChannelId.eq(channel_id))
+                .one(&*tx)
+                .await?
+                .map(|model| (model.epoch, model.lamport_timestamp));
+
+            let channel_buffer = channel::Model {
+                id: channel_id,
+                ..Default::default()
+            }
+            .find_related(buffer::Entity)
+            .one(&*tx)
+            .await?;
+
+            let Some(channel_buffer) = channel_buffer else {
+                return Ok(false);
+            };
+
+            let mut channel_max = buffer_operation::Entity::find()
+                .filter(buffer_operation::Column::BufferId.eq(channel_buffer.id))
+                .filter(buffer_operation::Column::Epoch.eq(channel_buffer.epoch))
+                .order_by(buffer_operation::Column::Epoch, Desc)
+                .order_by(buffer_operation::Column::LamportTimestamp, Desc)
+                .one(&*tx)
+                .await?
+                .map(|model| (model.epoch, model.lamport_timestamp));
+
+            // If there are no edits in this epoch
+            if channel_max.is_none() {
+                // check if this user observed the last edit of the previous epoch
+                channel_max = buffer_operation::Entity::find()
+                    .filter(buffer_operation::Column::BufferId.eq(channel_buffer.id))
+                    .filter(
+                        buffer_operation::Column::Epoch.eq(channel_buffer.epoch.saturating_sub(1)),
+                    )
+                    .order_by(buffer_operation::Column::Epoch, Desc)
+                    .order_by(buffer_operation::Column::LamportTimestamp, Desc)
+                    .one(&*tx)
+                    .await?
+                    .map(|model| (model.epoch, model.lamport_timestamp));
+            }
+
+            Ok(user_max != channel_max)
+        })
+        .await
+    }
+}
+
+fn max_assign<T: Ord>(max: &mut Option<T>, val: T) {
+    if let Some(max_val) = max {
+        if val > *max_val {
+            *max = Some(val);
+        }
+    } else {
+        *max = Some(val);
+    }
 }
 
 fn operation_to_storage(

crates/collab/src/db/tables.rs 🔗

@@ -12,6 +12,7 @@ pub mod contact;
 pub mod feature_flag;
 pub mod follower;
 pub mod language_server;
+pub mod observed_note_edits;
 pub mod project;
 pub mod project_collaborator;
 pub mod room;

crates/collab/src/db/tables/observed_note_edits.rs 🔗

@@ -0,0 +1,42 @@
+use crate::db::{ChannelId, UserId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "observed_channel_note_edits")]
+pub struct Model {
+    #[sea_orm(primary_key)]
+    pub user_id: UserId,
+    pub channel_id: ChannelId,
+    pub epoch: i32,
+    pub lamport_timestamp: i32,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+    #[sea_orm(
+        belongs_to = "super::channel::Entity",
+        from = "Column::ChannelId",
+        to = "super::channel::Column::Id"
+    )]
+    Channel,
+    #[sea_orm(
+        belongs_to = "super::user::Entity",
+        from = "Column::UserId",
+        to = "super::user::Column::Id"
+    )]
+    User,
+}
+
+impl Related<super::channel::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::Channel.def()
+    }
+}
+
+impl Related<super::user::Entity> for Entity {
+    fn to() -> RelationDef {
+        Relation::User.def()
+    }
+}
+
+impl ActiveModelBehavior for ActiveModel {}

crates/collab/src/db/tests/buffer_tests.rs 🔗

@@ -163,3 +163,111 @@ async fn test_channel_buffers(db: &Arc<Database>) {
     assert_eq!(buffer_response_b.base_text, "hello, cruel world");
     assert_eq!(buffer_response_b.operations, &[]);
 }
+
+test_both_dbs!(
+    test_channel_buffers_diffs,
+    test_channel_buffers_diffs_postgres,
+    test_channel_buffers_diffs_sqlite
+);
+
+async fn test_channel_buffers_diffs(db: &Database) {
+    let a_id = db
+        .create_user(
+            "user_a@example.com",
+            false,
+            NewUserParams {
+                github_login: "user_a".into(),
+                github_user_id: 101,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+    let b_id = db
+        .create_user(
+            "user_b@example.com",
+            false,
+            NewUserParams {
+                github_login: "user_b".into(),
+                github_user_id: 102,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+
+    let owner_id = db.create_server("production").await.unwrap().0 as u32;
+
+    let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
+
+    db.invite_channel_member(zed_id, b_id, a_id, false)
+        .await
+        .unwrap();
+
+    db.respond_to_channel_invite(zed_id, b_id, true)
+        .await
+        .unwrap();
+
+    let connection_id_a = ConnectionId {
+        owner_id,
+        id: a_id.0 as u32,
+    };
+    let connection_id_b = ConnectionId {
+        owner_id,
+        id: b_id.0 as u32,
+    };
+
+    // Zero test: A should not register as changed on an unitialized channel buffer
+    assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+
+    let _ = db
+        .join_channel_buffer(zed_id, a_id, connection_id_a)
+        .await
+        .unwrap();
+
+    // Zero test: A should register as changed on an empty channel buffer
+    assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+
+    let mut buffer_a = Buffer::new(0, 0, "".to_string());
+    let mut operations = Vec::new();
+    operations.push(buffer_a.edit([(0..0, "hello world")]));
+    assert_eq!(buffer_a.text(), "hello world");
+
+    let operations = operations
+        .into_iter()
+        .map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
+        .collect::<Vec<_>>();
+
+    db.update_channel_buffer(zed_id, a_id, &operations)
+        .await
+        .unwrap();
+
+    // Smoke test: Does B register as changed, A as unchanged?
+    assert!(db.has_buffer_changed(b_id, zed_id).await.unwrap());
+    assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+
+    db.leave_channel_buffer(zed_id, connection_id_a)
+        .await
+        .unwrap();
+
+    // Snapshotting from leaving the channel buffer should not have a diff
+    assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+
+    let _ = db
+        .join_channel_buffer(zed_id, b_id, connection_id_b)
+        .await
+        .unwrap();
+
+    // B has opened the channel buffer, so we shouldn't have any diff
+    assert!(!db.has_buffer_changed(b_id, zed_id).await.unwrap());
+
+    db.leave_channel_buffer(zed_id, connection_id_b)
+        .await
+        .unwrap();
+
+    // Since B just opened and closed the buffer without editing, neither should have a diff
+    assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+    assert!(!db.has_buffer_changed(b_id, zed_id).await.unwrap());
+}