Snapshot channel notes buffers when everyone leaves

Max Brunsfeld and Mikayla created

Co-authored-by: Mikayla <mikayla@zed.dev>

Change summary

Cargo.lock                                                      |   1 
crates/collab/Cargo.toml                                        |   1 
crates/collab/migrations.sqlite/20221109000000_test_schema.sql  |   1 
crates/collab/migrations/20230819154600_add_channel_buffers.sql |   1 
crates/collab/src/db/queries/buffers.rs                         | 352 ++
crates/collab/src/db/tables/buffer_snapshot.rs                  |   1 
crates/collab/src/db/tests/buffer_tests.rs                      |  10 
crates/language/src/proto.rs                                    |   1 
crates/text/src/text.rs                                         |   2 
9 files changed, 273 insertions(+), 97 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -1458,6 +1458,7 @@ dependencies = [
  "channel",
  "clap 3.2.25",
  "client",
+ "clock",
  "collections",
  "ctor",
  "dashmap",

crates/collab/Cargo.toml 🔗

@@ -14,6 +14,7 @@ name = "seed"
 required-features = ["seed-support"]
 
 [dependencies]
+clock = { path = "../clock" }
 collections = { path = "../collections" }
 live_kit_server = { path = "../live_kit_server" }
 text = { path = "../text" }

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -233,6 +233,7 @@ CREATE TABLE "buffer_snapshots" (
     "buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
     "epoch" INTEGER NOT NULL,
     "text" TEXT NOT NULL,
+    "operation_serialization_version" INTEGER NOT NULL,
     PRIMARY KEY(buffer_id, epoch)
 );
 

crates/collab/src/db/queries/buffers.rs 🔗

@@ -1,5 +1,7 @@
 use super::*;
 use prost::Message;
+use std::ops::Range;
+use text::{EditOperation, InsertionTimestamp, UndoOperation};
 
 impl Database {
     pub async fn join_channel_buffer(
@@ -31,6 +33,16 @@ impl Database {
                 }
                 .insert(&*tx)
                 .await?;
+                buffer_snapshot::ActiveModel {
+                    buffer_id: ActiveValue::Set(buffer.id),
+                    epoch: ActiveValue::Set(0),
+                    text: ActiveValue::Set(String::new()),
+                    operation_serialization_version: ActiveValue::Set(
+                        storage::SERIALIZATION_VERSION,
+                    ),
+                }
+                .insert(&*tx)
+                .await?;
                 buffer
             };
 
@@ -60,58 +72,7 @@ impl Database {
             collaborators.push(collaborator);
 
             // Assemble the buffer state
-            let id = buffer.id;
-            let base_text = if buffer.epoch > 0 {
-                buffer_snapshot::Entity::find()
-                    .filter(
-                        buffer_snapshot::Column::BufferId
-                            .eq(id)
-                            .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
-                    )
-                    .one(&*tx)
-                    .await?
-                    .ok_or_else(|| anyhow!("no such snapshot"))?
-                    .text
-            } else {
-                String::new()
-            };
-
-            let mut rows = buffer_operation::Entity::find()
-                .filter(
-                    buffer_operation::Column::BufferId
-                        .eq(id)
-                        .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
-                )
-                .stream(&*tx)
-                .await?;
-            let mut operations = Vec::new();
-            while let Some(row) = rows.next().await {
-                let row = row?;
-                let version = deserialize_version(&row.version)?;
-                let operation = if row.is_undo {
-                    let counts = deserialize_undo_operation(&row.value)?;
-                    proto::operation::Variant::Undo(proto::operation::Undo {
-                        replica_id: row.replica_id as u32,
-                        local_timestamp: row.local_timestamp as u32,
-                        lamport_timestamp: row.lamport_timestamp as u32,
-                        version,
-                        counts,
-                    })
-                } else {
-                    let (ranges, new_text) = deserialize_edit_operation(&row.value)?;
-                    proto::operation::Variant::Edit(proto::operation::Edit {
-                        replica_id: row.replica_id as u32,
-                        local_timestamp: row.local_timestamp as u32,
-                        lamport_timestamp: row.lamport_timestamp as u32,
-                        version,
-                        ranges,
-                        new_text,
-                    })
-                };
-                operations.push(proto::Operation {
-                    variant: Some(operation),
-                })
-            }
+            let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
 
             Ok(proto::JoinChannelBufferResponse {
                 buffer_id: buffer.id.to_proto(),
@@ -180,6 +141,12 @@ impl Database {
             });
         }
 
+        drop(rows);
+
+        if connections.is_empty() {
+            self.snapshot_buffer(channel_id, &tx).await?;
+        }
+
         Ok(connections)
     }
 
@@ -258,42 +225,23 @@ impl Database {
                 .one(&*tx)
                 .await?
                 .ok_or_else(|| anyhow!("no such buffer"))?;
-            let buffer_id = buffer.id;
+
+            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+            enum QueryVersion {
+                OperationSerializationVersion,
+            }
+
+            let serialization_version: i32 = buffer
+                .find_related(buffer_snapshot::Entity)
+                .select_only()
+                .filter(buffer_snapshot::Column::Epoch.eq(buffer.epoch))
+                .into_values::<_, QueryVersion>()
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("missing buffer snapshot"))?;
+
             buffer_operation::Entity::insert_many(operations.iter().filter_map(|operation| {
-                match operation.variant.as_ref()? {
-                    proto::operation::Variant::Edit(operation) => {
-                        let value =
-                            serialize_edit_operation(&operation.ranges, &operation.new_text);
-                        let version = serialize_version(&operation.version);
-                        Some(buffer_operation::ActiveModel {
-                            buffer_id: ActiveValue::Set(buffer_id),
-                            epoch: ActiveValue::Set(buffer.epoch),
-                            replica_id: ActiveValue::Set(operation.replica_id as i32),
-                            lamport_timestamp: ActiveValue::Set(operation.lamport_timestamp as i32),
-                            local_timestamp: ActiveValue::Set(operation.local_timestamp as i32),
-                            is_undo: ActiveValue::Set(false),
-                            version: ActiveValue::Set(version),
-                            value: ActiveValue::Set(value),
-                        })
-                    }
-                    proto::operation::Variant::Undo(operation) => {
-                        let value = serialize_undo_operation(&operation.counts);
-                        let version = serialize_version(&operation.version);
-                        Some(buffer_operation::ActiveModel {
-                            buffer_id: ActiveValue::Set(buffer_id),
-                            epoch: ActiveValue::Set(buffer.epoch),
-                            replica_id: ActiveValue::Set(operation.replica_id as i32),
-                            lamport_timestamp: ActiveValue::Set(operation.lamport_timestamp as i32),
-                            local_timestamp: ActiveValue::Set(operation.local_timestamp as i32),
-                            is_undo: ActiveValue::Set(true),
-                            version: ActiveValue::Set(version),
-                            value: ActiveValue::Set(value),
-                        })
-                    }
-                    proto::operation::Variant::UpdateSelections(_) => None,
-                    proto::operation::Variant::UpdateDiagnostics(_) => None,
-                    proto::operation::Variant::UpdateCompletionTriggers(_) => None,
-                }
+                operation_to_storage(operation, &buffer, serialization_version)
             }))
             .exec(&*tx)
             .await?;
@@ -318,6 +266,222 @@ impl Database {
         })
         .await
     }
