Denormalize buffer operations (#9026)

Conrad Irwin , Max , and Nathan created

This should significantly reduce database load on redeploy.

Co-Authored-By: Max <max@zed.dev>
Co-Authored-By: Nathan <nathan@zed.dev>

Release Notes:

- Reduced likelihood of being disconnected during deploys

Co-authored-by: Max <max@zed.dev>
Co-authored-by: Nathan <nathan@zed.dev>

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql     |  5 
crates/collab/migrations/20240307163119_denormalize_buffer_ops.sql | 17 
crates/collab/src/db/queries/buffers.rs                            | 89 
crates/collab/src/db/queries/channels.rs                           | 25 
crates/collab/src/db/tables/buffer.rs                              |  3 
crates/collab/src/db/tests/buffer_tests.rs                         | 94 
6 files changed, 62 insertions(+), 171 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -248,7 +248,10 @@ CREATE UNIQUE INDEX "index_channel_members_on_channel_id_and_user_id" ON "channe
 CREATE TABLE "buffers" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "channel_id" INTEGER NOT NULL REFERENCES channels (id) ON DELETE CASCADE,
-    "epoch" INTEGER NOT NULL DEFAULT 0
+    "epoch" INTEGER NOT NULL DEFAULT 0,
+    "latest_operation_epoch" INTEGER,
+    "latest_operation_replica_id" INTEGER,
+    "latest_operation_lamport_timestamp" INTEGER
 );
 
 CREATE INDEX "index_buffers_on_channel_id" ON "buffers" ("channel_id");

crates/collab/migrations/20240307163119_denormalize_buffer_ops.sql 🔗

@@ -0,0 +1,17 @@
+-- Add migration script here
+
+ALTER TABLE buffers ADD COLUMN latest_operation_epoch INTEGER;
+ALTER TABLE buffers ADD COLUMN latest_operation_lamport_timestamp INTEGER;
+ALTER TABLE buffers ADD COLUMN latest_operation_replica_id INTEGER;
+
+WITH ops AS (
+    SELECT DISTINCT ON (buffer_id) buffer_id, epoch, lamport_timestamp, replica_id
+    FROM buffer_operations
+    ORDER BY buffer_id, epoch DESC, lamport_timestamp DESC, replica_id DESC
+)
+UPDATE buffers
+SET latest_operation_epoch = ops.epoch,
+    latest_operation_lamport_timestamp = ops.lamport_timestamp,
+    latest_operation_replica_id = ops.replica_id
+FROM ops
+WHERE buffers.id = ops.buffer_id;

crates/collab/src/db/queries/buffers.rs 🔗

@@ -558,6 +558,17 @@ impl Database {
         lamport_timestamp: i32,
         tx: &DatabaseTransaction,
     ) -> Result<()> {
+        buffer::Entity::update(buffer::ActiveModel {
+            id: ActiveValue::Unchanged(buffer_id),
+            epoch: ActiveValue::Unchanged(epoch),
+            latest_operation_epoch: ActiveValue::Set(Some(epoch)),
+            latest_operation_replica_id: ActiveValue::Set(Some(replica_id)),
+            latest_operation_lamport_timestamp: ActiveValue::Set(Some(lamport_timestamp)),
+            channel_id: ActiveValue::NotSet,
+        })
+        .exec(tx)
+        .await?;
+
         use observed_buffer_edits::Column;
         observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
             user_id: ActiveValue::Set(user_id),
@@ -711,7 +722,10 @@ impl Database {
         buffer::ActiveModel {
             id: ActiveValue::Unchanged(buffer.id),
             epoch: ActiveValue::Set(epoch),
-            ..Default::default()
+            latest_operation_epoch: ActiveValue::NotSet,
+            latest_operation_replica_id: ActiveValue::NotSet,
+            latest_operation_lamport_timestamp: ActiveValue::NotSet,
+            channel_id: ActiveValue::NotSet,
         }
         .save(tx)
         .await?;
@@ -745,30 +759,6 @@ impl Database {
         .await
     }
 
-    pub async fn latest_channel_buffer_changes(
-        &self,
-        channel_ids_by_buffer_id: &HashMap<BufferId, ChannelId>,
-        tx: &DatabaseTransaction,
-    ) -> Result<Vec<proto::ChannelBufferVersion>> {
-        let latest_operations = self
-            .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), tx)
-            .await?;
-
-        Ok(latest_operations
-            .iter()
-            .flat_map(|op| {
-                Some(proto::ChannelBufferVersion {
-                    channel_id: channel_ids_by_buffer_id.get(&op.buffer_id)?.to_proto(),
-                    epoch: op.epoch as u64,
-                    version: vec![proto::VectorClockEntry {
-                        replica_id: op.replica_id as u32,
-                        timestamp: op.lamport_timestamp as u32,
-                    }],
-                })
-            })
-            .collect())
-    }
-
     pub async fn observed_channel_buffer_changes(
         &self,
         channel_ids_by_buffer_id: &HashMap<BufferId, ChannelId>,
@@ -798,55 +788,6 @@ impl Database {
             })
             .collect())
     }
