buffers.rs

   1use super::*;
   2use prost::Message;
   3use text::{EditOperation, UndoOperation};
   4
   5pub struct LeftChannelBuffer {
   6    pub channel_id: ChannelId,
   7    pub collaborators: Vec<proto::Collaborator>,
   8    pub connections: Vec<ConnectionId>,
   9}
  10
  11impl Database {
  12    /// Open a channel buffer. Returns the current contents, and adds you to the list of people
  13    /// to notify on changes.
  14    pub async fn join_channel_buffer(
  15        &self,
  16        channel_id: ChannelId,
  17        user_id: UserId,
  18        connection: ConnectionId,
  19    ) -> Result<proto::JoinChannelBufferResponse> {
  20        self.transaction(|tx| async move {
  21            let channel = self.get_channel_internal(channel_id, &*tx).await?;
  22            self.check_user_is_channel_participant(&channel, user_id, &tx)
  23                .await?;
  24
  25            let buffer = channel::Model {
  26                id: channel_id,
  27                ..Default::default()
  28            }
  29            .find_related(buffer::Entity)
  30            .one(&*tx)
  31            .await?;
  32
  33            let buffer = if let Some(buffer) = buffer {
  34                buffer
  35            } else {
  36                let buffer = buffer::ActiveModel {
  37                    channel_id: ActiveValue::Set(channel_id),
  38                    ..Default::default()
  39                }
  40                .insert(&*tx)
  41                .await?;
  42                buffer_snapshot::ActiveModel {
  43                    buffer_id: ActiveValue::Set(buffer.id),
  44                    epoch: ActiveValue::Set(0),
  45                    text: ActiveValue::Set(String::new()),
  46                    operation_serialization_version: ActiveValue::Set(
  47                        storage::SERIALIZATION_VERSION,
  48                    ),
  49                }
  50                .insert(&*tx)
  51                .await?;
  52                buffer
  53            };
  54
  55            // Join the collaborators
  56            let mut collaborators = channel_buffer_collaborator::Entity::find()
  57                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
  58                .all(&*tx)
  59                .await?;
  60            let replica_ids = collaborators
  61                .iter()
  62                .map(|c| c.replica_id)
  63                .collect::<HashSet<_>>();
  64            let mut replica_id = ReplicaId(0);
  65            while replica_ids.contains(&replica_id) {
  66                replica_id.0 += 1;
  67            }
  68            let collaborator = channel_buffer_collaborator::ActiveModel {
  69                channel_id: ActiveValue::Set(channel_id),
  70                connection_id: ActiveValue::Set(connection.id as i32),
  71                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
  72                user_id: ActiveValue::Set(user_id),
  73                replica_id: ActiveValue::Set(replica_id),
  74                ..Default::default()
  75            }
  76            .insert(&*tx)
  77            .await?;
  78            collaborators.push(collaborator);
  79
  80            let (base_text, operations, max_operation) =
  81                self.get_buffer_state(&buffer, &tx).await?;
  82
  83            // Save the last observed operation
  84            if let Some(op) = max_operation {
  85                observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
  86                    user_id: ActiveValue::Set(user_id),
  87                    buffer_id: ActiveValue::Set(buffer.id),
  88                    epoch: ActiveValue::Set(op.epoch),
  89                    lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
  90                    replica_id: ActiveValue::Set(op.replica_id),
  91                })
  92                .on_conflict(
  93                    OnConflict::columns([
  94                        observed_buffer_edits::Column::UserId,
  95                        observed_buffer_edits::Column::BufferId,
  96                    ])
  97                    .update_columns([
  98                        observed_buffer_edits::Column::Epoch,
  99                        observed_buffer_edits::Column::LamportTimestamp,
 100                    ])
 101                    .to_owned(),
 102                )
 103                .exec(&*tx)
 104                .await?;
 105            }
 106
 107            Ok(proto::JoinChannelBufferResponse {
 108                buffer_id: buffer.id.to_proto(),
 109                replica_id: replica_id.to_proto() as u32,
 110                base_text,
 111                operations,
 112                epoch: buffer.epoch as u64,
 113                collaborators: collaborators
 114                    .into_iter()
 115                    .map(|collaborator| proto::Collaborator {
 116                        peer_id: Some(collaborator.connection().into()),
 117                        user_id: collaborator.user_id.to_proto(),
 118                        replica_id: collaborator.replica_id.0 as u32,
 119                    })
 120                    .collect(),
 121            })
 122        })
 123        .await
 124    }
 125
 126    /// Rejoin a channel buffer (after a connection interruption)
 127    pub async fn rejoin_channel_buffers(
 128        &self,
 129        buffers: &[proto::ChannelBufferVersion],
 130        user_id: UserId,
 131        connection_id: ConnectionId,
 132    ) -> Result<Vec<RejoinedChannelBuffer>> {
 133        self.transaction(|tx| async move {
 134            let mut results = Vec::new();
 135            for client_buffer in buffers {
 136                let channel = self
 137                    .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &*tx)
 138                    .await?;
 139                if self
 140                    .check_user_is_channel_participant(&channel, user_id, &*tx)
 141                    .await
 142                    .is_err()
 143                {
 144                    log::info!("user is not a member of channel");
 145                    continue;
 146                }
 147
 148                let buffer = self.get_channel_buffer(channel.id, &*tx).await?;
 149                let mut collaborators = channel_buffer_collaborator::Entity::find()
 150                    .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id))
 151                    .all(&*tx)
 152                    .await?;
 153
 154                // If the buffer epoch hasn't changed since the client lost
 155                // connection, then the client's buffer can be synchronized with
 156                // the server's buffer.
 157                if buffer.epoch as u64 != client_buffer.epoch {
 158                    log::info!("can't rejoin buffer, epoch has changed");
 159                    continue;
 160                }
 161
 162                // Find the collaborator record for this user's previous lost
 163                // connection. Update it with the new connection id.
 164                let server_id = ServerId(connection_id.owner_id as i32);
 165                let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
 166                    c.user_id == user_id
 167                        && (c.connection_lost || c.connection_server_id != server_id)
 168                }) else {
 169                    log::info!("can't rejoin buffer, no previous collaborator found");
 170                    continue;
 171                };
 172                let old_connection_id = self_collaborator.connection();
 173                *self_collaborator = channel_buffer_collaborator::ActiveModel {
 174                    id: ActiveValue::Unchanged(self_collaborator.id),
 175                    connection_id: ActiveValue::Set(connection_id.id as i32),
 176                    connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 177                    connection_lost: ActiveValue::Set(false),
 178                    ..Default::default()
 179                }
 180                .update(&*tx)
 181                .await?;
 182
 183                let client_version = version_from_wire(&client_buffer.version);
 184                let serialization_version = self
 185                    .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
 186                    .await?;
 187
 188                let mut rows = buffer_operation::Entity::find()
 189                    .filter(
 190                        buffer_operation::Column::BufferId
 191                            .eq(buffer.id)
 192                            .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 193                    )
 194                    .stream(&*tx)
 195                    .await?;
 196
 197                // Find the server's version vector and any operations
 198                // that the client has not seen.
 199                let mut server_version = clock::Global::new();
 200                let mut operations = Vec::new();
 201                while let Some(row) = rows.next().await {
 202                    let row = row?;
 203                    let timestamp = clock::Lamport {
 204                        replica_id: row.replica_id as u16,
 205                        value: row.lamport_timestamp as u32,
 206                    };
 207                    server_version.observe(timestamp);
 208                    if !client_version.observed(timestamp) {
 209                        operations.push(proto::Operation {
 210                            variant: Some(operation_from_storage(row, serialization_version)?),
 211                        })
 212                    }
 213                }
 214
 215                results.push(RejoinedChannelBuffer {
 216                    old_connection_id,
 217                    buffer: proto::RejoinedChannelBuffer {
 218                        channel_id: client_buffer.channel_id,
 219                        version: version_to_wire(&server_version),
 220                        operations,
 221                        collaborators: collaborators
 222                            .into_iter()
 223                            .map(|collaborator| proto::Collaborator {
 224                                peer_id: Some(collaborator.connection().into()),
 225                                user_id: collaborator.user_id.to_proto(),
 226                                replica_id: collaborator.replica_id.0 as u32,
 227                            })
 228                            .collect(),
 229                    },
 230                });
 231            }
 232
 233            Ok(results)
 234        })
 235        .await
 236    }
 237
 238    /// Clear out any buffer collaborators who are no longer collaborating.
 239    pub async fn clear_stale_channel_buffer_collaborators(
 240        &self,
 241        channel_id: ChannelId,
 242        server_id: ServerId,
 243    ) -> Result<RefreshedChannelBuffer> {
 244        self.transaction(|tx| async move {
 245            let db_collaborators = channel_buffer_collaborator::Entity::find()
 246                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 247                .all(&*tx)
 248                .await?;
 249
 250            let mut connection_ids = Vec::new();
 251            let mut collaborators = Vec::new();
 252            let mut collaborator_ids_to_remove = Vec::new();
 253            for db_collaborator in &db_collaborators {
 254                if !db_collaborator.connection_lost
 255                    && db_collaborator.connection_server_id == server_id
 256                {
 257                    connection_ids.push(db_collaborator.connection());
 258                    collaborators.push(proto::Collaborator {
 259                        peer_id: Some(db_collaborator.connection().into()),
 260                        replica_id: db_collaborator.replica_id.0 as u32,
 261                        user_id: db_collaborator.user_id.to_proto(),
 262                    })
 263                } else {
 264                    collaborator_ids_to_remove.push(db_collaborator.id);
 265                }
 266            }
 267
 268            channel_buffer_collaborator::Entity::delete_many()
 269                .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
 270                .exec(&*tx)
 271                .await?;
 272
 273            Ok(RefreshedChannelBuffer {
 274                connection_ids,
 275                collaborators,
 276            })
 277        })
 278        .await
 279    }
 280
 281    /// Close the channel buffer, and stop receiving updates for it.
 282    pub async fn leave_channel_buffer(
 283        &self,
 284        channel_id: ChannelId,
 285        connection: ConnectionId,
 286    ) -> Result<LeftChannelBuffer> {
 287        self.transaction(|tx| async move {
 288            self.leave_channel_buffer_internal(channel_id, connection, &*tx)
 289                .await
 290        })
 291        .await
 292    }
 293
 294    /// Close the channel buffer, and stop receiving updates for it.
 295    pub async fn channel_buffer_connection_lost(
 296        &self,
 297        connection: ConnectionId,
 298        tx: &DatabaseTransaction,
 299    ) -> Result<()> {
 300        channel_buffer_collaborator::Entity::update_many()
 301            .filter(
 302                Condition::all()
 303                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 304                    .add(
 305                        channel_buffer_collaborator::Column::ConnectionServerId
 306                            .eq(connection.owner_id as i32),
 307                    ),
 308            )
 309            .set(channel_buffer_collaborator::ActiveModel {
 310                connection_lost: ActiveValue::set(true),
 311                ..Default::default()
 312            })
 313            .exec(&*tx)
 314            .await?;
 315        Ok(())
 316    }
 317
 318    /// Close all open channel buffers
 319    pub async fn leave_channel_buffers(
 320        &self,
 321        connection: ConnectionId,
 322    ) -> Result<Vec<LeftChannelBuffer>> {
 323        self.transaction(|tx| async move {
 324            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 325            enum QueryChannelIds {
 326                ChannelId,
 327            }
 328
 329            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
 330                .select_only()
 331                .column(channel_buffer_collaborator::Column::ChannelId)
 332                .filter(Condition::all().add(
 333                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
 334                ))
 335                .into_values::<_, QueryChannelIds>()
 336                .all(&*tx)
 337                .await?;
 338
 339            let mut result = Vec::new();
 340            for channel_id in channel_ids {
 341                let left_channel_buffer = self
 342                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
 343                    .await?;
 344                result.push(left_channel_buffer);
 345            }
 346
 347            Ok(result)
 348        })
 349        .await
 350    }
 351
 352    async fn leave_channel_buffer_internal(
 353        &self,
 354        channel_id: ChannelId,
 355        connection: ConnectionId,
 356        tx: &DatabaseTransaction,
 357    ) -> Result<LeftChannelBuffer> {
 358        let result = channel_buffer_collaborator::Entity::delete_many()
 359            .filter(
 360                Condition::all()
 361                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 362                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 363                    .add(
 364                        channel_buffer_collaborator::Column::ConnectionServerId
 365                            .eq(connection.owner_id as i32),
 366                    ),
 367            )
 368            .exec(&*tx)
 369            .await?;
 370        if result.rows_affected == 0 {
 371            Err(anyhow!("not a collaborator on this project"))?;
 372        }
 373
 374        let mut collaborators = Vec::new();
 375        let mut connections = Vec::new();
 376        let mut rows = channel_buffer_collaborator::Entity::find()
 377            .filter(
 378                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 379            )
 380            .stream(&*tx)
 381            .await?;
 382        while let Some(row) = rows.next().await {
 383            let row = row?;
 384            let connection = row.connection();
 385            connections.push(connection);
 386            collaborators.push(proto::Collaborator {
 387                peer_id: Some(connection.into()),
 388                replica_id: row.replica_id.0 as u32,
 389                user_id: row.user_id.to_proto(),
 390            });
 391        }
 392
 393        drop(rows);
 394
 395        if collaborators.is_empty() {
 396            self.snapshot_channel_buffer(channel_id, &tx).await?;
 397        }
 398
 399        Ok(LeftChannelBuffer {
 400            channel_id,
 401            collaborators,
 402            connections,
 403        })
 404    }
 405
 406    pub async fn get_channel_buffer_collaborators(
 407        &self,
 408        channel_id: ChannelId,
 409    ) -> Result<Vec<UserId>> {
 410        self.transaction(|tx| async move {
 411            self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
 412                .await
 413        })
 414        .await
 415    }
 416
 417    async fn get_channel_buffer_collaborators_internal(
 418        &self,
 419        channel_id: ChannelId,
 420        tx: &DatabaseTransaction,
 421    ) -> Result<Vec<UserId>> {
 422        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 423        enum QueryUserIds {
 424            UserId,
 425        }
 426
 427        let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
 428            .select_only()
 429            .column(channel_buffer_collaborator::Column::UserId)
 430            .filter(
 431                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 432            )
 433            .into_values::<_, QueryUserIds>()
 434            .all(&*tx)
 435            .await?;
 436
 437        Ok(users)
 438    }
 439
 440    pub async fn update_channel_buffer(
 441        &self,
 442        channel_id: ChannelId,
 443        user: UserId,
 444        operations: &[proto::Operation],
 445    ) -> Result<(
 446        Vec<ConnectionId>,
 447        Vec<UserId>,
 448        i32,
 449        Vec<proto::VectorClockEntry>,
 450    )> {
 451        self.transaction(move |tx| async move {
 452            let channel = self.get_channel_internal(channel_id, &*tx).await?;
 453
 454            let mut requires_write_permission = false;
 455            for op in operations.iter() {
 456                match op.variant {
 457                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
 458                    Some(_) => requires_write_permission = true,
 459                }
 460            }
 461            if requires_write_permission {
 462                self.check_user_is_channel_member(&channel, user, &*tx)
 463                    .await?;
 464            } else {
 465                self.check_user_is_channel_participant(&channel, user, &*tx)
 466                    .await?;
 467            }
 468
 469            let buffer = buffer::Entity::find()
 470                .filter(buffer::Column::ChannelId.eq(channel_id))
 471                .one(&*tx)
 472                .await?
 473                .ok_or_else(|| anyhow!("no such buffer"))?;
 474
 475            let serialization_version = self
 476                .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
 477                .await?;
 478
 479            let operations = operations
 480                .iter()
 481                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
 482                .collect::<Vec<_>>();
 483
 484            let mut channel_members;
 485            let max_version;
 486
 487            if !operations.is_empty() {
 488                let max_operation = operations
 489                    .iter()
 490                    .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
 491                    .unwrap();
 492
 493                max_version = vec![proto::VectorClockEntry {
 494                    replica_id: *max_operation.replica_id.as_ref() as u32,
 495                    timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
 496                }];
 497
 498                // get current channel participants and save the max operation above
 499                self.save_max_operation(
 500                    user,
 501                    buffer.id,
 502                    buffer.epoch,
 503                    *max_operation.replica_id.as_ref(),
 504                    *max_operation.lamport_timestamp.as_ref(),
 505                    &*tx,
 506                )
 507                .await?;
 508
 509                channel_members = self.get_channel_participants(&channel, &*tx).await?;
 510                let collaborators = self
 511                    .get_channel_buffer_collaborators_internal(channel_id, &*tx)
 512                    .await?;
 513                channel_members.retain(|member| !collaborators.contains(member));
 514
 515                buffer_operation::Entity::insert_many(operations)
 516                    .on_conflict(
 517                        OnConflict::columns([
 518                            buffer_operation::Column::BufferId,
 519                            buffer_operation::Column::Epoch,
 520                            buffer_operation::Column::LamportTimestamp,
 521                            buffer_operation::Column::ReplicaId,
 522                        ])
 523                        .do_nothing()
 524                        .to_owned(),
 525                    )
 526                    .exec(&*tx)
 527                    .await?;
 528            } else {
 529                channel_members = Vec::new();
 530                max_version = Vec::new();
 531            }
 532
 533            let mut connections = Vec::new();
 534            let mut rows = channel_buffer_collaborator::Entity::find()
 535                .filter(
 536                    Condition::all()
 537                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 538                )
 539                .stream(&*tx)
 540                .await?;
 541            while let Some(row) = rows.next().await {
 542                let row = row?;
 543                connections.push(ConnectionId {
 544                    id: row.connection_id as u32,
 545                    owner_id: row.connection_server_id.0 as u32,
 546                });
 547            }
 548
 549            Ok((connections, channel_members, buffer.epoch, max_version))
 550        })
 551        .await
 552    }
 553
 554    async fn save_max_operation(
 555        &self,
 556        user_id: UserId,
 557        buffer_id: BufferId,
 558        epoch: i32,
 559        replica_id: i32,
 560        lamport_timestamp: i32,
 561        tx: &DatabaseTransaction,
 562    ) -> Result<()> {
 563        use observed_buffer_edits::Column;
 564
 565        observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
 566            user_id: ActiveValue::Set(user_id),
 567            buffer_id: ActiveValue::Set(buffer_id),
 568            epoch: ActiveValue::Set(epoch),
 569            replica_id: ActiveValue::Set(replica_id),
 570            lamport_timestamp: ActiveValue::Set(lamport_timestamp),
 571        })
 572        .on_conflict(
 573            OnConflict::columns([Column::UserId, Column::BufferId])
 574                .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
 575                .action_cond_where(
 576                    Condition::any().add(Column::Epoch.lt(epoch)).add(
 577                        Condition::all().add(Column::Epoch.eq(epoch)).add(
 578                            Condition::any()
 579                                .add(Column::LamportTimestamp.lt(lamport_timestamp))
 580                                .add(
 581                                    Column::LamportTimestamp
 582                                        .eq(lamport_timestamp)
 583                                        .and(Column::ReplicaId.lt(replica_id)),
 584                                ),
 585                        ),
 586                    ),
 587                )
 588                .to_owned(),
 589        )
 590        .exec_without_returning(tx)
 591        .await?;
 592
 593        Ok(())
 594    }
 595
 596    async fn get_buffer_operation_serialization_version(
 597        &self,
 598        buffer_id: BufferId,
 599        epoch: i32,
 600        tx: &DatabaseTransaction,
 601    ) -> Result<i32> {
 602        Ok(buffer_snapshot::Entity::find()
 603            .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
 604            .filter(buffer_snapshot::Column::Epoch.eq(epoch))
 605            .select_only()
 606            .column(buffer_snapshot::Column::OperationSerializationVersion)
 607            .into_values::<_, QueryOperationSerializationVersion>()
 608            .one(&*tx)
 609            .await?
 610            .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
 611    }
 612
 613    pub async fn get_channel_buffer(
 614        &self,
 615        channel_id: ChannelId,
 616        tx: &DatabaseTransaction,
 617    ) -> Result<buffer::Model> {
 618        Ok(channel::Model {
 619            id: channel_id,
 620            ..Default::default()
 621        }
 622        .find_related(buffer::Entity)
 623        .one(&*tx)
 624        .await?
 625        .ok_or_else(|| anyhow!("no such buffer"))?)
 626    }
 627
 628    async fn get_buffer_state(
 629        &self,
 630        buffer: &buffer::Model,
 631        tx: &DatabaseTransaction,
 632    ) -> Result<(
 633        String,
 634        Vec<proto::Operation>,
 635        Option<buffer_operation::Model>,
 636    )> {
 637        let id = buffer.id;
 638        let (base_text, version) = if buffer.epoch > 0 {
 639            let snapshot = buffer_snapshot::Entity::find()
 640                .filter(
 641                    buffer_snapshot::Column::BufferId
 642                        .eq(id)
 643                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
 644                )
 645                .one(&*tx)
 646                .await?
 647                .ok_or_else(|| anyhow!("no such snapshot"))?;
 648
 649            let version = snapshot.operation_serialization_version;
 650            (snapshot.text, version)
 651        } else {
 652            (String::new(), storage::SERIALIZATION_VERSION)
 653        };
 654
 655        let mut rows = buffer_operation::Entity::find()
 656            .filter(
 657                buffer_operation::Column::BufferId
 658                    .eq(id)
 659                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 660            )
 661            .order_by_asc(buffer_operation::Column::LamportTimestamp)
 662            .order_by_asc(buffer_operation::Column::ReplicaId)
 663            .stream(&*tx)
 664            .await?;
 665
 666        let mut operations = Vec::new();
 667        let mut last_row = None;
 668        while let Some(row) = rows.next().await {
 669            let row = row?;
 670            last_row = Some(buffer_operation::Model {
 671                buffer_id: row.buffer_id,
 672                epoch: row.epoch,
 673                lamport_timestamp: row.lamport_timestamp,
 674                replica_id: row.lamport_timestamp,
 675                value: Default::default(),
 676            });
 677            operations.push(proto::Operation {
 678                variant: Some(operation_from_storage(row, version)?),
 679            });
 680        }
 681
 682        Ok((base_text, operations, last_row))
 683    }
 684
 685    async fn snapshot_channel_buffer(
 686        &self,
 687        channel_id: ChannelId,
 688        tx: &DatabaseTransaction,
 689    ) -> Result<()> {
 690        let buffer = self.get_channel_buffer(channel_id, tx).await?;
 691        let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
 692        if operations.is_empty() {
 693            return Ok(());
 694        }
 695
 696        let mut text_buffer = text::Buffer::new(0, 0, base_text);
 697        text_buffer
 698            .apply_ops(operations.into_iter().filter_map(operation_from_wire))
 699            .unwrap();
 700
 701        let base_text = text_buffer.text();
 702        let epoch = buffer.epoch + 1;
 703
 704        buffer_snapshot::Model {
 705            buffer_id: buffer.id,
 706            epoch,
 707            text: base_text,
 708            operation_serialization_version: storage::SERIALIZATION_VERSION,
 709        }
 710        .into_active_model()
 711        .insert(tx)
 712        .await?;
 713
 714        buffer::ActiveModel {
 715            id: ActiveValue::Unchanged(buffer.id),
 716            epoch: ActiveValue::Set(epoch),
 717            ..Default::default()
 718        }
 719        .save(tx)
 720        .await?;
 721
 722        Ok(())
 723    }
 724
 725    pub async fn observe_buffer_version(
 726        &self,
 727        buffer_id: BufferId,
 728        user_id: UserId,
 729        epoch: i32,
 730        version: &[proto::VectorClockEntry],
 731    ) -> Result<()> {
 732        self.transaction(|tx| async move {
 733            // For now, combine concurrent operations.
 734            let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
 735                return Ok(());
 736            };
 737            self.save_max_operation(
 738                user_id,
 739                buffer_id,
 740                epoch,
 741                component.replica_id as i32,
 742                component.timestamp as i32,
 743                &*tx,
 744            )
 745            .await?;
 746            Ok(())
 747        })
 748        .await
 749    }
 750
 751    pub async fn unseen_channel_buffer_changes(
 752        &self,
 753        user_id: UserId,
 754        channel_ids: &[ChannelId],
 755        tx: &DatabaseTransaction,
 756    ) -> Result<Vec<proto::UnseenChannelBufferChange>> {
 757        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 758        enum QueryIds {
 759            ChannelId,
 760            Id,
 761        }
 762
 763        let mut channel_ids_by_buffer_id = HashMap::default();
 764        let mut rows = buffer::Entity::find()
 765            .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied()))
 766            .stream(&*tx)
 767            .await?;
 768        while let Some(row) = rows.next().await {
 769            let row = row?;
 770            channel_ids_by_buffer_id.insert(row.id, row.channel_id);
 771        }
 772        drop(rows);
 773
 774        let mut observed_edits_by_buffer_id = HashMap::default();
 775        let mut rows = observed_buffer_edits::Entity::find()
 776            .filter(observed_buffer_edits::Column::UserId.eq(user_id))
 777            .filter(
 778                observed_buffer_edits::Column::BufferId
 779                    .is_in(channel_ids_by_buffer_id.keys().copied()),
 780            )
 781            .stream(&*tx)
 782            .await?;
 783        while let Some(row) = rows.next().await {
 784            let row = row?;
 785            observed_edits_by_buffer_id.insert(row.buffer_id, row);
 786        }
 787        drop(rows);
 788
 789        let latest_operations = self
 790            .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
 791            .await?;
 792
 793        let mut changes = Vec::default();
 794        for latest in latest_operations {
 795            if let Some(observed) = observed_edits_by_buffer_id.get(&latest.buffer_id) {
 796                if (
 797                    observed.epoch,
 798                    observed.lamport_timestamp,
 799                    observed.replica_id,
 800                ) >= (latest.epoch, latest.lamport_timestamp, latest.replica_id)
 801                {
 802                    continue;
 803                }
 804            }
 805
 806            if let Some(channel_id) = channel_ids_by_buffer_id.get(&latest.buffer_id) {
 807                changes.push(proto::UnseenChannelBufferChange {
 808                    channel_id: channel_id.to_proto(),
 809                    epoch: latest.epoch as u64,
 810                    version: vec![proto::VectorClockEntry {
 811                        replica_id: latest.replica_id as u32,
 812                        timestamp: latest.lamport_timestamp as u32,
 813                    }],
 814                });
 815            }
 816        }
 817
 818        Ok(changes)
 819    }
 820
 821    /// Returns the latest operations for the buffers with the specified IDs.
 822    pub async fn get_latest_operations_for_buffers(
 823        &self,
 824        buffer_ids: impl IntoIterator<Item = BufferId>,
 825        tx: &DatabaseTransaction,
 826    ) -> Result<Vec<buffer_operation::Model>> {
 827        let mut values = String::new();
 828        for id in buffer_ids {
 829            if !values.is_empty() {
 830                values.push_str(", ");
 831            }
 832            write!(&mut values, "({})", id).unwrap();
 833        }
 834
 835        if values.is_empty() {
 836            return Ok(Vec::default());
 837        }
 838
 839        let sql = format!(
 840            r#"
 841            SELECT
 842                *
 843            FROM
 844            (
 845                SELECT
 846                    *,
 847                    row_number() OVER (
 848                        PARTITION BY buffer_id
 849                        ORDER BY
 850                            epoch DESC,
 851                            lamport_timestamp DESC,
 852                            replica_id DESC
 853                    ) as row_number
 854                FROM buffer_operations
 855                WHERE
 856                    buffer_id in ({values})
 857            ) AS last_operations
 858            WHERE
 859                row_number = 1
 860            "#,
 861        );
 862
 863        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 864        Ok(buffer_operation::Entity::find()
 865            .from_raw_sql(stmt)
 866            .all(&*tx)
 867            .await?)
 868    }
 869}
 870
 871fn operation_to_storage(
 872    operation: &proto::Operation,
 873    buffer: &buffer::Model,
 874    _format: i32,
 875) -> Option<buffer_operation::ActiveModel> {
 876    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
 877        proto::operation::Variant::Edit(operation) => (
 878            operation.replica_id,
 879            operation.lamport_timestamp,
 880            storage::Operation {
 881                version: version_to_storage(&operation.version),
 882                is_undo: false,
 883                edit_ranges: operation
 884                    .ranges
 885                    .iter()
 886                    .map(|range| storage::Range {
 887                        start: range.start,
 888                        end: range.end,
 889                    })
 890                    .collect(),
 891                edit_texts: operation.new_text.clone(),
 892                undo_counts: Vec::new(),
 893            },
 894        ),
 895        proto::operation::Variant::Undo(operation) => (
 896            operation.replica_id,
 897            operation.lamport_timestamp,
 898            storage::Operation {
 899                version: version_to_storage(&operation.version),
 900                is_undo: true,
 901                edit_ranges: Vec::new(),
 902                edit_texts: Vec::new(),
 903                undo_counts: operation
 904                    .counts
 905                    .iter()
 906                    .map(|entry| storage::UndoCount {
 907                        replica_id: entry.replica_id,
 908                        lamport_timestamp: entry.lamport_timestamp,
 909                        count: entry.count,
 910                    })
 911                    .collect(),
 912            },
 913        ),
 914        _ => None?,
 915    };
 916
 917    Some(buffer_operation::ActiveModel {
 918        buffer_id: ActiveValue::Set(buffer.id),
 919        epoch: ActiveValue::Set(buffer.epoch),
 920        replica_id: ActiveValue::Set(replica_id as i32),
 921        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
 922        value: ActiveValue::Set(value.encode_to_vec()),
 923    })
 924}
 925
 926fn operation_from_storage(
 927    row: buffer_operation::Model,
 928    _format_version: i32,
 929) -> Result<proto::operation::Variant, Error> {
 930    let operation =
 931        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
 932    let version = version_from_storage(&operation.version);
 933    Ok(if operation.is_undo {
 934        proto::operation::Variant::Undo(proto::operation::Undo {
 935            replica_id: row.replica_id as u32,
 936            lamport_timestamp: row.lamport_timestamp as u32,
 937            version,
 938            counts: operation
 939                .undo_counts
 940                .iter()
 941                .map(|entry| proto::UndoCount {
 942                    replica_id: entry.replica_id,
 943                    lamport_timestamp: entry.lamport_timestamp,
 944                    count: entry.count,
 945                })
 946                .collect(),
 947        })
 948    } else {
 949        proto::operation::Variant::Edit(proto::operation::Edit {
 950            replica_id: row.replica_id as u32,
 951            lamport_timestamp: row.lamport_timestamp as u32,
 952            version,
 953            ranges: operation
 954                .edit_ranges
 955                .into_iter()
 956                .map(|range| proto::Range {
 957                    start: range.start,
 958                    end: range.end,
 959                })
 960                .collect(),
 961            new_text: operation.edit_texts,
 962        })
 963    })
 964}
 965
 966fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
 967    version
 968        .iter()
 969        .map(|entry| storage::VectorClockEntry {
 970            replica_id: entry.replica_id,
 971            timestamp: entry.timestamp,
 972        })
 973        .collect()
 974}
 975
 976fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
 977    version
 978        .iter()
 979        .map(|entry| proto::VectorClockEntry {
 980            replica_id: entry.replica_id,
 981            timestamp: entry.timestamp,
 982        })
 983        .collect()
 984}
 985
 986// This is currently a manual copy of the deserialization code in the client's language crate
 987pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
 988    match operation.variant? {
 989        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
 990            timestamp: clock::Lamport {
 991                replica_id: edit.replica_id as text::ReplicaId,
 992                value: edit.lamport_timestamp,
 993            },
 994            version: version_from_wire(&edit.version),
 995            ranges: edit
 996                .ranges
 997                .into_iter()
 998                .map(|range| {
 999                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
1000                })
1001                .collect(),
1002            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
1003        })),
1004        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
1005            timestamp: clock::Lamport {
1006                replica_id: undo.replica_id as text::ReplicaId,
1007                value: undo.lamport_timestamp,
1008            },
1009            version: version_from_wire(&undo.version),
1010            counts: undo
1011                .counts
1012                .into_iter()
1013                .map(|c| {
1014                    (
1015                        clock::Lamport {
1016                            replica_id: c.replica_id as text::ReplicaId,
1017                            value: c.lamport_timestamp,
1018                        },
1019                        c.count,
1020                    )
1021                })
1022                .collect(),
1023        })),
1024        _ => None,
1025    }
1026}
1027
1028fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
1029    let mut version = clock::Global::new();
1030    for entry in message {
1031        version.observe(clock::Lamport {
1032            replica_id: entry.replica_id as text::ReplicaId,
1033            value: entry.timestamp,
1034        });
1035    }
1036    version
1037}
1038
1039fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
1040    let mut message = Vec::new();
1041    for entry in version.iter() {
1042        message.push(proto::VectorClockEntry {
1043            replica_id: entry.replica_id as u32,
1044            timestamp: entry.value,
1045        });
1046    }
1047    message
1048}
1049
1050#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1051enum QueryOperationSerializationVersion {
1052    OperationSerializationVersion,
1053}
1054
1055mod storage {
1056    #![allow(non_snake_case)]
1057    use prost::Message;
1058    pub const SERIALIZATION_VERSION: i32 = 1;
1059
1060    #[derive(Message)]
1061    pub struct Operation {
1062        #[prost(message, repeated, tag = "2")]
1063        pub version: Vec<VectorClockEntry>,
1064        #[prost(bool, tag = "3")]
1065        pub is_undo: bool,
1066        #[prost(message, repeated, tag = "4")]
1067        pub edit_ranges: Vec<Range>,
1068        #[prost(string, repeated, tag = "5")]
1069        pub edit_texts: Vec<String>,
1070        #[prost(message, repeated, tag = "6")]
1071        pub undo_counts: Vec<UndoCount>,
1072    }
1073
1074    #[derive(Message)]
1075    pub struct VectorClockEntry {
1076        #[prost(uint32, tag = "1")]
1077        pub replica_id: u32,
1078        #[prost(uint32, tag = "2")]
1079        pub timestamp: u32,
1080    }
1081
1082    #[derive(Message)]
1083    pub struct Range {
1084        #[prost(uint64, tag = "1")]
1085        pub start: u64,
1086        #[prost(uint64, tag = "2")]
1087        pub end: u64,
1088    }
1089
1090    #[derive(Message)]
1091    pub struct UndoCount {
1092        #[prost(uint32, tag = "1")]
1093        pub replica_id: u32,
1094        #[prost(uint32, tag = "2")]
1095        pub lamport_timestamp: u32,
1096        #[prost(uint32, tag = "3")]
1097        pub count: u32,
1098    }
1099}