+
+    async fn get_buffer_state(
+        &self,
+        buffer: &buffer::Model,
+        tx: &DatabaseTransaction,
+    ) -> Result<(String, Vec<proto::Operation>)> {
+        let id = buffer.id;
+        let (base_text, version) = if buffer.epoch > 0 {
+            let snapshot = buffer_snapshot::Entity::find()
+                .filter(
+                    buffer_snapshot::Column::BufferId
+                        .eq(id)
+                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
+                )
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such snapshot"))?;
+
+            let version = snapshot.operation_serialization_version;
+            (snapshot.text, version)
+        } else {
+            (String::new(), storage::SERIALIZATION_VERSION)
+        };
+
+        let mut rows = buffer_operation::Entity::find()
+            .filter(
+                buffer_operation::Column::BufferId
+                    .eq(id)
+                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
+            )
+            .stream(&*tx)
+            .await?;
+        let mut operations = Vec::new();
+        while let Some(row) = rows.next().await {
+            let row = row?;
+
+            let operation = operation_from_storage(row, version)?;
+            operations.push(proto::Operation {
+                variant: Some(operation),
+            })
+        }
+
+        Ok((base_text, operations))
+    }
+
+    async fn snapshot_buffer(&self, channel_id: ChannelId, tx: &DatabaseTransaction) -> Result<()> {
+        let buffer = channel::Model {
+            id: channel_id,
+            ..Default::default()
+        }
+        .find_related(buffer::Entity)
+        .one(&*tx)
+        .await?
+        .ok_or_else(|| anyhow!("no such buffer"))?;
+
+        let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
+
+        let mut text_buffer = text::Buffer::new(0, 0, base_text);
+
+        text_buffer
+            .apply_ops(
+                operations
+                    .into_iter()
+                    .filter_map(deserialize_wire_operation),
+            )
+            .unwrap();
+
+        let base_text = text_buffer.text();
+        let epoch = buffer.epoch + 1;
+
+        buffer_snapshot::Model {
+            buffer_id: buffer.id,
+            epoch,
+            text: base_text,
+            operation_serialization_version: storage::SERIALIZATION_VERSION,
+        }
+        .into_active_model()
+        .insert(tx)
+        .await?;
+
+        buffer::ActiveModel {
+            id: ActiveValue::Unchanged(buffer.id),
+            epoch: ActiveValue::Set(epoch),
+            ..Default::default()
+        }
+        .save(tx)
+        .await?;
+
+        Ok(())
+    }
+}
+
+fn operation_to_storage(
+    operation: &proto::Operation,
+    buffer: &buffer::Model,
+    _format: i32,
+) -> Option<buffer_operation::ActiveModel> {
+    match operation.variant.as_ref()? {
+        proto::operation::Variant::Edit(operation) => {
+            let value = edit_operation_to_storage(&operation.ranges, &operation.new_text);
+            let version = version_to_storage(&operation.version);
+            Some(buffer_operation::ActiveModel {
+                buffer_id: ActiveValue::Set(buffer.id),
+                epoch: ActiveValue::Set(buffer.epoch),
+                replica_id: ActiveValue::Set(operation.replica_id as i32),
+                lamport_timestamp: ActiveValue::Set(operation.lamport_timestamp as i32),
+                local_timestamp: ActiveValue::Set(operation.local_timestamp as i32),
+                is_undo: ActiveValue::Set(false),
+                version: ActiveValue::Set(version),
+                value: ActiveValue::Set(value),
+            })
+        }
+        proto::operation::Variant::Undo(operation) => {
+            let value = undo_operation_to_storage(&operation.counts);
+            let version = version_to_storage(&operation.version);
+            Some(buffer_operation::ActiveModel {
+                buffer_id: ActiveValue::Set(buffer.id),
+                epoch: ActiveValue::Set(buffer.epoch),
+                replica_id: ActiveValue::Set(operation.replica_id as i32),
+                lamport_timestamp: ActiveValue::Set(operation.lamport_timestamp as i32),
+                local_timestamp: ActiveValue::Set(operation.local_timestamp as i32),
+                is_undo: ActiveValue::Set(true),
+                version: ActiveValue::Set(version),
+                value: ActiveValue::Set(value),
+            })
+        }
+        proto::operation::Variant::UpdateSelections(_) => None,
+        proto::operation::Variant::UpdateDiagnostics(_) => None,
+        proto::operation::Variant::UpdateCompletionTriggers(_) => None,
+    }
+}
+
+fn operation_from_storage(
+    row: buffer_operation::Model,
+    _format_version: i32,
+) -> Result<proto::operation::Variant, Error> {
+    let version = version_from_storage(&row.version)?;
+    let operation = if row.is_undo {
+        let counts = undo_operation_from_storage(&row.value)?;
+        proto::operation::Variant::Undo(proto::operation::Undo {
+            replica_id: row.replica_id as u32,
+            local_timestamp: row.local_timestamp as u32,
+            lamport_timestamp: row.lamport_timestamp as u32,
+            version,
+            counts,
+        })
+    } else {
+        let (ranges, new_text) = edit_operation_from_storage(&row.value)?;
+        proto::operation::Variant::Edit(proto::operation::Edit {
+            replica_id: row.replica_id as u32,
+            local_timestamp: row.local_timestamp as u32,
+            lamport_timestamp: row.lamport_timestamp as u32,
+            version,
+            ranges,
+            new_text,
+        })
+    };
+    Ok(operation)
+}
+
+// This is currently a manual copy of the deserialization code in the client's langauge crate
+pub fn deserialize_wire_operation(operation: proto::Operation) -> Option<text::Operation> {
+    match operation.variant? {
+        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
+            timestamp: InsertionTimestamp {
+                replica_id: edit.replica_id as text::ReplicaId,
+                local: edit.local_timestamp,
+                lamport: edit.lamport_timestamp,
+            },
+            version: deserialize_wire_version(&edit.version),
+            ranges: edit.ranges.into_iter().map(deserialize_range).collect(),
+            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
+        })),
+        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo {
+            lamport_timestamp: clock::Lamport {
+                replica_id: undo.replica_id as text::ReplicaId,
+                value: undo.lamport_timestamp,
+            },
+            undo: UndoOperation {
+                id: clock::Local {
+                    replica_id: undo.replica_id as text::ReplicaId,
+                    value: undo.local_timestamp,
+                },
+                version: deserialize_wire_version(&undo.version),
+                counts: undo
+                    .counts
+                    .into_iter()
+                    .map(|c| {
+                        (
+                            clock::Local {
+                                replica_id: c.replica_id as text::ReplicaId,
+                                value: c.local_timestamp,
+                            },
+                            c.count,
+                        )
+                    })
+                    .collect(),
+            },
+        }),
+        _ => None,
+    }
+}
+
+pub fn deserialize_range(range: proto::Range) -> Range<text::FullOffset> {
+    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
+}
+
+fn deserialize_wire_version(message: &[proto::VectorClockEntry]) -> clock::Global {
+    let mut version = clock::Global::new();
+    for entry in message {
+        version.observe(clock::Local {
+            replica_id: entry.replica_id as text::ReplicaId,
+            value: entry.timestamp,
+        });
+    }
+    version
 }
 
 mod storage {
@@ -325,7 +489,7 @@ mod storage {
 
     use prost::Message;
 
-    pub const VERSION: usize = 1;
+    pub const SERIALIZATION_VERSION: i32 = 1;
 
     #[derive(Message)]
     pub struct VectorClock {
@@ -374,7 +538,7 @@ mod storage {
     }
 }
 
-fn serialize_version(version: &Vec<proto::VectorClockEntry>) -> Vec<u8> {
+fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<u8> {
     storage::VectorClock {
         entries: version
             .iter()
@@ -387,7 +551,7 @@ fn serialize_version(version: &Vec<proto::VectorClockEntry>) -> Vec<u8> {
     .encode_to_vec()
 }
 
-fn deserialize_version(bytes: &[u8]) -> Result<Vec<proto::VectorClockEntry>> {
+fn version_from_storage(bytes: &[u8]) -> Result<Vec<proto::VectorClockEntry>> {
     let clock = storage::VectorClock::decode(bytes).map_err(|error| anyhow!("{}", error))?;
     Ok(clock
         .entries
@@ -399,7 +563,7 @@ fn deserialize_version(bytes: &[u8]) -> Result<Vec<proto::VectorClockEntry>> {
         .collect())
 }
 
-fn serialize_edit_operation(ranges: &[proto::Range], texts: &[String]) -> Vec<u8> {
+fn edit_operation_to_storage(ranges: &[proto::Range], texts: &[String]) -> Vec<u8> {
     storage::TextEdit {
         ranges: ranges
             .iter()
@@ -413,7 +577,7 @@ fn serialize_edit_operation(ranges: &[proto::Range], texts: &[String]) -> Vec<u8
     .encode_to_vec()
 }
 
-fn deserialize_edit_operation(bytes: &[u8]) -> Result<(Vec<proto::Range>, Vec<String>)> {
+fn edit_operation_from_storage(bytes: &[u8]) -> Result<(Vec<proto::Range>, Vec<String>)> {
     let edit = storage::TextEdit::decode(bytes).map_err(|error| anyhow!("{}", error))?;
     let ranges = edit
         .ranges
@@ -426,7 +590,7 @@ fn deserialize_edit_operation(bytes: &[u8]) -> Result<(Vec<proto::Range>, Vec<St
     Ok((ranges, edit.texts))
 }
 
-fn serialize_undo_operation(counts: &Vec<proto::UndoCount>) -> Vec<u8> {
+fn undo_operation_to_storage(counts: &Vec<proto::UndoCount>) -> Vec<u8> {
     storage::Undo {
         entries: counts
             .iter()
@@ -440,7 +604,7 @@ fn serialize_undo_operation(counts: &Vec<proto::UndoCount>) -> Vec<u8> {
     .encode_to_vec()
 }
 
-fn deserialize_undo_operation(bytes: &[u8]) -> Result<Vec<proto::UndoCount>> {
+fn undo_operation_from_storage(bytes: &[u8]) -> Result<Vec<proto::UndoCount>> {
     let undo = storage::Undo::decode(bytes).map_err(|error| anyhow!("{}", error))?;
     Ok(undo
         .entries

crates/collab/src/db/tests/buffer_tests.rs 🔗

@@ -10,7 +10,6 @@ test_both_dbs!(
 );
 
 async fn test_channel_buffers(db: &Arc<Database>) {
-    // Prep database test info
     let a_id = db
         .create_user(
             "user_a@example.com",
@@ -155,5 +154,12 @@ async fn test_channel_buffers(db: &Arc<Database>) {
     assert_eq!(zed_collaborators, &[]);
     assert_eq!(cargo_collaborators, &[]);
 
-    // TODO: test buffer epoch incrementing
+    // When everyone has left the channel, the operations are collapsed into
+    // a new base text.
+    let buffer_response_b = db
+        .join_channel_buffer(zed_id, b_id, connection_id_b)
+        .await
+        .unwrap();
+    assert_eq!(buffer_response_b.base_text, "hello, cruel world");
+    assert_eq!(buffer_response_b.operations, &[]);
 }

crates/language/src/proto.rs 🔗

@@ -207,6 +207,7 @@ pub fn serialize_anchor(anchor: &Anchor) -> proto::Anchor {
     }
 }
 
+// This behavior is currently copied in the collab database, for snapshotting channel notes
 pub fn deserialize_operation(message: proto::Operation) -> Result<crate::Operation> {
     Ok(
         match message

crates/text/src/text.rs 🔗

@@ -12,7 +12,7 @@ mod undo_map;
 
 pub use anchor::*;
 use anyhow::{anyhow, Result};
-use clock::ReplicaId;
+pub use clock::ReplicaId;
 use collections::{HashMap, HashSet};
 use fs::LineEnding;
 use locator::Locator;