@@ -1,5 +1,6 @@
use super::*;
use prost::Message;
+use sea_query::Order::Desc;
use text::{EditOperation, UndoOperation};
pub struct LeftChannelBuffer {
@@ -74,7 +75,62 @@ impl Database {
.await?;
collaborators.push(collaborator);
- let (base_text, operations) = self.get_buffer_state(&buffer, &tx).await?;
+ let (base_text, operations, max_operation) =
+ self.get_buffer_state(&buffer, &tx).await?;
+
+ // Save the last observed operation
+ if let Some(max_operation) = max_operation {
+ observed_note_edits::Entity::insert(observed_note_edits::ActiveModel {
+ user_id: ActiveValue::Set(user_id),
+ channel_id: ActiveValue::Set(channel_id),
+ epoch: ActiveValue::Set(max_operation.0),
+ lamport_timestamp: ActiveValue::Set(max_operation.1),
+ })
+ .on_conflict(
+ OnConflict::columns([
+ observed_note_edits::Column::UserId,
+ observed_note_edits::Column::ChannelId,
+ ])
+ .update_columns([
+ observed_note_edits::Column::Epoch,
+ observed_note_edits::Column::LamportTimestamp,
+ ])
+ .to_owned(),
+ )
+ .exec(&*tx)
+ .await?;
+ } else {
+ let buffer_max = buffer_operation::Entity::find()
+ .filter(buffer_operation::Column::BufferId.eq(buffer.id))
+ .filter(buffer_operation::Column::Epoch.eq(buffer.epoch.saturating_sub(1)))
+ .order_by(buffer_operation::Column::Epoch, Desc)
+ .order_by(buffer_operation::Column::LamportTimestamp, Desc)
+ .one(&*tx)
+ .await?
+ .map(|model| (model.epoch, model.lamport_timestamp));
+
+ if let Some(buffer_max) = buffer_max {
+ observed_note_edits::Entity::insert(observed_note_edits::ActiveModel {
+ user_id: ActiveValue::Set(user_id),
+ channel_id: ActiveValue::Set(channel_id),
+ epoch: ActiveValue::Set(buffer_max.0),
+ lamport_timestamp: ActiveValue::Set(buffer_max.1),
+ })
+ .on_conflict(
+ OnConflict::columns([
+ observed_note_edits::Column::UserId,
+ observed_note_edits::Column::ChannelId,
+ ])
+ .update_columns([
+ observed_note_edits::Column::Epoch,
+ observed_note_edits::Column::LamportTimestamp,
+ ])
+ .to_owned(),
+ )
+ .exec(&*tx)
+ .await?;
+ }
+ }
Ok(proto::JoinChannelBufferResponse {
buffer_id: buffer.id.to_proto(),
@@ -373,27 +429,35 @@ impl Database {
channel_id: ChannelId,
) -> Result<Vec<UserId>> {
self.transaction(|tx| async move {
- #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
- enum QueryUserIds {
- UserId,
- }
-
- let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
- .select_only()
- .column(channel_buffer_collaborator::Column::UserId)
- .filter(
- Condition::all()
- .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
- )
- .into_values::<_, QueryUserIds>()
- .all(&*tx)
- .await?;
-
- Ok(users)
+ self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
+ .await
})
.await
}
+ async fn get_channel_buffer_collaborators_internal(
+ &self,
+ channel_id: ChannelId,
+ tx: &DatabaseTransaction,
+ ) -> Result<Vec<UserId>> {
+ #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
+ enum QueryUserIds {
+ UserId,
+ }
+
+ let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
+ .select_only()
+ .column(channel_buffer_collaborator::Column::UserId)
+ .filter(
+ Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
+ )
+ .into_values::<_, QueryUserIds>()
+ .all(&*tx)
+ .await?;
+
+ Ok(users)
+ }
+
pub async fn update_channel_buffer(
&self,
channel_id: ChannelId,
@@ -418,7 +482,12 @@ impl Database {
.iter()
.filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
.collect::<Vec<_>>();
+
if !operations.is_empty() {
+ // get current channel participants and save the max operation above
+ self.save_max_operation_for_collaborators(operations.as_slice(), channel_id, &*tx)
+ .await?;
+
buffer_operation::Entity::insert_many(operations)
.on_conflict(
OnConflict::columns([
@@ -455,6 +524,60 @@ impl Database {
.await
}
+ async fn save_max_operation_for_collaborators(
+ &self,
+ operations: &[buffer_operation::ActiveModel],
+ channel_id: ChannelId,
+ tx: &DatabaseTransaction,
+ ) -> Result<()> {
+ let max_operation = operations
+ .iter()
+ .map(|storage_model| {
+ (
+ storage_model.epoch.clone(),
+ storage_model.lamport_timestamp.clone(),
+ )
+ })
+ .max_by(
+ |(epoch_a, lamport_timestamp_a), (epoch_b, lamport_timestamp_b)| {
+ epoch_a.as_ref().cmp(epoch_b.as_ref()).then(
+ lamport_timestamp_a
+ .as_ref()
+ .cmp(lamport_timestamp_b.as_ref()),
+ )
+ },
+ )
+ .unwrap();
+
+ let users = self
+ .get_channel_buffer_collaborators_internal(channel_id, tx)
+ .await?;
+
+ observed_note_edits::Entity::insert_many(users.iter().map(|id| {
+ observed_note_edits::ActiveModel {
+ user_id: ActiveValue::Set(*id),
+ channel_id: ActiveValue::Set(channel_id),
+ epoch: max_operation.0.clone(),
+ lamport_timestamp: ActiveValue::Set(*max_operation.1.as_ref()),
+ }
+ }))
+ .on_conflict(
+ OnConflict::columns([
+ observed_note_edits::Column::UserId,
+ observed_note_edits::Column::ChannelId,
+ ])
+ .update_columns([
+ observed_note_edits::Column::Epoch,
+ observed_note_edits::Column::LamportTimestamp,
+ ])
+ .to_owned(),
+ )
+ .exec(tx)
+ .await?;
+
+ Ok(())
+ }
+
async fn get_buffer_operation_serialization_version(
&self,
buffer_id: BufferId,
@@ -491,7 +614,7 @@ impl Database {
&self,
buffer: &buffer::Model,
tx: &DatabaseTransaction,
- ) -> Result<(String, Vec<proto::Operation>)> {
+ ) -> Result<(String, Vec<proto::Operation>, Option<(i32, i32)>)> {
let id = buffer.id;
let (base_text, version) = if buffer.epoch > 0 {
let snapshot = buffer_snapshot::Entity::find()
@@ -519,13 +642,21 @@ impl Database {
.stream(&*tx)
.await?;
let mut operations = Vec::new();
+
+ let mut max_epoch: Option<i32> = None;
+ let mut max_timestamp: Option<i32> = None;
while let Some(row) = rows.next().await {
+ let row = row?;
+
+ max_assign(&mut max_epoch, row.epoch);
+ max_assign(&mut max_timestamp, row.lamport_timestamp);
+
operations.push(proto::Operation {
- variant: Some(operation_from_storage(row?, version)?),
+ variant: Some(operation_from_storage(row, version)?),
})
}
- Ok((base_text, operations))
+ Ok((base_text, operations, max_epoch.zip(max_timestamp)))
}
async fn snapshot_channel_buffer(
@@ -534,7 +665,7 @@ impl Database {
tx: &DatabaseTransaction,
) -> Result<()> {
let buffer = self.get_channel_buffer(channel_id, tx).await?;
- let (base_text, operations) = self.get_buffer_state(&buffer, tx).await?;
+ let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
if operations.is_empty() {
return Ok(());
}
@@ -567,6 +698,66 @@ impl Database {
Ok(())
}
+
+ pub async fn has_buffer_changed(&self, user_id: UserId, channel_id: ChannelId) -> Result<bool> {
+ self.transaction(|tx| async move {
+ let user_max = observed_note_edits::Entity::find()
+ .filter(observed_note_edits::Column::UserId.eq(user_id))
+ .filter(observed_note_edits::Column::ChannelId.eq(channel_id))
+ .one(&*tx)
+ .await?
+ .map(|model| (model.epoch, model.lamport_timestamp));
+
+ let channel_buffer = channel::Model {
+ id: channel_id,
+ ..Default::default()
+ }
+ .find_related(buffer::Entity)
+ .one(&*tx)
+ .await?;
+
+ let Some(channel_buffer) = channel_buffer else {
+ return Ok(false);
+ };
+
+ let mut channel_max = buffer_operation::Entity::find()
+ .filter(buffer_operation::Column::BufferId.eq(channel_buffer.id))
+ .filter(buffer_operation::Column::Epoch.eq(channel_buffer.epoch))
+ .order_by(buffer_operation::Column::Epoch, Desc)
+ .order_by(buffer_operation::Column::LamportTimestamp, Desc)
+ .one(&*tx)
+ .await?
+ .map(|model| (model.epoch, model.lamport_timestamp));
+
+ // If there are no edits in this epoch
+ if channel_max.is_none() {
+ // check if this user observed the last edit of the previous epoch
+ channel_max = buffer_operation::Entity::find()
+ .filter(buffer_operation::Column::BufferId.eq(channel_buffer.id))
+ .filter(
+ buffer_operation::Column::Epoch.eq(channel_buffer.epoch.saturating_sub(1)),
+ )
+ .order_by(buffer_operation::Column::Epoch, Desc)
+ .order_by(buffer_operation::Column::LamportTimestamp, Desc)
+ .one(&*tx)
+ .await?
+ .map(|model| (model.epoch, model.lamport_timestamp));
+ }
+
+ Ok(user_max != channel_max)
+ })
+ .await
+ }
+}
+
+fn max_assign<T: Ord>(max: &mut Option<T>, val: T) {
+ if let Some(max_val) = max {
+ if val > *max_val {
+ *max = Some(val);
+ }
+ } else {
+ *max = Some(val);
+ }
}
fn operation_to_storage(
@@ -0,0 +1,42 @@
+use crate::db::{ChannelId, UserId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "observed_channel_note_edits")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub user_id: UserId,
+ pub channel_id: ChannelId,
+ pub epoch: i32,
+ pub lamport_timestamp: i32,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::channel::Entity",
+ from = "Column::ChannelId",
+ to = "super::channel::Column::Id"
+ )]
+ Channel,
+ #[sea_orm(
+ belongs_to = "super::user::Entity",
+ from = "Column::UserId",
+ to = "super::user::Column::Id"
+ )]
+ User,
+}
+
+impl Related<super::channel::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Channel.def()
+ }
+}
+
+impl Related<super::user::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::User.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -163,3 +163,111 @@ async fn test_channel_buffers(db: &Arc<Database>) {
assert_eq!(buffer_response_b.base_text, "hello, cruel world");
assert_eq!(buffer_response_b.operations, &[]);
}
+
+test_both_dbs!(
+ test_channel_buffers_diffs,
+ test_channel_buffers_diffs_postgres,
+ test_channel_buffers_diffs_sqlite
+);
+
+async fn test_channel_buffers_diffs(db: &Database) {
+ let a_id = db
+ .create_user(
+ "user_a@example.com",
+ false,
+ NewUserParams {
+ github_login: "user_a".into(),
+ github_user_id: 101,
+ invite_count: 0,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+ let b_id = db
+ .create_user(
+ "user_b@example.com",
+ false,
+ NewUserParams {
+ github_login: "user_b".into(),
+ github_user_id: 102,
+ invite_count: 0,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+
+ let owner_id = db.create_server("production").await.unwrap().0 as u32;
+
+ let zed_id = db.create_root_channel("zed", "1", a_id).await.unwrap();
+
+ db.invite_channel_member(zed_id, b_id, a_id, false)
+ .await
+ .unwrap();
+
+ db.respond_to_channel_invite(zed_id, b_id, true)
+ .await
+ .unwrap();
+
+ let connection_id_a = ConnectionId {
+ owner_id,
+ id: a_id.0 as u32,
+ };
+ let connection_id_b = ConnectionId {
+ owner_id,
+ id: b_id.0 as u32,
+ };
+
+ // Zero test: A should not register as changed on an unitialized channel buffer
+ assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+
+ let _ = db
+ .join_channel_buffer(zed_id, a_id, connection_id_a)
+ .await
+ .unwrap();
+
+ // Zero test: A should register as changed on an empty channel buffer
+ assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+
+ let mut buffer_a = Buffer::new(0, 0, "".to_string());
+ let mut operations = Vec::new();
+ operations.push(buffer_a.edit([(0..0, "hello world")]));
+ assert_eq!(buffer_a.text(), "hello world");
+
+ let operations = operations
+ .into_iter()
+ .map(|op| proto::serialize_operation(&language::Operation::Buffer(op)))
+ .collect::<Vec<_>>();
+
+ db.update_channel_buffer(zed_id, a_id, &operations)
+ .await
+ .unwrap();
+
+ // Smoke test: Does B register as changed, A as unchanged?
+ assert!(db.has_buffer_changed(b_id, zed_id).await.unwrap());
+ assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+
+ db.leave_channel_buffer(zed_id, connection_id_a)
+ .await
+ .unwrap();
+
+ // Snapshotting from leaving the channel buffer should not have a diff
+ assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+
+ let _ = db
+ .join_channel_buffer(zed_id, b_id, connection_id_b)
+ .await
+ .unwrap();
+
+ // B has opened the channel buffer, so we shouldn't have any diff
+ assert!(!db.has_buffer_changed(b_id, zed_id).await.unwrap());
+
+ db.leave_channel_buffer(zed_id, connection_id_b)
+ .await
+ .unwrap();
+
+ // Since B just opened and closed the buffer without editing, neither should have a diff
+ assert!(!db.has_buffer_changed(a_id, zed_id).await.unwrap());
+ assert!(!db.has_buffer_changed(b_id, zed_id).await.unwrap());
+}