Avoid N+1 query for channels with notes changes

Max Brunsfeld and Mikayla created

Also, start work on new timing for recording observed notes edits.

Co-authored-by: Mikayla <mikayla@zed.dev>

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql  |   1 
crates/collab/migrations/20230925210437_add_channel_changes.sql |   1 
crates/collab/src/db.rs                                         |   4 
crates/collab/src/db/queries/buffers.rs                         | 269 +-
crates/collab/src/db/queries/channels.rs                        |  12 
crates/collab/src/db/tables/observed_buffer_edits.rs            |   1 
crates/collab/src/db/tests/buffer_tests.rs                      | 190 ++
7 files changed, 381 insertions(+), 97 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -296,6 +296,7 @@ CREATE TABLE "observed_buffer_edits" (
     "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
     "epoch" INTEGER NOT NULL,
     "lamport_timestamp" INTEGER NOT NULL,
+    "replica_id" INTEGER NOT NULL,
     PRIMARY KEY (user_id, buffer_id)
 );
 

crates/collab/migrations/20230925210437_add_channel_changes.sql 🔗

@@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS "observed_buffer_edits" (
     "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
     "epoch" INTEGER NOT NULL,
     "lamport_timestamp" INTEGER NOT NULL,
+    "replica_id" INTEGER NOT NULL,
     PRIMARY KEY (user_id, buffer_id)
 );
 

crates/collab/src/db.rs 🔗

@@ -119,7 +119,7 @@ impl Database {
         Ok(new_migrations)
     }
 
-    async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
+    pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
     where
         F: Send + Fn(TransactionHandle) -> Fut,
         Fut: Send + Future<Output = Result<T>>,
@@ -321,7 +321,7 @@ fn is_serialization_error(error: &Error) -> bool {
     }
 }
 
-struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
+pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
 
 impl Deref for TransactionHandle {
     type Target = DatabaseTransaction;

crates/collab/src/db/queries/buffers.rs 🔗

@@ -79,12 +79,13 @@ impl Database {
                 self.get_buffer_state(&buffer, &tx).await?;
 
             // Save the last observed operation
-            if let Some(max_operation) = max_operation {
+            if let Some(op) = max_operation {
                 observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
                     user_id: ActiveValue::Set(user_id),
                     buffer_id: ActiveValue::Set(buffer.id),
-                    epoch: ActiveValue::Set(max_operation.0),
-                    lamport_timestamp: ActiveValue::Set(max_operation.1),
+                    epoch: ActiveValue::Set(op.epoch),
+                    lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
+                    replica_id: ActiveValue::Set(op.replica_id),
                 })
                 .on_conflict(
                     OnConflict::columns([
@@ -99,37 +100,6 @@ impl Database {
                 )
                 .exec(&*tx)
                 .await?;
-            } else {
-                let buffer_max = buffer_operation::Entity::find()
-                    .filter(buffer_operation::Column::BufferId.eq(buffer.id))
-                    .filter(buffer_operation::Column::Epoch.eq(buffer.epoch.saturating_sub(1)))
-                    .order_by(buffer_operation::Column::Epoch, Desc)
-                    .order_by(buffer_operation::Column::LamportTimestamp, Desc)
-                    .one(&*tx)
-                    .await?
-                    .map(|model| (model.epoch, model.lamport_timestamp));
-
-                if let Some(buffer_max) = buffer_max {
-                    observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
-                        user_id: ActiveValue::Set(user_id),
-                        buffer_id: ActiveValue::Set(buffer.id),
-                        epoch: ActiveValue::Set(buffer_max.0),
-                        lamport_timestamp: ActiveValue::Set(buffer_max.1),
-                    })
-                    .on_conflict(
-                        OnConflict::columns([
-                            observed_buffer_edits::Column::UserId,
-                            observed_buffer_edits::Column::BufferId,
-                        ])
-                        .update_columns([
-                            observed_buffer_edits::Column::Epoch,
-                            observed_buffer_edits::Column::LamportTimestamp,
-                        ])
-                        .to_owned(),
-                    )
-                    .exec(&*tx)
-                    .await?;
-                }
             }
 
             Ok(proto::JoinChannelBufferResponse {
@@ -487,13 +457,8 @@ impl Database {
 
             if !operations.is_empty() {
                 // get current channel participants and save the max operation above
-                self.save_max_operation_for_collaborators(
-                    operations.as_slice(),
-                    channel_id,
-                    buffer.id,
-                    &*tx,
-                )
-                .await?;
+                self.save_max_operation(user, buffer.id, buffer.epoch, operations.as_slice(), &*tx)
+                    .await?;
 
                 channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
                 let collaborators = self
@@ -539,54 +504,55 @@ impl Database {
         .await
     }
 
-    async fn save_max_operation_for_collaborators(
+    async fn save_max_operation(
         &self,
-        operations: &[buffer_operation::ActiveModel],
-        channel_id: ChannelId,
+        user_id: UserId,
         buffer_id: BufferId,
+        epoch: i32,
+        operations: &[buffer_operation::ActiveModel],
         tx: &DatabaseTransaction,
     ) -> Result<()> {
+        use observed_buffer_edits::Column;
+
         let max_operation = operations
             .iter()
-            .map(|storage_model| {
-                (
-                    storage_model.epoch.clone(),
-                    storage_model.lamport_timestamp.clone(),
-                )
-            })
-            .max_by(
-                |(epoch_a, lamport_timestamp_a), (epoch_b, lamport_timestamp_b)| {
-                    epoch_a.as_ref().cmp(epoch_b.as_ref()).then(
-                        lamport_timestamp_a
-                            .as_ref()
-                            .cmp(lamport_timestamp_b.as_ref()),
-                    )
-                },
-            )
+            .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
             .unwrap();
 
-        let users = self
-            .get_channel_buffer_collaborators_internal(channel_id, tx)
-            .await?;
-
-        observed_buffer_edits::Entity::insert_many(users.iter().map(|id| {
-            observed_buffer_edits::ActiveModel {
-                user_id: ActiveValue::Set(*id),
-                buffer_id: ActiveValue::Set(buffer_id),
-                epoch: max_operation.0.clone(),
-                lamport_timestamp: ActiveValue::Set(*max_operation.1.as_ref()),
-            }
-        }))
+        observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
+            user_id: ActiveValue::Set(user_id),
+            buffer_id: ActiveValue::Set(buffer_id),
+            epoch: ActiveValue::Set(epoch),
+            replica_id: max_operation.replica_id.clone(),
+            lamport_timestamp: max_operation.lamport_timestamp.clone(),
+        })
         .on_conflict(
-            OnConflict::columns([
-                observed_buffer_edits::Column::UserId,
-                observed_buffer_edits::Column::BufferId,
-            ])
-            .update_columns([
-                observed_buffer_edits::Column::Epoch,
-                observed_buffer_edits::Column::LamportTimestamp,
-            ])
-            .to_owned(),
+            OnConflict::columns([Column::UserId, Column::BufferId])
+                .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
+                .target_cond_where(
+                    Condition::any()
+                        .add(Column::Epoch.lt(*max_operation.epoch.as_ref()))
+                        .add(
+                            Condition::all()
+                                .add(Column::Epoch.eq(*max_operation.epoch.as_ref()))
+                                .add(
+                                    Condition::any()
+                                        .add(
+                                            Column::LamportTimestamp
+                                                .lt(*max_operation.lamport_timestamp.as_ref()),
+                                        )
+                                        .add(
+                                            Column::LamportTimestamp
+                                                .eq(*max_operation.lamport_timestamp.as_ref())
+                                                .and(
+                                                    Column::ReplicaId
+                                                        .lt(*max_operation.replica_id.as_ref()),
+                                                ),
+                                        ),
+                                ),
+                        ),
+                )
+                .to_owned(),
         )
         .exec(tx)
         .await?;
@@ -611,7 +577,7 @@ impl Database {
             .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
     }
 
-    async fn get_channel_buffer(
+    pub async fn get_channel_buffer(
         &self,
         channel_id: ChannelId,
         tx: &DatabaseTransaction,
@@ -630,7 +596,11 @@ impl Database {
         &self,
         buffer: &buffer::Model,
         tx: &DatabaseTransaction,
-    ) -> Result<(String, Vec<proto::Operation>, Option<(i32, i32)>)> {
+    ) -> Result<(
+        String,
+        Vec<proto::Operation>,
+        Option<buffer_operation::Model>,
+    )> {
         let id = buffer.id;
         let (base_text, version) = if buffer.epoch > 0 {
             let snapshot = buffer_snapshot::Entity::find()
@@ -655,24 +625,28 @@ impl Database {
                     .eq(id)
                     .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
             )
+            .order_by_asc(buffer_operation::Column::LamportTimestamp)
+            .order_by_asc(buffer_operation::Column::ReplicaId)
             .stream(&*tx)
             .await?;
-        let mut operations = Vec::new();
 
-        let mut max_epoch: Option<i32> = None;
-        let mut max_timestamp: Option<i32> = None;
+        let mut operations = Vec::new();
+        let mut last_row = None;
         while let Some(row) = rows.next().await {
             let row = row?;
-
-            max_assign(&mut max_epoch, row.epoch);
-            max_assign(&mut max_timestamp, row.lamport_timestamp);
-
+            last_row = Some(buffer_operation::Model {
+                buffer_id: row.buffer_id,
+                epoch: row.epoch,
+                lamport_timestamp: row.lamport_timestamp,
+                replica_id: row.lamport_timestamp,
+                value: Default::default(),
+            });
             operations.push(proto::Operation {
                 variant: Some(operation_from_storage(row, version)?),
-            })
+            });
         }
 
-        Ok((base_text, operations, max_epoch.zip(max_timestamp)))
+        Ok((base_text, operations, last_row))
     }
 
     async fn snapshot_channel_buffer(
@@ -725,6 +699,119 @@ impl Database {
             .await
     }
 
+    pub async fn channels_with_changed_notes(
+        &self,
+        user_id: UserId,
+        channel_ids: impl IntoIterator<Item = ChannelId>,
+        tx: &DatabaseTransaction,
+    ) -> Result<HashSet<ChannelId>> {
+        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+        enum QueryIds {
+            ChannelId,
+            Id,
+        }
+
+        let mut channel_ids_by_buffer_id = HashMap::default();
+        let mut rows = buffer::Entity::find()
+            .filter(buffer::Column::ChannelId.is_in(channel_ids))
+            .stream(&*tx)
+            .await?;
+        while let Some(row) = rows.next().await {
+            let row = row?;
+            channel_ids_by_buffer_id.insert(row.id, row.channel_id);
+        }
+        drop(rows);
+
+        let mut observed_edits_by_buffer_id = HashMap::default();
+        let mut rows = observed_buffer_edits::Entity::find()
+            .filter(observed_buffer_edits::Column::UserId.eq(user_id))
+            .filter(
+                observed_buffer_edits::Column::BufferId
+                    .is_in(channel_ids_by_buffer_id.keys().copied()),
+            )
+            .stream(&*tx)
+            .await?;
+        while let Some(row) = rows.next().await {
+            let row = row?;
+            observed_edits_by_buffer_id.insert(row.buffer_id, row);
+        }
+        drop(rows);
+
+        let last_operations = self
+            .get_last_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
+            .await?;
+
+        let mut channels_with_new_changes = HashSet::default();
+        for last_operation in last_operations {
+            if let Some(observed_edit) = observed_edits_by_buffer_id.get(&last_operation.buffer_id)
+            {
+                if observed_edit.epoch == last_operation.epoch
+                    && observed_edit.lamport_timestamp == last_operation.lamport_timestamp
+                    && observed_edit.replica_id == last_operation.replica_id
+                {
+                    continue;
+                }
+            }
+
+            if let Some(channel_id) = channel_ids_by_buffer_id.get(&last_operation.buffer_id) {
+                channels_with_new_changes.insert(*channel_id);
+            }
+        }
+
+        Ok(channels_with_new_changes)
+    }
+
+    pub async fn get_last_operations_for_buffers(
+        &self,
+        channel_ids: impl IntoIterator<Item = BufferId>,
+        tx: &DatabaseTransaction,
+    ) -> Result<Vec<buffer_operation::Model>> {
+        let mut values = String::new();
+        for id in channel_ids {
+            if !values.is_empty() {
+                values.push_str(", ");
+            }
+            write!(&mut values, "({})", id).unwrap();
+        }
+
+        if values.is_empty() {
+            return Ok(Vec::default());
+        }
+
+        let sql = format!(
+            r#"
+            SELECT
+                *
+            FROM (
+                SELECT
+                    buffer_id,
+                    epoch,
+                    lamport_timestamp,
+                    replica_id,
+                    value,
+                    row_number() OVER (
+                        PARTITION BY buffer_id
+                        ORDER BY
+                            epoch DESC,
+                            lamport_timestamp DESC,
+                            replica_id DESC
+                    ) as row_number
+                FROM buffer_operations
+                WHERE
+                    buffer_id in ({values})
+            ) AS operations
+            WHERE
+                row_number = 1
+            "#,
+        );
+
+        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
+        let operations = buffer_operation::Model::find_by_statement(stmt)
+            .all(&*tx)
+            .await?;
+        Ok(operations)
+    }
+
     pub async fn has_note_changed(
         &self,
         user_id: UserId,

crates/collab/src/db/queries/channels.rs 🔗

@@ -463,12 +463,16 @@ impl Database {
             }
         }
 
-        let mut channels_with_changed_notes = HashSet::default();
+        let channels_with_changed_notes = self
+            .channels_with_changed_notes(
+                user_id,
+                graph.channels.iter().map(|channel| channel.id),
+                &*tx,
+            )
+            .await?;
+
         let mut channels_with_new_messages = HashSet::default();
         for channel in graph.channels.iter() {
-            if self.has_note_changed(user_id, channel.id, tx).await? {
-                channels_with_changed_notes.insert(channel.id);
-            }
             if self.has_new_message(channel.id, user_id, tx).await? {
                 channels_with_new_messages.insert(channel.id);
             }

crates/collab/src/db/tests/buffer_tests.rs 🔗

@@ -272,3 +272,193 @@ async fn test_channel_buffers_diffs(db: &Database) {
     assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap());
     assert!(!db.test_has_note_changed(b_id, zed_id).await.unwrap());
 }
+
+test_both_dbs!(
+    test_channel_buffers_last_operations,
+    test_channel_buffers_last_operations_postgres,
+    test_channel_buffers_last_operations_sqlite
+);
+
+async fn test_channel_buffers_last_operations(db: &Database) {
+    let user_id = db
+        .create_user(
+            "user_a@example.com",
+            false,
+            NewUserParams {
+                github_login: "user_a".into(),
+                github_user_id: 101,
+                invite_count: 0,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+    let owner_id = db.create_server("production").await.unwrap().0 as u32;
+    let connection_id = ConnectionId {
+        owner_id,
+        id: user_id.0 as u32,
+    };
+
+    let mut buffers = Vec::new();
+    let mut text_buffers = Vec::new();
+    for i in 0..3 {
+        let channel = db
+            .create_root_channel(&format!("channel-{i}"), &format!("room-{i}"), user_id)
+            .await
+            .unwrap();
+
+        db.join_channel_buffer(channel, user_id, connection_id)
+            .await
+            .unwrap();
+
+        buffers.push(
+            db.transaction(|tx| async move { db.get_channel_buffer(channel, &*tx).await })
+                .await
+                .unwrap(),
+        );
+
+        text_buffers.push(Buffer::new(0, 0, "".to_string()));
+    }
+
+    let operations = db
+        .transaction(|tx| {
+            let buffers = &buffers;
+            async move {
+                db.get_last_operations_for_buffers([buffers[0].id, buffers[2].id], &*tx)
+                    .await
+            }
+        })
+        .await
+        .unwrap();
+
+    assert!(operations.is_empty());
+
+    update_buffer(
+        buffers[0].channel_id,
+        user_id,
+        db,
+        vec![
+            text_buffers[0].edit([(0..0, "a")]),
+            text_buffers[0].edit([(0..0, "b")]),
+            text_buffers[0].edit([(0..0, "c")]),
+        ],
+    )
+    .await;
+
+    update_buffer(
+        buffers[1].channel_id,
+        user_id,
+        db,
+        vec![
+            text_buffers[1].edit([(0..0, "d")]),
+            text_buffers[1].edit([(1..1, "e")]),
+            text_buffers[1].edit([(2..2, "f")]),
+        ],
+    )
+    .await;
+
+    // cause buffer 1's epoch to increment.
+    db.leave_channel_buffer(buffers[1].channel_id, connection_id)
+        .await
+        .unwrap();
+    db.join_channel_buffer(buffers[1].channel_id, user_id, connection_id)
+        .await
+        .unwrap();
+    text_buffers[1] = Buffer::new(1, 0, "def".to_string());
+    update_buffer(
+        buffers[1].channel_id,
+        user_id,
+        db,
+        vec![
+            text_buffers[1].edit([(0..0, "g")]),
+            text_buffers[1].edit([(0..0, "h")]),
+        ],
+    )
+    .await;
+
+    update_buffer(
+        buffers[2].channel_id,
+        user_id,
+        db,
+        vec![text_buffers[2].edit([(0..0, "i")])],
+    )
+    .await;
+
+    let operations = db
+        .transaction(|tx| {
+            let buffers = &buffers;
+            async move {
+                db.get_last_operations_for_buffers([buffers[1].id, buffers[2].id], &*tx)
+                    .await
+            }
+        })
+        .await
+        .unwrap();
+    assert_operations(
+        &operations,
+        &[
+            (buffers[1].id, 1, &text_buffers[1]),
+            (buffers[2].id, 0, &text_buffers[2]),
+        ],
+    );
+
+    let operations = db
+        .transaction(|tx| {
+            let buffers = &buffers;
+            async move {
+                db.get_last_operations_for_buffers([buffers[0].id, buffers[1].id], &*tx)
+                    .await
+            }
+        })
+        .await
+        .unwrap();
+    assert_operations(
+        &operations,
+        &[
+            (buffers[0].id, 0, &text_buffers[0]),
+            (buffers[1].id, 1, &text_buffers[1]),
+        ],
+    );
+
+    async fn update_buffer(
+        channel_id: ChannelId,
+        user_id: UserId,
+        db: &Database,
+        operations: Vec<text::Operation>,
+    ) {
+        let operations = operations
+            .into_iter()
+            .map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
+            .collect::<Vec<_>>();
+        db.update_channel_buffer(channel_id, user_id, &operations)
+            .await
+            .unwrap();
+    }
+
+    fn assert_operations(
+        operations: &[buffer_operation::Model],
+        expected: &[(BufferId, i32, &text::Buffer)],
+    ) {
+        let actual = operations
+            .iter()
+            .map(|op| buffer_operation::Model {
+                buffer_id: op.buffer_id,
+                epoch: op.epoch,
+                lamport_timestamp: op.lamport_timestamp,
+                replica_id: op.replica_id,
+                value: vec![],
+            })
+            .collect::<Vec<_>>();
+        let expected = expected
+            .iter()
+            .map(|(buffer_id, epoch, buffer)| buffer_operation::Model {
+                buffer_id: *buffer_id,
+                epoch: *epoch,
+                lamport_timestamp: buffer.lamport_clock.value as i32 - 1,
+                replica_id: buffer.replica_id() as i32,
+                value: vec![],
+            })
+            .collect::<Vec<_>>();
+        assert_eq!(actual, expected, "unexpected operations")
+    }
+}