-
-    /// Returns the latest operations for the buffers with the specified IDs.
-    pub async fn get_latest_operations_for_buffers(
-        &self,
-        buffer_ids: impl IntoIterator<Item = BufferId>,
-        tx: &DatabaseTransaction,
-    ) -> Result<Vec<buffer_operation::Model>> {
-        let mut values = String::new();
-        for id in buffer_ids {
-            if !values.is_empty() {
-                values.push_str(", ");
-            }
-            write!(&mut values, "({})", id).unwrap();
-        }
-
-        if values.is_empty() {
-            return Ok(Vec::default());
-        }
-
-        let sql = format!(
-            r#"
-            SELECT
-                *
-            FROM
-            (
-                SELECT
-                    *,
-                    row_number() OVER (
-                        PARTITION BY buffer_id
-                        ORDER BY
-                            epoch DESC,
-                            lamport_timestamp DESC,
-                            replica_id DESC
-                    ) as row_number
-                FROM buffer_operations
-                WHERE
-                    buffer_id in ({values})
-            ) AS last_operations
-            WHERE
-                row_number = 1
-            "#,
-        );
-
-        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
-        Ok(buffer_operation::Entity::find()
-            .from_raw_sql(stmt)
-            .all(tx)
-            .await?)
-    }
 }
 
 fn operation_to_storage(

crates/collab/src/db/queries/channels.rs 🔗

@@ -1,5 +1,8 @@
 use super::*;
-use rpc::{proto::channel_member::Kind, ErrorCode, ErrorCodeExt};
+use rpc::{
+    proto::{channel_member::Kind, ChannelBufferVersion, VectorClockEntry},
+    ErrorCode, ErrorCodeExt,
+};
 use sea_orm::TryGetableMany;
 
 impl Database {
@@ -625,6 +628,7 @@ impl Database {
         let channel_ids = channels.iter().map(|c| c.id).collect::<Vec<_>>();
 
         let mut channel_ids_by_buffer_id = HashMap::default();
+        let mut latest_buffer_versions: Vec<ChannelBufferVersion> = vec![];
         let mut rows = buffer::Entity::find()
             .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied()))
             .stream(tx)
@@ -632,13 +636,24 @@ impl Database {
         while let Some(row) = rows.next().await {
             let row = row?;
             channel_ids_by_buffer_id.insert(row.id, row.channel_id);
+            latest_buffer_versions.push(ChannelBufferVersion {
+                channel_id: row.channel_id.0 as u64,
+                epoch: row.latest_operation_epoch.unwrap_or_default() as u64,
+                version: if let Some((latest_lamport_timestamp, latest_replica_id)) = row
+                    .latest_operation_lamport_timestamp
+                    .zip(row.latest_operation_replica_id)
+                {
+                    vec![VectorClockEntry {
+                        timestamp: latest_lamport_timestamp as u32,
+                        replica_id: latest_replica_id as u32,
+                    }]
+                } else {
+                    vec![]
+                },
+            });
         }
         drop(rows);
 
-        let latest_buffer_versions = self
-            .latest_channel_buffer_changes(&channel_ids_by_buffer_id, tx)
-            .await?;
-
         let latest_channel_messages = self.latest_channel_messages(&channel_ids, tx).await?;
 
         let observed_buffer_versions = self

crates/collab/src/db/tables/buffer.rs 🔗

@@ -8,6 +8,9 @@ pub struct Model {
     pub id: BufferId,
     pub epoch: i32,
     pub channel_id: ChannelId,
+    pub latest_operation_epoch: Option<i32>,
+    pub latest_operation_lamport_timestamp: Option<i32>,
+    pub latest_operation_replica_id: Option<i32>,
 }
 
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

crates/collab/src/db/tests/buffer_tests.rs 🔗

@@ -235,19 +235,6 @@ async fn test_channel_buffers_last_operations(db: &Database) {
         ));
     }
 
