Detailed changes
@@ -296,6 +296,7 @@ CREATE TABLE "observed_buffer_edits" (
"buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
"epoch" INTEGER NOT NULL,
"lamport_timestamp" INTEGER NOT NULL,
+ "replica_id" INTEGER NOT NULL,
PRIMARY KEY (user_id, buffer_id)
);
@@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS "observed_buffer_edits" (
"buffer_id" INTEGER NOT NULL REFERENCES buffers (id) ON DELETE CASCADE,
"epoch" INTEGER NOT NULL,
"lamport_timestamp" INTEGER NOT NULL,
+ "replica_id" INTEGER NOT NULL,
PRIMARY KEY (user_id, buffer_id)
);
@@ -119,7 +119,7 @@ impl Database {
Ok(new_migrations)
}
- async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
+ pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
Fut: Send + Future<Output = Result<T>>,
@@ -321,7 +321,7 @@ fn is_serialization_error(error: &Error) -> bool {
}
}
-struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
+pub struct TransactionHandle(Arc<Option<DatabaseTransaction>>);
impl Deref for TransactionHandle {
type Target = DatabaseTransaction;
@@ -79,12 +79,13 @@ impl Database {
self.get_buffer_state(&buffer, &tx).await?;
// Save the last observed operation
- if let Some(max_operation) = max_operation {
+ if let Some(op) = max_operation {
observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
user_id: ActiveValue::Set(user_id),
buffer_id: ActiveValue::Set(buffer.id),
- epoch: ActiveValue::Set(max_operation.0),
- lamport_timestamp: ActiveValue::Set(max_operation.1),
+ epoch: ActiveValue::Set(op.epoch),
+ lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
+ replica_id: ActiveValue::Set(op.replica_id),
})
.on_conflict(
OnConflict::columns([
@@ -99,37 +100,6 @@ impl Database {
)
.exec(&*tx)
.await?;
- } else {
- let buffer_max = buffer_operation::Entity::find()
- .filter(buffer_operation::Column::BufferId.eq(buffer.id))
- .filter(buffer_operation::Column::Epoch.eq(buffer.epoch.saturating_sub(1)))
- .order_by(buffer_operation::Column::Epoch, Desc)
- .order_by(buffer_operation::Column::LamportTimestamp, Desc)
- .one(&*tx)
- .await?
- .map(|model| (model.epoch, model.lamport_timestamp));
-
- if let Some(buffer_max) = buffer_max {
- observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
- user_id: ActiveValue::Set(user_id),
- buffer_id: ActiveValue::Set(buffer.id),
- epoch: ActiveValue::Set(buffer_max.0),
- lamport_timestamp: ActiveValue::Set(buffer_max.1),
- })
- .on_conflict(
- OnConflict::columns([
- observed_buffer_edits::Column::UserId,
- observed_buffer_edits::Column::BufferId,
- ])
- .update_columns([
- observed_buffer_edits::Column::Epoch,
- observed_buffer_edits::Column::LamportTimestamp,
- ])
- .to_owned(),
- )
- .exec(&*tx)
- .await?;
- }
}
Ok(proto::JoinChannelBufferResponse {
@@ -487,13 +457,8 @@ impl Database {
if !operations.is_empty() {
// get current channel participants and save the max operation above
- self.save_max_operation_for_collaborators(
- operations.as_slice(),
- channel_id,
- buffer.id,
- &*tx,
- )
- .await?;
+ self.save_max_operation(user, buffer.id, buffer.epoch, operations.as_slice(), &*tx)
+ .await?;
channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
let collaborators = self
@@ -539,54 +504,55 @@ impl Database {
.await
}
- async fn save_max_operation_for_collaborators(
+ async fn save_max_operation(
&self,
- operations: &[buffer_operation::ActiveModel],
- channel_id: ChannelId,
+ user_id: UserId,
buffer_id: BufferId,
+ epoch: i32,
+ operations: &[buffer_operation::ActiveModel],
tx: &DatabaseTransaction,
) -> Result<()> {
+ use observed_buffer_edits::Column;
+
let max_operation = operations
.iter()
- .map(|storage_model| {
- (
- storage_model.epoch.clone(),
- storage_model.lamport_timestamp.clone(),
- )
- })
- .max_by(
- |(epoch_a, lamport_timestamp_a), (epoch_b, lamport_timestamp_b)| {
- epoch_a.as_ref().cmp(epoch_b.as_ref()).then(
- lamport_timestamp_a
- .as_ref()
- .cmp(lamport_timestamp_b.as_ref()),
- )
- },
- )
+ .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
.unwrap();
- let users = self
- .get_channel_buffer_collaborators_internal(channel_id, tx)
- .await?;
-
- observed_buffer_edits::Entity::insert_many(users.iter().map(|id| {
- observed_buffer_edits::ActiveModel {
- user_id: ActiveValue::Set(*id),
- buffer_id: ActiveValue::Set(buffer_id),
- epoch: max_operation.0.clone(),
- lamport_timestamp: ActiveValue::Set(*max_operation.1.as_ref()),
- }
- }))
+ observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
+ user_id: ActiveValue::Set(user_id),
+ buffer_id: ActiveValue::Set(buffer_id),
+ epoch: ActiveValue::Set(epoch),
+ replica_id: max_operation.replica_id.clone(),
+ lamport_timestamp: max_operation.lamport_timestamp.clone(),
+ })
.on_conflict(
- OnConflict::columns([
- observed_buffer_edits::Column::UserId,
- observed_buffer_edits::Column::BufferId,
- ])
- .update_columns([
- observed_buffer_edits::Column::Epoch,
- observed_buffer_edits::Column::LamportTimestamp,
- ])
- .to_owned(),
+ OnConflict::columns([Column::UserId, Column::BufferId])
+ .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
+ .target_cond_where(
+ Condition::any()
+ .add(Column::Epoch.lt(*max_operation.epoch.as_ref()))
+ .add(
+ Condition::all()
+ .add(Column::Epoch.eq(*max_operation.epoch.as_ref()))
+ .add(
+ Condition::any()
+ .add(
+ Column::LamportTimestamp
+ .lt(*max_operation.lamport_timestamp.as_ref()),
+ )
+ .add(
+ Column::LamportTimestamp
+ .eq(*max_operation.lamport_timestamp.as_ref())
+ .and(
+ Column::ReplicaId
+ .lt(*max_operation.replica_id.as_ref()),
+ ),
+ ),
+ ),
+ ),
+ )
+ .to_owned(),
)
.exec(tx)
.await?;
@@ -611,7 +577,7 @@ impl Database {
.ok_or_else(|| anyhow!("missing buffer snapshot"))?)
}
- async fn get_channel_buffer(
+ pub async fn get_channel_buffer(
&self,
channel_id: ChannelId,
tx: &DatabaseTransaction,
@@ -630,7 +596,11 @@ impl Database {
&self,
buffer: &buffer::Model,
tx: &DatabaseTransaction,
- ) -> Result<(String, Vec<proto::Operation>, Option<(i32, i32)>)> {
+ ) -> Result<(
+ String,
+ Vec<proto::Operation>,
+ Option<buffer_operation::Model>,
+ )> {
let id = buffer.id;
let (base_text, version) = if buffer.epoch > 0 {
let snapshot = buffer_snapshot::Entity::find()
@@ -655,24 +625,28 @@ impl Database {
.eq(id)
.and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
)
+ .order_by_asc(buffer_operation::Column::LamportTimestamp)
+ .order_by_asc(buffer_operation::Column::ReplicaId)
.stream(&*tx)
.await?;
- let mut operations = Vec::new();
- let mut max_epoch: Option<i32> = None;
- let mut max_timestamp: Option<i32> = None;
+ let mut operations = Vec::new();
+ let mut last_row = None;
while let Some(row) = rows.next().await {
let row = row?;
-
- max_assign(&mut max_epoch, row.epoch);
- max_assign(&mut max_timestamp, row.lamport_timestamp);
-
+ last_row = Some(buffer_operation::Model {
+ buffer_id: row.buffer_id,
+ epoch: row.epoch,
+ lamport_timestamp: row.lamport_timestamp,
+ replica_id: row.lamport_timestamp,
+ value: Default::default(),
+ });
operations.push(proto::Operation {
variant: Some(operation_from_storage(row, version)?),
- })
+ });
}
- Ok((base_text, operations, max_epoch.zip(max_timestamp)))
+ Ok((base_text, operations, last_row))
}
async fn snapshot_channel_buffer(
@@ -725,6 +699,119 @@ impl Database {
.await
}
+ pub async fn channels_with_changed_notes(
+ &self,
+ user_id: UserId,
+ channel_ids: impl IntoIterator<Item = ChannelId>,
+ tx: &DatabaseTransaction,
+ ) -> Result<HashSet<ChannelId>> {
+ #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+ enum QueryIds {
+ ChannelId,
+ Id,
+ }
+
+ let mut channel_ids_by_buffer_id = HashMap::default();
+ let mut rows = buffer::Entity::find()
+ .filter(buffer::Column::ChannelId.is_in(channel_ids))
+ .stream(&*tx)
+ .await?;
+ while let Some(row) = rows.next().await {
+ let row = row?;
+ channel_ids_by_buffer_id.insert(row.id, row.channel_id);
+ }
+ drop(rows);
+
+ let mut observed_edits_by_buffer_id = HashMap::default();
+ let mut rows = observed_buffer_edits::Entity::find()
+ .filter(observed_buffer_edits::Column::UserId.eq(user_id))
+ .filter(
+ observed_buffer_edits::Column::BufferId
+ .is_in(channel_ids_by_buffer_id.keys().copied()),
+ )
+ .stream(&*tx)
+ .await?;
+ while let Some(row) = rows.next().await {
+ let row = row?;
+ observed_edits_by_buffer_id.insert(row.buffer_id, row);
+ }
+ drop(rows);
+
+ let last_operations = self
+ .get_last_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
+ .await?;
+
+ let mut channels_with_new_changes = HashSet::default();
+ for last_operation in last_operations {
+ if let Some(observed_edit) = observed_edits_by_buffer_id.get(&last_operation.buffer_id)
+ {
+ if observed_edit.epoch == last_operation.epoch
+ && observed_edit.lamport_timestamp == last_operation.lamport_timestamp
+ && observed_edit.replica_id == last_operation.replica_id
+ {
+ continue;
+ }
+ }
+
+ if let Some(channel_id) = channel_ids_by_buffer_id.get(&last_operation.buffer_id) {
+ channels_with_new_changes.insert(*channel_id);
+ }
+ }
+
+ Ok(channels_with_new_changes)
+ }
+
+ pub async fn get_last_operations_for_buffers(
+ &self,
+ channel_ids: impl IntoIterator<Item = BufferId>,
+ tx: &DatabaseTransaction,
+ ) -> Result<Vec<buffer_operation::Model>> {
+ let mut values = String::new();
+ for id in channel_ids {
+ if !values.is_empty() {
+ values.push_str(", ");
+ }
+ write!(&mut values, "({})", id).unwrap();
+ }
+
+ if values.is_empty() {
+ return Ok(Vec::default());
+ }
+
+ let sql = format!(
+ r#"
+ SELECT
+ *
+ FROM (
+ SELECT
+ buffer_id,
+ epoch,
+ lamport_timestamp,
+ replica_id,
+ value,
+ row_number() OVER (
+ PARTITION BY buffer_id
+ ORDER BY
+ epoch DESC,
+ lamport_timestamp DESC,
+ replica_id DESC
+ ) as row_number
+ FROM buffer_operations
+ WHERE
+ buffer_id in ({values})
+ ) AS operations
+ WHERE
+ row_number = 1
+ "#,
+ );
+
+ let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
+ let operations = buffer_operation::Model::find_by_statement(stmt)
+ .all(&*tx)
+ .await?;
+ Ok(operations)
+ }
+
pub async fn has_note_changed(
&self,
user_id: UserId,
@@ -463,12 +463,16 @@ impl Database {
}
}
- let mut channels_with_changed_notes = HashSet::default();
+ let channels_with_changed_notes = self
+ .channels_with_changed_notes(
+ user_id,
+ graph.channels.iter().map(|channel| channel.id),
+ &*tx,
+ )
+ .await?;
+
let mut channels_with_new_messages = HashSet::default();
for channel in graph.channels.iter() {
- if self.has_note_changed(user_id, channel.id, tx).await? {
- channels_with_changed_notes.insert(channel.id);
- }
if self.has_new_message(channel.id, user_id, tx).await? {
channels_with_new_messages.insert(channel.id);
}
@@ -9,6 +9,7 @@ pub struct Model {
pub buffer_id: BufferId,
pub epoch: i32,
pub lamport_timestamp: i32,
+ pub replica_id: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -272,3 +272,193 @@ async fn test_channel_buffers_diffs(db: &Database) {
assert!(!db.test_has_note_changed(a_id, zed_id).await.unwrap());
assert!(!db.test_has_note_changed(b_id, zed_id).await.unwrap());
}
+
+test_both_dbs!(
+ test_channel_buffers_last_operations,
+ test_channel_buffers_last_operations_postgres,
+ test_channel_buffers_last_operations_sqlite
+);
+
+async fn test_channel_buffers_last_operations(db: &Database) {
+ let user_id = db
+ .create_user(
+ "user_a@example.com",
+ false,
+ NewUserParams {
+ github_login: "user_a".into(),
+ github_user_id: 101,
+ invite_count: 0,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+ let owner_id = db.create_server("production").await.unwrap().0 as u32;
+ let connection_id = ConnectionId {
+ owner_id,
+ id: user_id.0 as u32,
+ };
+
+ let mut buffers = Vec::new();
+ let mut text_buffers = Vec::new();
+ for i in 0..3 {
+ let channel = db
+ .create_root_channel(&format!("channel-{i}"), &format!("room-{i}"), user_id)
+ .await
+ .unwrap();
+
+ db.join_channel_buffer(channel, user_id, connection_id)
+ .await
+ .unwrap();
+
+ buffers.push(
+ db.transaction(|tx| async move { db.get_channel_buffer(channel, &*tx).await })
+ .await
+ .unwrap(),
+ );
+
+ text_buffers.push(Buffer::new(0, 0, "".to_string()));
+ }
+
+ let operations = db
+ .transaction(|tx| {
+ let buffers = &buffers;
+ async move {
+ db.get_last_operations_for_buffers([buffers[0].id, buffers[2].id], &*tx)
+ .await
+ }
+ })
+ .await
+ .unwrap();
+
+ assert!(operations.is_empty());
+
+ update_buffer(
+ buffers[0].channel_id,
+ user_id,
+ db,
+ vec![
+ text_buffers[0].edit([(0..0, "a")]),
+ text_buffers[0].edit([(0..0, "b")]),
+ text_buffers[0].edit([(0..0, "c")]),
+ ],
+ )
+ .await;
+
+ update_buffer(
+ buffers[1].channel_id,
+ user_id,
+ db,
+ vec![
+ text_buffers[1].edit([(0..0, "d")]),
+ text_buffers[1].edit([(1..1, "e")]),
+ text_buffers[1].edit([(2..2, "f")]),
+ ],
+ )
+ .await;
+
+ // cause buffer 1's epoch to increment.
+ db.leave_channel_buffer(buffers[1].channel_id, connection_id)
+ .await
+ .unwrap();
+ db.join_channel_buffer(buffers[1].channel_id, user_id, connection_id)
+ .await
+ .unwrap();
+ text_buffers[1] = Buffer::new(1, 0, "def".to_string());
+ update_buffer(
+ buffers[1].channel_id,
+ user_id,
+ db,
+ vec![
+ text_buffers[1].edit([(0..0, "g")]),
+ text_buffers[1].edit([(0..0, "h")]),
+ ],
+ )
+ .await;
+
+ update_buffer(
+ buffers[2].channel_id,
+ user_id,
+ db,
+ vec![text_buffers[2].edit([(0..0, "i")])],
+ )
+ .await;
+
+ let operations = db
+ .transaction(|tx| {
+ let buffers = &buffers;
+ async move {
+ db.get_last_operations_for_buffers([buffers[1].id, buffers[2].id], &*tx)
+ .await
+ }
+ })
+ .await
+ .unwrap();
+ assert_operations(
+ &operations,
+ &[
+ (buffers[1].id, 1, &text_buffers[1]),
+ (buffers[2].id, 0, &text_buffers[2]),
+ ],
+ );
+
+ let operations = db
+ .transaction(|tx| {
+ let buffers = &buffers;
+ async move {
+ db.get_last_operations_for_buffers([buffers[0].id, buffers[1].id], &*tx)
+ .await
+ }
+ })
+ .await
+ .unwrap();
+ assert_operations(
+ &operations,
+ &[
+ (buffers[0].id, 0, &text_buffers[0]),
+ (buffers[1].id, 1, &text_buffers[1]),
+ ],
+ );
+
+ async fn update_buffer(
+ channel_id: ChannelId,
+ user_id: UserId,
+ db: &Database,
+ operations: Vec<text::Operation>,
+ ) {
+ let operations = operations
+ .into_iter()
+ .map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
+ .collect::<Vec<_>>();
+ db.update_channel_buffer(channel_id, user_id, &operations)
+ .await
+ .unwrap();
+ }
+
+ fn assert_operations(
+ operations: &[buffer_operation::Model],
+ expected: &[(BufferId, i32, &text::Buffer)],
+ ) {
+ let actual = operations
+ .iter()
+ .map(|op| buffer_operation::Model {
+ buffer_id: op.buffer_id,
+ epoch: op.epoch,
+ lamport_timestamp: op.lamport_timestamp,
+ replica_id: op.replica_id,
+ value: vec![],
+ })
+ .collect::<Vec<_>>();
+ let expected = expected
+ .iter()
+ .map(|(buffer_id, epoch, buffer)| buffer_operation::Model {
+ buffer_id: *buffer_id,
+ epoch: *epoch,
+ lamport_timestamp: buffer.lamport_clock.value as i32 - 1,
+ replica_id: buffer.replica_id() as i32,
+ value: vec![],
+ })
+ .collect::<Vec<_>>();
+ assert_eq!(actual, expected, "unexpected operations")
+ }
+}