buffers.rs

   1use super::*;
   2use prost::Message;
   3use text::{EditOperation, UndoOperation};
   4
   5pub struct LeftChannelBuffer {
   6    pub channel_id: ChannelId,
   7    pub collaborators: Vec<proto::Collaborator>,
   8    pub connections: Vec<ConnectionId>,
   9}
  10
  11impl Database {
  12    pub async fn join_channel_buffer(
  13        &self,
  14        channel_id: ChannelId,
  15        user_id: UserId,
  16        connection: ConnectionId,
  17    ) -> Result<proto::JoinChannelBufferResponse> {
  18        self.transaction(|tx| async move {
  19            let channel = self.get_channel_internal(channel_id, &*tx).await?;
  20            self.check_user_is_channel_participant(&channel, user_id, &tx)
  21                .await?;
  22
  23            let buffer = channel::Model {
  24                id: channel_id,
  25                ..Default::default()
  26            }
  27            .find_related(buffer::Entity)
  28            .one(&*tx)
  29            .await?;
  30
  31            let buffer = if let Some(buffer) = buffer {
  32                buffer
  33            } else {
  34                let buffer = buffer::ActiveModel {
  35                    channel_id: ActiveValue::Set(channel_id),
  36                    ..Default::default()
  37                }
  38                .insert(&*tx)
  39                .await?;
  40                buffer_snapshot::ActiveModel {
  41                    buffer_id: ActiveValue::Set(buffer.id),
  42                    epoch: ActiveValue::Set(0),
  43                    text: ActiveValue::Set(String::new()),
  44                    operation_serialization_version: ActiveValue::Set(
  45                        storage::SERIALIZATION_VERSION,
  46                    ),
  47                }
  48                .insert(&*tx)
  49                .await?;
  50                buffer
  51            };
  52
  53            // Join the collaborators
  54            let mut collaborators = channel_buffer_collaborator::Entity::find()
  55                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
  56                .all(&*tx)
  57                .await?;
  58            let replica_ids = collaborators
  59                .iter()
  60                .map(|c| c.replica_id)
  61                .collect::<HashSet<_>>();
  62            let mut replica_id = ReplicaId(0);
  63            while replica_ids.contains(&replica_id) {
  64                replica_id.0 += 1;
  65            }
  66            let collaborator = channel_buffer_collaborator::ActiveModel {
  67                channel_id: ActiveValue::Set(channel_id),
  68                connection_id: ActiveValue::Set(connection.id as i32),
  69                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
  70                user_id: ActiveValue::Set(user_id),
  71                replica_id: ActiveValue::Set(replica_id),
  72                ..Default::default()
  73            }
  74            .insert(&*tx)
  75            .await?;
  76            collaborators.push(collaborator);
  77
  78            let (base_text, operations, max_operation) =
  79                self.get_buffer_state(&buffer, &tx).await?;
  80
  81            // Save the last observed operation
  82            if let Some(op) = max_operation {
  83                observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
  84                    user_id: ActiveValue::Set(user_id),
  85                    buffer_id: ActiveValue::Set(buffer.id),
  86                    epoch: ActiveValue::Set(op.epoch),
  87                    lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
  88                    replica_id: ActiveValue::Set(op.replica_id),
  89                })
  90                .on_conflict(
  91                    OnConflict::columns([
  92                        observed_buffer_edits::Column::UserId,
  93                        observed_buffer_edits::Column::BufferId,
  94                    ])
  95                    .update_columns([
  96                        observed_buffer_edits::Column::Epoch,
  97                        observed_buffer_edits::Column::LamportTimestamp,
  98                    ])
  99                    .to_owned(),
 100                )
 101                .exec(&*tx)
 102                .await?;
 103            }
 104
 105            Ok(proto::JoinChannelBufferResponse {
 106                buffer_id: buffer.id.to_proto(),
 107                replica_id: replica_id.to_proto() as u32,
 108                base_text,
 109                operations,
 110                epoch: buffer.epoch as u64,
 111                collaborators: collaborators
 112                    .into_iter()
 113                    .map(|collaborator| proto::Collaborator {
 114                        peer_id: Some(collaborator.connection().into()),
 115                        user_id: collaborator.user_id.to_proto(),
 116                        replica_id: collaborator.replica_id.0 as u32,
 117                    })
 118                    .collect(),
 119            })
 120        })
 121        .await
 122    }
 123
 124    pub async fn rejoin_channel_buffers(
 125        &self,
 126        buffers: &[proto::ChannelBufferVersion],
 127        user_id: UserId,
 128        connection_id: ConnectionId,
 129    ) -> Result<Vec<RejoinedChannelBuffer>> {
 130        self.transaction(|tx| async move {
 131            let mut results = Vec::new();
 132            for client_buffer in buffers {
 133                let channel = self
 134                    .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &*tx)
 135                    .await?;
 136                if self
 137                    .check_user_is_channel_participant(&channel, user_id, &*tx)
 138                    .await
 139                    .is_err()
 140                {
 141                    log::info!("user is not a member of channel");
 142                    continue;
 143                }
 144
 145                let buffer = self.get_channel_buffer(channel.id, &*tx).await?;
 146                let mut collaborators = channel_buffer_collaborator::Entity::find()
 147                    .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id))
 148                    .all(&*tx)
 149                    .await?;
 150
 151                // If the buffer epoch hasn't changed since the client lost
 152                // connection, then the client's buffer can be syncronized with
 153                // the server's buffer.
 154                if buffer.epoch as u64 != client_buffer.epoch {
 155                    log::info!("can't rejoin buffer, epoch has changed");
 156                    continue;
 157                }
 158
 159                // Find the collaborator record for this user's previous lost
 160                // connection. Update it with the new connection id.
 161                let server_id = ServerId(connection_id.owner_id as i32);
 162                let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
 163                    c.user_id == user_id
 164                        && (c.connection_lost || c.connection_server_id != server_id)
 165                }) else {
 166                    log::info!("can't rejoin buffer, no previous collaborator found");
 167                    continue;
 168                };
 169                let old_connection_id = self_collaborator.connection();
 170                *self_collaborator = channel_buffer_collaborator::ActiveModel {
 171                    id: ActiveValue::Unchanged(self_collaborator.id),
 172                    connection_id: ActiveValue::Set(connection_id.id as i32),
 173                    connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 174                    connection_lost: ActiveValue::Set(false),
 175                    ..Default::default()
 176                }
 177                .update(&*tx)
 178                .await?;
 179
 180                let client_version = version_from_wire(&client_buffer.version);
 181                let serialization_version = self
 182                    .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
 183                    .await?;
 184
 185                let mut rows = buffer_operation::Entity::find()
 186                    .filter(
 187                        buffer_operation::Column::BufferId
 188                            .eq(buffer.id)
 189                            .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 190                    )
 191                    .stream(&*tx)
 192                    .await?;
 193
 194                // Find the server's version vector and any operations
 195                // that the client has not seen.
 196                let mut server_version = clock::Global::new();
 197                let mut operations = Vec::new();
 198                while let Some(row) = rows.next().await {
 199                    let row = row?;
 200                    let timestamp = clock::Lamport {
 201                        replica_id: row.replica_id as u16,
 202                        value: row.lamport_timestamp as u32,
 203                    };
 204                    server_version.observe(timestamp);
 205                    if !client_version.observed(timestamp) {
 206                        operations.push(proto::Operation {
 207                            variant: Some(operation_from_storage(row, serialization_version)?),
 208                        })
 209                    }
 210                }
 211
 212                results.push(RejoinedChannelBuffer {
 213                    old_connection_id,
 214                    buffer: proto::RejoinedChannelBuffer {
 215                        channel_id: client_buffer.channel_id,
 216                        version: version_to_wire(&server_version),
 217                        operations,
 218                        collaborators: collaborators
 219                            .into_iter()
 220                            .map(|collaborator| proto::Collaborator {
 221                                peer_id: Some(collaborator.connection().into()),
 222                                user_id: collaborator.user_id.to_proto(),
 223                                replica_id: collaborator.replica_id.0 as u32,
 224                            })
 225                            .collect(),
 226                    },
 227                });
 228            }
 229
 230            Ok(results)
 231        })
 232        .await
 233    }
 234
 235    pub async fn clear_stale_channel_buffer_collaborators(
 236        &self,
 237        channel_id: ChannelId,
 238        server_id: ServerId,
 239    ) -> Result<RefreshedChannelBuffer> {
 240        self.transaction(|tx| async move {
 241            let db_collaborators = channel_buffer_collaborator::Entity::find()
 242                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 243                .all(&*tx)
 244                .await?;
 245
 246            let mut connection_ids = Vec::new();
 247            let mut collaborators = Vec::new();
 248            let mut collaborator_ids_to_remove = Vec::new();
 249            for db_collaborator in &db_collaborators {
 250                if !db_collaborator.connection_lost
 251                    && db_collaborator.connection_server_id == server_id
 252                {
 253                    connection_ids.push(db_collaborator.connection());
 254                    collaborators.push(proto::Collaborator {
 255                        peer_id: Some(db_collaborator.connection().into()),
 256                        replica_id: db_collaborator.replica_id.0 as u32,
 257                        user_id: db_collaborator.user_id.to_proto(),
 258                    })
 259                } else {
 260                    collaborator_ids_to_remove.push(db_collaborator.id);
 261                }
 262            }
 263
 264            channel_buffer_collaborator::Entity::delete_many()
 265                .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
 266                .exec(&*tx)
 267                .await?;
 268
 269            Ok(RefreshedChannelBuffer {
 270                connection_ids,
 271                collaborators,
 272            })
 273        })
 274        .await
 275    }
 276
 277    pub async fn leave_channel_buffer(
 278        &self,
 279        channel_id: ChannelId,
 280        connection: ConnectionId,
 281    ) -> Result<LeftChannelBuffer> {
 282        self.transaction(|tx| async move {
 283            self.leave_channel_buffer_internal(channel_id, connection, &*tx)
 284                .await
 285        })
 286        .await
 287    }
 288
 289    pub async fn channel_buffer_connection_lost(
 290        &self,
 291        connection: ConnectionId,
 292        tx: &DatabaseTransaction,
 293    ) -> Result<()> {
 294        channel_buffer_collaborator::Entity::update_many()
 295            .filter(
 296                Condition::all()
 297                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 298                    .add(
 299                        channel_buffer_collaborator::Column::ConnectionServerId
 300                            .eq(connection.owner_id as i32),
 301                    ),
 302            )
 303            .set(channel_buffer_collaborator::ActiveModel {
 304                connection_lost: ActiveValue::set(true),
 305                ..Default::default()
 306            })
 307            .exec(&*tx)
 308            .await?;
 309        Ok(())
 310    }
 311
 312    pub async fn leave_channel_buffers(
 313        &self,
 314        connection: ConnectionId,
 315    ) -> Result<Vec<LeftChannelBuffer>> {
 316        self.transaction(|tx| async move {
 317            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 318            enum QueryChannelIds {
 319                ChannelId,
 320            }
 321
 322            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
 323                .select_only()
 324                .column(channel_buffer_collaborator::Column::ChannelId)
 325                .filter(Condition::all().add(
 326                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
 327                ))
 328                .into_values::<_, QueryChannelIds>()
 329                .all(&*tx)
 330                .await?;
 331
 332            let mut result = Vec::new();
 333            for channel_id in channel_ids {
 334                let left_channel_buffer = self
 335                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
 336                    .await?;
 337                result.push(left_channel_buffer);
 338            }
 339
 340            Ok(result)
 341        })
 342        .await
 343    }
 344
 345    pub async fn leave_channel_buffer_internal(
 346        &self,
 347        channel_id: ChannelId,
 348        connection: ConnectionId,
 349        tx: &DatabaseTransaction,
 350    ) -> Result<LeftChannelBuffer> {
 351        let result = channel_buffer_collaborator::Entity::delete_many()
 352            .filter(
 353                Condition::all()
 354                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 355                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 356                    .add(
 357                        channel_buffer_collaborator::Column::ConnectionServerId
 358                            .eq(connection.owner_id as i32),
 359                    ),
 360            )
 361            .exec(&*tx)
 362            .await?;
 363        if result.rows_affected == 0 {
 364            Err(anyhow!("not a collaborator on this project"))?;
 365        }
 366
 367        let mut collaborators = Vec::new();
 368        let mut connections = Vec::new();
 369        let mut rows = channel_buffer_collaborator::Entity::find()
 370            .filter(
 371                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 372            )
 373            .stream(&*tx)
 374            .await?;
 375        while let Some(row) = rows.next().await {
 376            let row = row?;
 377            let connection = row.connection();
 378            connections.push(connection);
 379            collaborators.push(proto::Collaborator {
 380                peer_id: Some(connection.into()),
 381                replica_id: row.replica_id.0 as u32,
 382                user_id: row.user_id.to_proto(),
 383            });
 384        }
 385
 386        drop(rows);
 387
 388        if collaborators.is_empty() {
 389            self.snapshot_channel_buffer(channel_id, &tx).await?;
 390        }
 391
 392        Ok(LeftChannelBuffer {
 393            channel_id,
 394            collaborators,
 395            connections,
 396        })
 397    }
 398
 399    pub async fn get_channel_buffer_collaborators(
 400        &self,
 401        channel_id: ChannelId,
 402    ) -> Result<Vec<UserId>> {
 403        self.transaction(|tx| async move {
 404            self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
 405                .await
 406        })
 407        .await
 408    }
 409
 410    async fn get_channel_buffer_collaborators_internal(
 411        &self,
 412        channel_id: ChannelId,
 413        tx: &DatabaseTransaction,
 414    ) -> Result<Vec<UserId>> {
 415        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 416        enum QueryUserIds {
 417            UserId,
 418        }
 419
 420        let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
 421            .select_only()
 422            .column(channel_buffer_collaborator::Column::UserId)
 423            .filter(
 424                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 425            )
 426            .into_values::<_, QueryUserIds>()
 427            .all(&*tx)
 428            .await?;
 429
 430        Ok(users)
 431    }
 432
 433    pub async fn update_channel_buffer(
 434        &self,
 435        channel_id: ChannelId,
 436        user: UserId,
 437        operations: &[proto::Operation],
 438    ) -> Result<(
 439        Vec<ConnectionId>,
 440        Vec<UserId>,
 441        i32,
 442        Vec<proto::VectorClockEntry>,
 443    )> {
 444        self.transaction(move |tx| async move {
 445            let channel = self.get_channel_internal(channel_id, &*tx).await?;
 446            self.check_user_is_channel_member(&channel, user, &*tx)
 447                .await?;
 448
 449            let buffer = buffer::Entity::find()
 450                .filter(buffer::Column::ChannelId.eq(channel_id))
 451                .one(&*tx)
 452                .await?
 453                .ok_or_else(|| anyhow!("no such buffer"))?;
 454
 455            let serialization_version = self
 456                .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
 457                .await?;
 458
 459            let operations = operations
 460                .iter()
 461                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
 462                .collect::<Vec<_>>();
 463
 464            let mut channel_members;
 465            let max_version;
 466
 467            if !operations.is_empty() {
 468                let max_operation = operations
 469                    .iter()
 470                    .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
 471                    .unwrap();
 472
 473                max_version = vec![proto::VectorClockEntry {
 474                    replica_id: *max_operation.replica_id.as_ref() as u32,
 475                    timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
 476                }];
 477
 478                // get current channel participants and save the max operation above
 479                self.save_max_operation(
 480                    user,
 481                    buffer.id,
 482                    buffer.epoch,
 483                    *max_operation.replica_id.as_ref(),
 484                    *max_operation.lamport_timestamp.as_ref(),
 485                    &*tx,
 486                )
 487                .await?;
 488
 489                channel_members = self.get_channel_participants(&channel, &*tx).await?;
 490                let collaborators = self
 491                    .get_channel_buffer_collaborators_internal(channel_id, &*tx)
 492                    .await?;
 493                channel_members.retain(|member| !collaborators.contains(member));
 494
 495                buffer_operation::Entity::insert_many(operations)
 496                    .on_conflict(
 497                        OnConflict::columns([
 498                            buffer_operation::Column::BufferId,
 499                            buffer_operation::Column::Epoch,
 500                            buffer_operation::Column::LamportTimestamp,
 501                            buffer_operation::Column::ReplicaId,
 502                        ])
 503                        .do_nothing()
 504                        .to_owned(),
 505                    )
 506                    .exec(&*tx)
 507                    .await?;
 508            } else {
 509                channel_members = Vec::new();
 510                max_version = Vec::new();
 511            }
 512
 513            let mut connections = Vec::new();
 514            let mut rows = channel_buffer_collaborator::Entity::find()
 515                .filter(
 516                    Condition::all()
 517                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 518                )
 519                .stream(&*tx)
 520                .await?;
 521            while let Some(row) = rows.next().await {
 522                let row = row?;
 523                connections.push(ConnectionId {
 524                    id: row.connection_id as u32,
 525                    owner_id: row.connection_server_id.0 as u32,
 526                });
 527            }
 528
 529            Ok((connections, channel_members, buffer.epoch, max_version))
 530        })
 531        .await
 532    }
 533
 534    async fn save_max_operation(
 535        &self,
 536        user_id: UserId,
 537        buffer_id: BufferId,
 538        epoch: i32,
 539        replica_id: i32,
 540        lamport_timestamp: i32,
 541        tx: &DatabaseTransaction,
 542    ) -> Result<()> {
 543        use observed_buffer_edits::Column;
 544
 545        observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
 546            user_id: ActiveValue::Set(user_id),
 547            buffer_id: ActiveValue::Set(buffer_id),
 548            epoch: ActiveValue::Set(epoch),
 549            replica_id: ActiveValue::Set(replica_id),
 550            lamport_timestamp: ActiveValue::Set(lamport_timestamp),
 551        })
 552        .on_conflict(
 553            OnConflict::columns([Column::UserId, Column::BufferId])
 554                .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
 555                .action_cond_where(
 556                    Condition::any().add(Column::Epoch.lt(epoch)).add(
 557                        Condition::all().add(Column::Epoch.eq(epoch)).add(
 558                            Condition::any()
 559                                .add(Column::LamportTimestamp.lt(lamport_timestamp))
 560                                .add(
 561                                    Column::LamportTimestamp
 562                                        .eq(lamport_timestamp)
 563                                        .and(Column::ReplicaId.lt(replica_id)),
 564                                ),
 565                        ),
 566                    ),
 567                )
 568                .to_owned(),
 569        )
 570        .exec_without_returning(tx)
 571        .await?;
 572
 573        Ok(())
 574    }
 575
 576    async fn get_buffer_operation_serialization_version(
 577        &self,
 578        buffer_id: BufferId,
 579        epoch: i32,
 580        tx: &DatabaseTransaction,
 581    ) -> Result<i32> {
 582        Ok(buffer_snapshot::Entity::find()
 583            .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
 584            .filter(buffer_snapshot::Column::Epoch.eq(epoch))
 585            .select_only()
 586            .column(buffer_snapshot::Column::OperationSerializationVersion)
 587            .into_values::<_, QueryOperationSerializationVersion>()
 588            .one(&*tx)
 589            .await?
 590            .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
 591    }
 592
 593    pub async fn get_channel_buffer(
 594        &self,
 595        channel_id: ChannelId,
 596        tx: &DatabaseTransaction,
 597    ) -> Result<buffer::Model> {
 598        Ok(channel::Model {
 599            id: channel_id,
 600            ..Default::default()
 601        }
 602        .find_related(buffer::Entity)
 603        .one(&*tx)
 604        .await?
 605        .ok_or_else(|| anyhow!("no such buffer"))?)
 606    }
 607
 608    async fn get_buffer_state(
 609        &self,
 610        buffer: &buffer::Model,
 611        tx: &DatabaseTransaction,
 612    ) -> Result<(
 613        String,
 614        Vec<proto::Operation>,
 615        Option<buffer_operation::Model>,
 616    )> {
 617        let id = buffer.id;
 618        let (base_text, version) = if buffer.epoch > 0 {
 619            let snapshot = buffer_snapshot::Entity::find()
 620                .filter(
 621                    buffer_snapshot::Column::BufferId
 622                        .eq(id)
 623                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
 624                )
 625                .one(&*tx)
 626                .await?
 627                .ok_or_else(|| anyhow!("no such snapshot"))?;
 628
 629            let version = snapshot.operation_serialization_version;
 630            (snapshot.text, version)
 631        } else {
 632            (String::new(), storage::SERIALIZATION_VERSION)
 633        };
 634
 635        let mut rows = buffer_operation::Entity::find()
 636            .filter(
 637                buffer_operation::Column::BufferId
 638                    .eq(id)
 639                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 640            )
 641            .order_by_asc(buffer_operation::Column::LamportTimestamp)
 642            .order_by_asc(buffer_operation::Column::ReplicaId)
 643            .stream(&*tx)
 644            .await?;
 645
 646        let mut operations = Vec::new();
 647        let mut last_row = None;
 648        while let Some(row) = rows.next().await {
 649            let row = row?;
 650            last_row = Some(buffer_operation::Model {
 651                buffer_id: row.buffer_id,
 652                epoch: row.epoch,
 653                lamport_timestamp: row.lamport_timestamp,
 654                replica_id: row.lamport_timestamp,
 655                value: Default::default(),
 656            });
 657            operations.push(proto::Operation {
 658                variant: Some(operation_from_storage(row, version)?),
 659            });
 660        }
 661
 662        Ok((base_text, operations, last_row))
 663    }
 664
 665    async fn snapshot_channel_buffer(
 666        &self,
 667        channel_id: ChannelId,
 668        tx: &DatabaseTransaction,
 669    ) -> Result<()> {
 670        let buffer = self.get_channel_buffer(channel_id, tx).await?;
 671        let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
 672        if operations.is_empty() {
 673            return Ok(());
 674        }
 675
 676        let mut text_buffer = text::Buffer::new(0, 0, base_text);
 677        text_buffer
 678            .apply_ops(operations.into_iter().filter_map(operation_from_wire))
 679            .unwrap();
 680
 681        let base_text = text_buffer.text();
 682        let epoch = buffer.epoch + 1;
 683
 684        buffer_snapshot::Model {
 685            buffer_id: buffer.id,
 686            epoch,
 687            text: base_text,
 688            operation_serialization_version: storage::SERIALIZATION_VERSION,
 689        }
 690        .into_active_model()
 691        .insert(tx)
 692        .await?;
 693
 694        buffer::ActiveModel {
 695            id: ActiveValue::Unchanged(buffer.id),
 696            epoch: ActiveValue::Set(epoch),
 697            ..Default::default()
 698        }
 699        .save(tx)
 700        .await?;
 701
 702        Ok(())
 703    }
 704
 705    pub async fn observe_buffer_version(
 706        &self,
 707        buffer_id: BufferId,
 708        user_id: UserId,
 709        epoch: i32,
 710        version: &[proto::VectorClockEntry],
 711    ) -> Result<()> {
 712        self.transaction(|tx| async move {
 713            // For now, combine concurrent operations.
 714            let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
 715                return Ok(());
 716            };
 717            self.save_max_operation(
 718                user_id,
 719                buffer_id,
 720                epoch,
 721                component.replica_id as i32,
 722                component.timestamp as i32,
 723                &*tx,
 724            )
 725            .await?;
 726            Ok(())
 727        })
 728        .await
 729    }
 730
 731    pub async fn unseen_channel_buffer_changes(
 732        &self,
 733        user_id: UserId,
 734        channel_ids: &[ChannelId],
 735        tx: &DatabaseTransaction,
 736    ) -> Result<Vec<proto::UnseenChannelBufferChange>> {
 737        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 738        enum QueryIds {
 739            ChannelId,
 740            Id,
 741        }
 742
 743        let mut channel_ids_by_buffer_id = HashMap::default();
 744        let mut rows = buffer::Entity::find()
 745            .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied()))
 746            .stream(&*tx)
 747            .await?;
 748        while let Some(row) = rows.next().await {
 749            let row = row?;
 750            channel_ids_by_buffer_id.insert(row.id, row.channel_id);
 751        }
 752        drop(rows);
 753
 754        let mut observed_edits_by_buffer_id = HashMap::default();
 755        let mut rows = observed_buffer_edits::Entity::find()
 756            .filter(observed_buffer_edits::Column::UserId.eq(user_id))
 757            .filter(
 758                observed_buffer_edits::Column::BufferId
 759                    .is_in(channel_ids_by_buffer_id.keys().copied()),
 760            )
 761            .stream(&*tx)
 762            .await?;
 763        while let Some(row) = rows.next().await {
 764            let row = row?;
 765            observed_edits_by_buffer_id.insert(row.buffer_id, row);
 766        }
 767        drop(rows);
 768
 769        let latest_operations = self
 770            .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
 771            .await?;
 772
 773        let mut changes = Vec::default();
 774        for latest in latest_operations {
 775            if let Some(observed) = observed_edits_by_buffer_id.get(&latest.buffer_id) {
 776                if (
 777                    observed.epoch,
 778                    observed.lamport_timestamp,
 779                    observed.replica_id,
 780                ) >= (latest.epoch, latest.lamport_timestamp, latest.replica_id)
 781                {
 782                    continue;
 783                }
 784            }
 785
 786            if let Some(channel_id) = channel_ids_by_buffer_id.get(&latest.buffer_id) {
 787                changes.push(proto::UnseenChannelBufferChange {
 788                    channel_id: channel_id.to_proto(),
 789                    epoch: latest.epoch as u64,
 790                    version: vec![proto::VectorClockEntry {
 791                        replica_id: latest.replica_id as u32,
 792                        timestamp: latest.lamport_timestamp as u32,
 793                    }],
 794                });
 795            }
 796        }
 797
 798        Ok(changes)
 799    }
 800
 801    pub async fn get_latest_operations_for_buffers(
 802        &self,
 803        buffer_ids: impl IntoIterator<Item = BufferId>,
 804        tx: &DatabaseTransaction,
 805    ) -> Result<Vec<buffer_operation::Model>> {
 806        let mut values = String::new();
 807        for id in buffer_ids {
 808            if !values.is_empty() {
 809                values.push_str(", ");
 810            }
 811            write!(&mut values, "({})", id).unwrap();
 812        }
 813
 814        if values.is_empty() {
 815            return Ok(Vec::default());
 816        }
 817
 818        let sql = format!(
 819            r#"
 820            SELECT
 821                *
 822            FROM
 823            (
 824                SELECT
 825                    *,
 826                    row_number() OVER (
 827                        PARTITION BY buffer_id
 828                        ORDER BY
 829                            epoch DESC,
 830                            lamport_timestamp DESC,
 831                            replica_id DESC
 832                    ) as row_number
 833                FROM buffer_operations
 834                WHERE
 835                    buffer_id in ({values})
 836            ) AS last_operations
 837            WHERE
 838                row_number = 1
 839            "#,
 840        );
 841
 842        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 843        Ok(buffer_operation::Entity::find()
 844            .from_raw_sql(stmt)
 845            .all(&*tx)
 846            .await?)
 847    }
 848}
 849
 850fn operation_to_storage(
 851    operation: &proto::Operation,
 852    buffer: &buffer::Model,
 853    _format: i32,
 854) -> Option<buffer_operation::ActiveModel> {
 855    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
 856        proto::operation::Variant::Edit(operation) => (
 857            operation.replica_id,
 858            operation.lamport_timestamp,
 859            storage::Operation {
 860                version: version_to_storage(&operation.version),
 861                is_undo: false,
 862                edit_ranges: operation
 863                    .ranges
 864                    .iter()
 865                    .map(|range| storage::Range {
 866                        start: range.start,
 867                        end: range.end,
 868                    })
 869                    .collect(),
 870                edit_texts: operation.new_text.clone(),
 871                undo_counts: Vec::new(),
 872            },
 873        ),
 874        proto::operation::Variant::Undo(operation) => (
 875            operation.replica_id,
 876            operation.lamport_timestamp,
 877            storage::Operation {
 878                version: version_to_storage(&operation.version),
 879                is_undo: true,
 880                edit_ranges: Vec::new(),
 881                edit_texts: Vec::new(),
 882                undo_counts: operation
 883                    .counts
 884                    .iter()
 885                    .map(|entry| storage::UndoCount {
 886                        replica_id: entry.replica_id,
 887                        lamport_timestamp: entry.lamport_timestamp,
 888                        count: entry.count,
 889                    })
 890                    .collect(),
 891            },
 892        ),
 893        _ => None?,
 894    };
 895
 896    Some(buffer_operation::ActiveModel {
 897        buffer_id: ActiveValue::Set(buffer.id),
 898        epoch: ActiveValue::Set(buffer.epoch),
 899        replica_id: ActiveValue::Set(replica_id as i32),
 900        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
 901        value: ActiveValue::Set(value.encode_to_vec()),
 902    })
 903}
 904
 905fn operation_from_storage(
 906    row: buffer_operation::Model,
 907    _format_version: i32,
 908) -> Result<proto::operation::Variant, Error> {
 909    let operation =
 910        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
 911    let version = version_from_storage(&operation.version);
 912    Ok(if operation.is_undo {
 913        proto::operation::Variant::Undo(proto::operation::Undo {
 914            replica_id: row.replica_id as u32,
 915            lamport_timestamp: row.lamport_timestamp as u32,
 916            version,
 917            counts: operation
 918                .undo_counts
 919                .iter()
 920                .map(|entry| proto::UndoCount {
 921                    replica_id: entry.replica_id,
 922                    lamport_timestamp: entry.lamport_timestamp,
 923                    count: entry.count,
 924                })
 925                .collect(),
 926        })
 927    } else {
 928        proto::operation::Variant::Edit(proto::operation::Edit {
 929            replica_id: row.replica_id as u32,
 930            lamport_timestamp: row.lamport_timestamp as u32,
 931            version,
 932            ranges: operation
 933                .edit_ranges
 934                .into_iter()
 935                .map(|range| proto::Range {
 936                    start: range.start,
 937                    end: range.end,
 938                })
 939                .collect(),
 940            new_text: operation.edit_texts,
 941        })
 942    })
 943}
 944
 945fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
 946    version
 947        .iter()
 948        .map(|entry| storage::VectorClockEntry {
 949            replica_id: entry.replica_id,
 950            timestamp: entry.timestamp,
 951        })
 952        .collect()
 953}
 954
 955fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
 956    version
 957        .iter()
 958        .map(|entry| proto::VectorClockEntry {
 959            replica_id: entry.replica_id,
 960            timestamp: entry.timestamp,
 961        })
 962        .collect()
 963}
 964
 965// This is currently a manual copy of the deserialization code in the client's langauge crate
 966pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
 967    match operation.variant? {
 968        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
 969            timestamp: clock::Lamport {
 970                replica_id: edit.replica_id as text::ReplicaId,
 971                value: edit.lamport_timestamp,
 972            },
 973            version: version_from_wire(&edit.version),
 974            ranges: edit
 975                .ranges
 976                .into_iter()
 977                .map(|range| {
 978                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
 979                })
 980                .collect(),
 981            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
 982        })),
 983        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
 984            timestamp: clock::Lamport {
 985                replica_id: undo.replica_id as text::ReplicaId,
 986                value: undo.lamport_timestamp,
 987            },
 988            version: version_from_wire(&undo.version),
 989            counts: undo
 990                .counts
 991                .into_iter()
 992                .map(|c| {
 993                    (
 994                        clock::Lamport {
 995                            replica_id: c.replica_id as text::ReplicaId,
 996                            value: c.lamport_timestamp,
 997                        },
 998                        c.count,
 999                    )
1000                })
1001                .collect(),
1002        })),
1003        _ => None,
1004    }
1005}
1006
1007fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
1008    let mut version = clock::Global::new();
1009    for entry in message {
1010        version.observe(clock::Lamport {
1011            replica_id: entry.replica_id as text::ReplicaId,
1012            value: entry.timestamp,
1013        });
1014    }
1015    version
1016}
1017
1018fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
1019    let mut message = Vec::new();
1020    for entry in version.iter() {
1021        message.push(proto::VectorClockEntry {
1022            replica_id: entry.replica_id as u32,
1023            timestamp: entry.value,
1024        });
1025    }
1026    message
1027}
1028
1029#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1030enum QueryOperationSerializationVersion {
1031    OperationSerializationVersion,
1032}
1033
1034mod storage {
1035    #![allow(non_snake_case)]
1036    use prost::Message;
1037    pub const SERIALIZATION_VERSION: i32 = 1;
1038
1039    #[derive(Message)]
1040    pub struct Operation {
1041        #[prost(message, repeated, tag = "2")]
1042        pub version: Vec<VectorClockEntry>,
1043        #[prost(bool, tag = "3")]
1044        pub is_undo: bool,
1045        #[prost(message, repeated, tag = "4")]
1046        pub edit_ranges: Vec<Range>,
1047        #[prost(string, repeated, tag = "5")]
1048        pub edit_texts: Vec<String>,
1049        #[prost(message, repeated, tag = "6")]
1050        pub undo_counts: Vec<UndoCount>,
1051    }
1052
1053    #[derive(Message)]
1054    pub struct VectorClockEntry {
1055        #[prost(uint32, tag = "1")]
1056        pub replica_id: u32,
1057        #[prost(uint32, tag = "2")]
1058        pub timestamp: u32,
1059    }
1060
1061    #[derive(Message)]
1062    pub struct Range {
1063        #[prost(uint64, tag = "1")]
1064        pub start: u64,
1065        #[prost(uint64, tag = "2")]
1066        pub end: u64,
1067    }
1068
1069    #[derive(Message)]
1070    pub struct UndoCount {
1071        #[prost(uint32, tag = "1")]
1072        pub replica_id: u32,
1073        #[prost(uint32, tag = "2")]
1074        pub lamport_timestamp: u32,
1075        #[prost(uint32, tag = "3")]
1076        pub count: u32,
1077    }
1078}