-    let operations = db
-        .transaction(|tx| {
-            let buffers = &buffers;
-            async move {
-                db.get_latest_operations_for_buffers([buffers[0].id, buffers[2].id], &tx)
-                    .await
-            }
-        })
-        .await
-        .unwrap();
-
-    assert!(operations.is_empty());
-
     update_buffer(
         buffers[0].channel_id,
         user_id,
@@ -299,57 +286,10 @@ async fn test_channel_buffers_last_operations(db: &Database) {
     )
     .await;
 
-    let operations = db
-        .transaction(|tx| {
-            let buffers = &buffers;
-            async move {
-                db.get_latest_operations_for_buffers([buffers[1].id, buffers[2].id], &tx)
-                    .await
-            }
-        })
-        .await
-        .unwrap();
-    assert_operations(
-        &operations,
-        &[
-            (buffers[1].id, 1, &text_buffers[1]),
-            (buffers[2].id, 0, &text_buffers[2]),
-        ],
-    );
-
-    let operations = db
-        .transaction(|tx| {
-            let buffers = &buffers;
-            async move {
-                db.get_latest_operations_for_buffers([buffers[0].id, buffers[1].id], &tx)
-                    .await
-            }
-        })
-        .await
-        .unwrap();
-    assert_operations(
-        &operations,
-        &[
-            (buffers[0].id, 0, &text_buffers[0]),
-            (buffers[1].id, 1, &text_buffers[1]),
-        ],
-    );
-
-    let buffer_changes = db
-        .transaction(|tx| {
-            let buffers = &buffers;
-            let mut hash = HashMap::default();
-            hash.insert(buffers[0].id, buffers[0].channel_id);
-            hash.insert(buffers[1].id, buffers[1].channel_id);
-            hash.insert(buffers[2].id, buffers[2].channel_id);
-
-            async move { db.latest_channel_buffer_changes(&hash, &tx).await }
-        })
-        .await
-        .unwrap();
+    let channels_for_user = db.get_channels_for_user(user_id).await.unwrap();
 
     pretty_assertions::assert_eq!(
-        buffer_changes,
+        channels_for_user.latest_buffer_versions,
         [
             rpc::proto::ChannelBufferVersion {
                 channel_id: buffers[0].channel_id.to_proto(),
@@ -361,8 +301,7 @@ async fn test_channel_buffers_last_operations(db: &Database) {
                 epoch: 1,
                 version: serialize_version(&text_buffers[1].version())
                     .into_iter()
-                    .filter(|vector| vector.replica_id
-                        == buffer_changes[1].version.first().unwrap().replica_id)
+                    .filter(|vector| vector.replica_id == text_buffers[1].replica_id() as u32)
                     .collect::<Vec<_>>(),
             },
             rpc::proto::ChannelBufferVersion {
@@ -388,30 +327,3 @@ async fn update_buffer(
         .await
         .unwrap();
 }
-
-fn assert_operations(
-    operations: &[buffer_operation::Model],
-    expected: &[(BufferId, i32, &text::Buffer)],
-) {
-    let actual = operations
-        .iter()
-        .map(|op| buffer_operation::Model {
-            buffer_id: op.buffer_id,
-            epoch: op.epoch,
-            lamport_timestamp: op.lamport_timestamp,
-            replica_id: op.replica_id,
-            value: vec![],
-        })
-        .collect::<Vec<_>>();
-    let expected = expected
-        .iter()
-        .map(|(buffer_id, epoch, buffer)| buffer_operation::Model {
-            buffer_id: *buffer_id,
-            epoch: *epoch,
-            lamport_timestamp: buffer.lamport_clock.value as i32 - 1,
-            replica_id: buffer.replica_id() as i32,
-            value: vec![],
-        })
-        .collect::<Vec<_>>();
-    assert_eq!(actual, expected, "unexpected operations")
-}