buffers.rs

   1use super::*;
   2use prost::Message;
   3use text::{EditOperation, UndoOperation};
   4
   5pub struct LeftChannelBuffer {
   6    pub channel_id: ChannelId,
   7    pub collaborators: Vec<proto::Collaborator>,
   8    pub connections: Vec<ConnectionId>,
   9}
  10
  11impl Database {
  12    pub async fn join_channel_buffer(
  13        &self,
  14        channel_id: ChannelId,
  15        user_id: UserId,
  16        connection: ConnectionId,
  17    ) -> Result<proto::JoinChannelBufferResponse> {
  18        self.transaction(|tx| async move {
  19            self.check_user_is_channel_member(channel_id, user_id, &tx)
  20                .await?;
  21
  22            let buffer = channel::Model {
  23                id: channel_id,
  24                ..Default::default()
  25            }
  26            .find_related(buffer::Entity)
  27            .one(&*tx)
  28            .await?;
  29
  30            let buffer = if let Some(buffer) = buffer {
  31                buffer
  32            } else {
  33                let buffer = buffer::ActiveModel {
  34                    channel_id: ActiveValue::Set(channel_id),
  35                    ..Default::default()
  36                }
  37                .insert(&*tx)
  38                .await?;
  39                buffer_snapshot::ActiveModel {
  40                    buffer_id: ActiveValue::Set(buffer.id),
  41                    epoch: ActiveValue::Set(0),
  42                    text: ActiveValue::Set(String::new()),
  43                    operation_serialization_version: ActiveValue::Set(
  44                        storage::SERIALIZATION_VERSION,
  45                    ),
  46                }
  47                .insert(&*tx)
  48                .await?;
  49                buffer
  50            };
  51
  52            // Join the collaborators
  53            let mut collaborators = channel_buffer_collaborator::Entity::find()
  54                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
  55                .all(&*tx)
  56                .await?;
  57            let replica_ids = collaborators
  58                .iter()
  59                .map(|c| c.replica_id)
  60                .collect::<HashSet<_>>();
  61            let mut replica_id = ReplicaId(0);
  62            while replica_ids.contains(&replica_id) {
  63                replica_id.0 += 1;
  64            }
  65            let collaborator = channel_buffer_collaborator::ActiveModel {
  66                channel_id: ActiveValue::Set(channel_id),
  67                connection_id: ActiveValue::Set(connection.id as i32),
  68                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
  69                user_id: ActiveValue::Set(user_id),
  70                replica_id: ActiveValue::Set(replica_id),
  71                ..Default::default()
  72            }
  73            .insert(&*tx)
  74            .await?;
  75            collaborators.push(collaborator);
  76
  77            let (base_text, operations, max_operation) =
  78                self.get_buffer_state(&buffer, &tx).await?;
  79
  80            // Save the last observed operation
  81            if let Some(op) = max_operation {
  82                observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
  83                    user_id: ActiveValue::Set(user_id),
  84                    buffer_id: ActiveValue::Set(buffer.id),
  85                    epoch: ActiveValue::Set(op.epoch),
  86                    lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
  87                    replica_id: ActiveValue::Set(op.replica_id),
  88                })
  89                .on_conflict(
  90                    OnConflict::columns([
  91                        observed_buffer_edits::Column::UserId,
  92                        observed_buffer_edits::Column::BufferId,
  93                    ])
  94                    .update_columns([
  95                        observed_buffer_edits::Column::Epoch,
  96                        observed_buffer_edits::Column::LamportTimestamp,
  97                    ])
  98                    .to_owned(),
  99                )
 100                .exec(&*tx)
 101                .await?;
 102            }
 103
 104            Ok(proto::JoinChannelBufferResponse {
 105                buffer_id: buffer.id.to_proto(),
 106                replica_id: replica_id.to_proto() as u32,
 107                base_text,
 108                operations,
 109                epoch: buffer.epoch as u64,
 110                collaborators: collaborators
 111                    .into_iter()
 112                    .map(|collaborator| proto::Collaborator {
 113                        peer_id: Some(collaborator.connection().into()),
 114                        user_id: collaborator.user_id.to_proto(),
 115                        replica_id: collaborator.replica_id.0 as u32,
 116                    })
 117                    .collect(),
 118            })
 119        })
 120        .await
 121    }
 122
 123    pub async fn rejoin_channel_buffers(
 124        &self,
 125        buffers: &[proto::ChannelBufferVersion],
 126        user_id: UserId,
 127        connection_id: ConnectionId,
 128    ) -> Result<Vec<RejoinedChannelBuffer>> {
 129        self.transaction(|tx| async move {
 130            let mut results = Vec::new();
 131            for client_buffer in buffers {
 132                let channel_id = ChannelId::from_proto(client_buffer.channel_id);
 133                if self
 134                    .check_user_is_channel_member(channel_id, user_id, &*tx)
 135                    .await
 136                    .is_err()
 137                {
 138                    log::info!("user is not a member of channel");
 139                    continue;
 140                }
 141
 142                let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
 143                let mut collaborators = channel_buffer_collaborator::Entity::find()
 144                    .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 145                    .all(&*tx)
 146                    .await?;
 147
 148                // If the buffer epoch hasn't changed since the client lost
 149                // connection, then the client's buffer can be syncronized with
 150                // the server's buffer.
 151                if buffer.epoch as u64 != client_buffer.epoch {
 152                    log::info!("can't rejoin buffer, epoch has changed");
 153                    continue;
 154                }
 155
 156                // Find the collaborator record for this user's previous lost
 157                // connection. Update it with the new connection id.
 158                let server_id = ServerId(connection_id.owner_id as i32);
 159                let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
 160                    c.user_id == user_id
 161                        && (c.connection_lost || c.connection_server_id != server_id)
 162                }) else {
 163                    log::info!("can't rejoin buffer, no previous collaborator found");
 164                    continue;
 165                };
 166                let old_connection_id = self_collaborator.connection();
 167                *self_collaborator = channel_buffer_collaborator::ActiveModel {
 168                    id: ActiveValue::Unchanged(self_collaborator.id),
 169                    connection_id: ActiveValue::Set(connection_id.id as i32),
 170                    connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 171                    connection_lost: ActiveValue::Set(false),
 172                    ..Default::default()
 173                }
 174                .update(&*tx)
 175                .await?;
 176
 177                let client_version = version_from_wire(&client_buffer.version);
 178                let serialization_version = self
 179                    .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
 180                    .await?;
 181
 182                let mut rows = buffer_operation::Entity::find()
 183                    .filter(
 184                        buffer_operation::Column::BufferId
 185                            .eq(buffer.id)
 186                            .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 187                    )
 188                    .stream(&*tx)
 189                    .await?;
 190
 191                // Find the server's version vector and any operations
 192                // that the client has not seen.
 193                let mut server_version = clock::Global::new();
 194                let mut operations = Vec::new();
 195                while let Some(row) = rows.next().await {
 196                    let row = row?;
 197                    let timestamp = clock::Lamport {
 198                        replica_id: row.replica_id as u16,
 199                        value: row.lamport_timestamp as u32,
 200                    };
 201                    server_version.observe(timestamp);
 202                    if !client_version.observed(timestamp) {
 203                        operations.push(proto::Operation {
 204                            variant: Some(operation_from_storage(row, serialization_version)?),
 205                        })
 206                    }
 207                }
 208
 209                results.push(RejoinedChannelBuffer {
 210                    old_connection_id,
 211                    buffer: proto::RejoinedChannelBuffer {
 212                        channel_id: client_buffer.channel_id,
 213                        version: version_to_wire(&server_version),
 214                        operations,
 215                        collaborators: collaborators
 216                            .into_iter()
 217                            .map(|collaborator| proto::Collaborator {
 218                                peer_id: Some(collaborator.connection().into()),
 219                                user_id: collaborator.user_id.to_proto(),
 220                                replica_id: collaborator.replica_id.0 as u32,
 221                            })
 222                            .collect(),
 223                    },
 224                });
 225            }
 226
 227            Ok(results)
 228        })
 229        .await
 230    }
 231
 232    pub async fn clear_stale_channel_buffer_collaborators(
 233        &self,
 234        channel_id: ChannelId,
 235        server_id: ServerId,
 236    ) -> Result<RefreshedChannelBuffer> {
 237        self.transaction(|tx| async move {
 238            let db_collaborators = channel_buffer_collaborator::Entity::find()
 239                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 240                .all(&*tx)
 241                .await?;
 242
 243            let mut connection_ids = Vec::new();
 244            let mut collaborators = Vec::new();
 245            let mut collaborator_ids_to_remove = Vec::new();
 246            for db_collaborator in &db_collaborators {
 247                if !db_collaborator.connection_lost
 248                    && db_collaborator.connection_server_id == server_id
 249                {
 250                    connection_ids.push(db_collaborator.connection());
 251                    collaborators.push(proto::Collaborator {
 252                        peer_id: Some(db_collaborator.connection().into()),
 253                        replica_id: db_collaborator.replica_id.0 as u32,
 254                        user_id: db_collaborator.user_id.to_proto(),
 255                    })
 256                } else {
 257                    collaborator_ids_to_remove.push(db_collaborator.id);
 258                }
 259            }
 260
 261            channel_buffer_collaborator::Entity::delete_many()
 262                .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
 263                .exec(&*tx)
 264                .await?;
 265
 266            Ok(RefreshedChannelBuffer {
 267                connection_ids,
 268                collaborators,
 269            })
 270        })
 271        .await
 272    }
 273
 274    pub async fn leave_channel_buffer(
 275        &self,
 276        channel_id: ChannelId,
 277        connection: ConnectionId,
 278    ) -> Result<LeftChannelBuffer> {
 279        self.transaction(|tx| async move {
 280            self.leave_channel_buffer_internal(channel_id, connection, &*tx)
 281                .await
 282        })
 283        .await
 284    }
 285
 286    pub async fn channel_buffer_connection_lost(
 287        &self,
 288        connection: ConnectionId,
 289        tx: &DatabaseTransaction,
 290    ) -> Result<()> {
 291        channel_buffer_collaborator::Entity::update_many()
 292            .filter(
 293                Condition::all()
 294                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 295                    .add(
 296                        channel_buffer_collaborator::Column::ConnectionServerId
 297                            .eq(connection.owner_id as i32),
 298                    ),
 299            )
 300            .set(channel_buffer_collaborator::ActiveModel {
 301                connection_lost: ActiveValue::set(true),
 302                ..Default::default()
 303            })
 304            .exec(&*tx)
 305            .await?;
 306        Ok(())
 307    }
 308
 309    pub async fn leave_channel_buffers(
 310        &self,
 311        connection: ConnectionId,
 312    ) -> Result<Vec<LeftChannelBuffer>> {
 313        self.transaction(|tx| async move {
 314            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 315            enum QueryChannelIds {
 316                ChannelId,
 317            }
 318
 319            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
 320                .select_only()
 321                .column(channel_buffer_collaborator::Column::ChannelId)
 322                .filter(Condition::all().add(
 323                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
 324                ))
 325                .into_values::<_, QueryChannelIds>()
 326                .all(&*tx)
 327                .await?;
 328
 329            let mut result = Vec::new();
 330            for channel_id in channel_ids {
 331                let left_channel_buffer = self
 332                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
 333                    .await?;
 334                result.push(left_channel_buffer);
 335            }
 336
 337            Ok(result)
 338        })
 339        .await
 340    }
 341
 342    pub async fn leave_channel_buffer_internal(
 343        &self,
 344        channel_id: ChannelId,
 345        connection: ConnectionId,
 346        tx: &DatabaseTransaction,
 347    ) -> Result<LeftChannelBuffer> {
 348        let result = channel_buffer_collaborator::Entity::delete_many()
 349            .filter(
 350                Condition::all()
 351                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 352                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 353                    .add(
 354                        channel_buffer_collaborator::Column::ConnectionServerId
 355                            .eq(connection.owner_id as i32),
 356                    ),
 357            )
 358            .exec(&*tx)
 359            .await?;
 360        if result.rows_affected == 0 {
 361            Err(anyhow!("not a collaborator on this project"))?;
 362        }
 363
 364        let mut collaborators = Vec::new();
 365        let mut connections = Vec::new();
 366        let mut rows = channel_buffer_collaborator::Entity::find()
 367            .filter(
 368                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 369            )
 370            .stream(&*tx)
 371            .await?;
 372        while let Some(row) = rows.next().await {
 373            let row = row?;
 374            let connection = row.connection();
 375            connections.push(connection);
 376            collaborators.push(proto::Collaborator {
 377                peer_id: Some(connection.into()),
 378                replica_id: row.replica_id.0 as u32,
 379                user_id: row.user_id.to_proto(),
 380            });
 381        }
 382
 383        drop(rows);
 384
 385        if collaborators.is_empty() {
 386            self.snapshot_channel_buffer(channel_id, &tx).await?;
 387        }
 388
 389        Ok(LeftChannelBuffer {
 390            channel_id,
 391            collaborators,
 392            connections,
 393        })
 394    }
 395
 396    pub async fn get_channel_buffer_collaborators(
 397        &self,
 398        channel_id: ChannelId,
 399    ) -> Result<Vec<UserId>> {
 400        self.transaction(|tx| async move {
 401            self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
 402                .await
 403        })
 404        .await
 405    }
 406
 407    async fn get_channel_buffer_collaborators_internal(
 408        &self,
 409        channel_id: ChannelId,
 410        tx: &DatabaseTransaction,
 411    ) -> Result<Vec<UserId>> {
 412        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 413        enum QueryUserIds {
 414            UserId,
 415        }
 416
 417        let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
 418            .select_only()
 419            .column(channel_buffer_collaborator::Column::UserId)
 420            .filter(
 421                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 422            )
 423            .into_values::<_, QueryUserIds>()
 424            .all(&*tx)
 425            .await?;
 426
 427        Ok(users)
 428    }
 429
 430    pub async fn update_channel_buffer(
 431        &self,
 432        channel_id: ChannelId,
 433        user: UserId,
 434        operations: &[proto::Operation],
 435    ) -> Result<(
 436        Vec<ConnectionId>,
 437        Vec<UserId>,
 438        i32,
 439        Vec<proto::VectorClockEntry>,
 440    )> {
 441        self.transaction(move |tx| async move {
 442            self.check_user_is_channel_member(channel_id, user, &*tx)
 443                .await?;
 444
 445            let buffer = buffer::Entity::find()
 446                .filter(buffer::Column::ChannelId.eq(channel_id))
 447                .one(&*tx)
 448                .await?
 449                .ok_or_else(|| anyhow!("no such buffer"))?;
 450
 451            let serialization_version = self
 452                .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
 453                .await?;
 454
 455            let operations = operations
 456                .iter()
 457                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
 458                .collect::<Vec<_>>();
 459
 460            let mut channel_members;
 461            let max_version;
 462
 463            if !operations.is_empty() {
 464                let max_operation = operations
 465                    .iter()
 466                    .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
 467                    .unwrap();
 468
 469                max_version = vec![proto::VectorClockEntry {
 470                    replica_id: *max_operation.replica_id.as_ref() as u32,
 471                    timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
 472                }];
 473
 474                // get current channel participants and save the max operation above
 475                self.save_max_operation(
 476                    user,
 477                    buffer.id,
 478                    buffer.epoch,
 479                    *max_operation.replica_id.as_ref(),
 480                    *max_operation.lamport_timestamp.as_ref(),
 481                    &*tx,
 482                )
 483                .await?;
 484
 485                channel_members = self.get_channel_members_internal(channel_id, &*tx).await?;
 486                let collaborators = self
 487                    .get_channel_buffer_collaborators_internal(channel_id, &*tx)
 488                    .await?;
 489                channel_members.retain(|member| !collaborators.contains(member));
 490
 491                buffer_operation::Entity::insert_many(operations)
 492                    .on_conflict(
 493                        OnConflict::columns([
 494                            buffer_operation::Column::BufferId,
 495                            buffer_operation::Column::Epoch,
 496                            buffer_operation::Column::LamportTimestamp,
 497                            buffer_operation::Column::ReplicaId,
 498                        ])
 499                        .do_nothing()
 500                        .to_owned(),
 501                    )
 502                    .exec(&*tx)
 503                    .await?;
 504            } else {
 505                channel_members = Vec::new();
 506                max_version = Vec::new();
 507            }
 508
 509            let mut connections = Vec::new();
 510            let mut rows = channel_buffer_collaborator::Entity::find()
 511                .filter(
 512                    Condition::all()
 513                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 514                )
 515                .stream(&*tx)
 516                .await?;
 517            while let Some(row) = rows.next().await {
 518                let row = row?;
 519                connections.push(ConnectionId {
 520                    id: row.connection_id as u32,
 521                    owner_id: row.connection_server_id.0 as u32,
 522                });
 523            }
 524
 525            Ok((connections, channel_members, buffer.epoch, max_version))
 526        })
 527        .await
 528    }
 529
 530    async fn save_max_operation(
 531        &self,
 532        user_id: UserId,
 533        buffer_id: BufferId,
 534        epoch: i32,
 535        replica_id: i32,
 536        lamport_timestamp: i32,
 537        tx: &DatabaseTransaction,
 538    ) -> Result<()> {
 539        use observed_buffer_edits::Column;
 540
 541        observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
 542            user_id: ActiveValue::Set(user_id),
 543            buffer_id: ActiveValue::Set(buffer_id),
 544            epoch: ActiveValue::Set(epoch),
 545            replica_id: ActiveValue::Set(replica_id),
 546            lamport_timestamp: ActiveValue::Set(lamport_timestamp),
 547        })
 548        .on_conflict(
 549            OnConflict::columns([Column::UserId, Column::BufferId])
 550                .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
 551                .action_cond_where(
 552                    Condition::any().add(Column::Epoch.lt(epoch)).add(
 553                        Condition::all().add(Column::Epoch.eq(epoch)).add(
 554                            Condition::any()
 555                                .add(Column::LamportTimestamp.lt(lamport_timestamp))
 556                                .add(
 557                                    Column::LamportTimestamp
 558                                        .eq(lamport_timestamp)
 559                                        .and(Column::ReplicaId.lt(replica_id)),
 560                                ),
 561                        ),
 562                    ),
 563                )
 564                .to_owned(),
 565        )
 566        .exec_without_returning(tx)
 567        .await?;
 568
 569        Ok(())
 570    }
 571
 572    async fn get_buffer_operation_serialization_version(
 573        &self,
 574        buffer_id: BufferId,
 575        epoch: i32,
 576        tx: &DatabaseTransaction,
 577    ) -> Result<i32> {
 578        Ok(buffer_snapshot::Entity::find()
 579            .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
 580            .filter(buffer_snapshot::Column::Epoch.eq(epoch))
 581            .select_only()
 582            .column(buffer_snapshot::Column::OperationSerializationVersion)
 583            .into_values::<_, QueryOperationSerializationVersion>()
 584            .one(&*tx)
 585            .await?
 586            .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
 587    }
 588
 589    pub async fn get_channel_buffer(
 590        &self,
 591        channel_id: ChannelId,
 592        tx: &DatabaseTransaction,
 593    ) -> Result<buffer::Model> {
 594        Ok(channel::Model {
 595            id: channel_id,
 596            ..Default::default()
 597        }
 598        .find_related(buffer::Entity)
 599        .one(&*tx)
 600        .await?
 601        .ok_or_else(|| anyhow!("no such buffer"))?)
 602    }
 603
 604    async fn get_buffer_state(
 605        &self,
 606        buffer: &buffer::Model,
 607        tx: &DatabaseTransaction,
 608    ) -> Result<(
 609        String,
 610        Vec<proto::Operation>,
 611        Option<buffer_operation::Model>,
 612    )> {
 613        let id = buffer.id;
 614        let (base_text, version) = if buffer.epoch > 0 {
 615            let snapshot = buffer_snapshot::Entity::find()
 616                .filter(
 617                    buffer_snapshot::Column::BufferId
 618                        .eq(id)
 619                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
 620                )
 621                .one(&*tx)
 622                .await?
 623                .ok_or_else(|| anyhow!("no such snapshot"))?;
 624
 625            let version = snapshot.operation_serialization_version;
 626            (snapshot.text, version)
 627        } else {
 628            (String::new(), storage::SERIALIZATION_VERSION)
 629        };
 630
 631        let mut rows = buffer_operation::Entity::find()
 632            .filter(
 633                buffer_operation::Column::BufferId
 634                    .eq(id)
 635                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 636            )
 637            .order_by_asc(buffer_operation::Column::LamportTimestamp)
 638            .order_by_asc(buffer_operation::Column::ReplicaId)
 639            .stream(&*tx)
 640            .await?;
 641
 642        let mut operations = Vec::new();
 643        let mut last_row = None;
 644        while let Some(row) = rows.next().await {
 645            let row = row?;
 646            last_row = Some(buffer_operation::Model {
 647                buffer_id: row.buffer_id,
 648                epoch: row.epoch,
 649                lamport_timestamp: row.lamport_timestamp,
 650                replica_id: row.lamport_timestamp,
 651                value: Default::default(),
 652            });
 653            operations.push(proto::Operation {
 654                variant: Some(operation_from_storage(row, version)?),
 655            });
 656        }
 657
 658        Ok((base_text, operations, last_row))
 659    }
 660
 661    async fn snapshot_channel_buffer(
 662        &self,
 663        channel_id: ChannelId,
 664        tx: &DatabaseTransaction,
 665    ) -> Result<()> {
 666        let buffer = self.get_channel_buffer(channel_id, tx).await?;
 667        let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
 668        if operations.is_empty() {
 669            return Ok(());
 670        }
 671
 672        let mut text_buffer = text::Buffer::new(0, 0, base_text);
 673        text_buffer
 674            .apply_ops(operations.into_iter().filter_map(operation_from_wire))
 675            .unwrap();
 676
 677        let base_text = text_buffer.text();
 678        let epoch = buffer.epoch + 1;
 679
 680        buffer_snapshot::Model {
 681            buffer_id: buffer.id,
 682            epoch,
 683            text: base_text,
 684            operation_serialization_version: storage::SERIALIZATION_VERSION,
 685        }
 686        .into_active_model()
 687        .insert(tx)
 688        .await?;
 689
 690        buffer::ActiveModel {
 691            id: ActiveValue::Unchanged(buffer.id),
 692            epoch: ActiveValue::Set(epoch),
 693            ..Default::default()
 694        }
 695        .save(tx)
 696        .await?;
 697
 698        Ok(())
 699    }
 700
 701    pub async fn observe_buffer_version(
 702        &self,
 703        buffer_id: BufferId,
 704        user_id: UserId,
 705        epoch: i32,
 706        version: &[proto::VectorClockEntry],
 707    ) -> Result<()> {
 708        self.transaction(|tx| async move {
 709            // For now, combine concurrent operations.
 710            let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
 711                return Ok(());
 712            };
 713            self.save_max_operation(
 714                user_id,
 715                buffer_id,
 716                epoch,
 717                component.replica_id as i32,
 718                component.timestamp as i32,
 719                &*tx,
 720            )
 721            .await?;
 722            Ok(())
 723        })
 724        .await
 725    }
 726
 727    pub async fn unseen_channel_buffer_changes(
 728        &self,
 729        user_id: UserId,
 730        channel_ids: &[ChannelId],
 731        tx: &DatabaseTransaction,
 732    ) -> Result<Vec<proto::UnseenChannelBufferChange>> {
 733        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 734        enum QueryIds {
 735            ChannelId,
 736            Id,
 737        }
 738
 739        let mut channel_ids_by_buffer_id = HashMap::default();
 740        let mut rows = buffer::Entity::find()
 741            .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied()))
 742            .stream(&*tx)
 743            .await?;
 744        while let Some(row) = rows.next().await {
 745            let row = row?;
 746            channel_ids_by_buffer_id.insert(row.id, row.channel_id);
 747        }
 748        drop(rows);
 749
 750        let mut observed_edits_by_buffer_id = HashMap::default();
 751        let mut rows = observed_buffer_edits::Entity::find()
 752            .filter(observed_buffer_edits::Column::UserId.eq(user_id))
 753            .filter(
 754                observed_buffer_edits::Column::BufferId
 755                    .is_in(channel_ids_by_buffer_id.keys().copied()),
 756            )
 757            .stream(&*tx)
 758            .await?;
 759        while let Some(row) = rows.next().await {
 760            let row = row?;
 761            observed_edits_by_buffer_id.insert(row.buffer_id, row);
 762        }
 763        drop(rows);
 764
 765        let latest_operations = self
 766            .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
 767            .await?;
 768
 769        let mut changes = Vec::default();
 770        for latest in latest_operations {
 771            if let Some(observed) = observed_edits_by_buffer_id.get(&latest.buffer_id) {
 772                if (
 773                    observed.epoch,
 774                    observed.lamport_timestamp,
 775                    observed.replica_id,
 776                ) >= (latest.epoch, latest.lamport_timestamp, latest.replica_id)
 777                {
 778                    continue;
 779                }
 780            }
 781
 782            if let Some(channel_id) = channel_ids_by_buffer_id.get(&latest.buffer_id) {
 783                changes.push(proto::UnseenChannelBufferChange {
 784                    channel_id: channel_id.to_proto(),
 785                    epoch: latest.epoch as u64,
 786                    version: vec![proto::VectorClockEntry {
 787                        replica_id: latest.replica_id as u32,
 788                        timestamp: latest.lamport_timestamp as u32,
 789                    }],
 790                });
 791            }
 792        }
 793
 794        Ok(changes)
 795    }
 796
 797    pub async fn get_latest_operations_for_buffers(
 798        &self,
 799        buffer_ids: impl IntoIterator<Item = BufferId>,
 800        tx: &DatabaseTransaction,
 801    ) -> Result<Vec<buffer_operation::Model>> {
 802        let mut values = String::new();
 803        for id in buffer_ids {
 804            if !values.is_empty() {
 805                values.push_str(", ");
 806            }
 807            write!(&mut values, "({})", id).unwrap();
 808        }
 809
 810        if values.is_empty() {
 811            return Ok(Vec::default());
 812        }
 813
 814        let sql = format!(
 815            r#"
 816            SELECT
 817                *
 818            FROM
 819            (
 820                SELECT
 821                    *,
 822                    row_number() OVER (
 823                        PARTITION BY buffer_id
 824                        ORDER BY
 825                            epoch DESC,
 826                            lamport_timestamp DESC,
 827                            replica_id DESC
 828                    ) as row_number
 829                FROM buffer_operations
 830                WHERE
 831                    buffer_id in ({values})
 832            ) AS last_operations
 833            WHERE
 834                row_number = 1
 835            "#,
 836        );
 837
 838        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 839        Ok(buffer_operation::Entity::find()
 840            .from_raw_sql(stmt)
 841            .all(&*tx)
 842            .await?)
 843    }
 844}
 845
 846fn operation_to_storage(
 847    operation: &proto::Operation,
 848    buffer: &buffer::Model,
 849    _format: i32,
 850) -> Option<buffer_operation::ActiveModel> {
 851    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
 852        proto::operation::Variant::Edit(operation) => (
 853            operation.replica_id,
 854            operation.lamport_timestamp,
 855            storage::Operation {
 856                version: version_to_storage(&operation.version),
 857                is_undo: false,
 858                edit_ranges: operation
 859                    .ranges
 860                    .iter()
 861                    .map(|range| storage::Range {
 862                        start: range.start,
 863                        end: range.end,
 864                    })
 865                    .collect(),
 866                edit_texts: operation.new_text.clone(),
 867                undo_counts: Vec::new(),
 868            },
 869        ),
 870        proto::operation::Variant::Undo(operation) => (
 871            operation.replica_id,
 872            operation.lamport_timestamp,
 873            storage::Operation {
 874                version: version_to_storage(&operation.version),
 875                is_undo: true,
 876                edit_ranges: Vec::new(),
 877                edit_texts: Vec::new(),
 878                undo_counts: operation
 879                    .counts
 880                    .iter()
 881                    .map(|entry| storage::UndoCount {
 882                        replica_id: entry.replica_id,
 883                        lamport_timestamp: entry.lamport_timestamp,
 884                        count: entry.count,
 885                    })
 886                    .collect(),
 887            },
 888        ),
 889        _ => None?,
 890    };
 891
 892    Some(buffer_operation::ActiveModel {
 893        buffer_id: ActiveValue::Set(buffer.id),
 894        epoch: ActiveValue::Set(buffer.epoch),
 895        replica_id: ActiveValue::Set(replica_id as i32),
 896        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
 897        value: ActiveValue::Set(value.encode_to_vec()),
 898    })
 899}
 900
 901fn operation_from_storage(
 902    row: buffer_operation::Model,
 903    _format_version: i32,
 904) -> Result<proto::operation::Variant, Error> {
 905    let operation =
 906        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
 907    let version = version_from_storage(&operation.version);
 908    Ok(if operation.is_undo {
 909        proto::operation::Variant::Undo(proto::operation::Undo {
 910            replica_id: row.replica_id as u32,
 911            lamport_timestamp: row.lamport_timestamp as u32,
 912            version,
 913            counts: operation
 914                .undo_counts
 915                .iter()
 916                .map(|entry| proto::UndoCount {
 917                    replica_id: entry.replica_id,
 918                    lamport_timestamp: entry.lamport_timestamp,
 919                    count: entry.count,
 920                })
 921                .collect(),
 922        })
 923    } else {
 924        proto::operation::Variant::Edit(proto::operation::Edit {
 925            replica_id: row.replica_id as u32,
 926            lamport_timestamp: row.lamport_timestamp as u32,
 927            version,
 928            ranges: operation
 929                .edit_ranges
 930                .into_iter()
 931                .map(|range| proto::Range {
 932                    start: range.start,
 933                    end: range.end,
 934                })
 935                .collect(),
 936            new_text: operation.edit_texts,
 937        })
 938    })
 939}
 940
 941fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
 942    version
 943        .iter()
 944        .map(|entry| storage::VectorClockEntry {
 945            replica_id: entry.replica_id,
 946            timestamp: entry.timestamp,
 947        })
 948        .collect()
 949}
 950
 951fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
 952    version
 953        .iter()
 954        .map(|entry| proto::VectorClockEntry {
 955            replica_id: entry.replica_id,
 956            timestamp: entry.timestamp,
 957        })
 958        .collect()
 959}
 960
 961// This is currently a manual copy of the deserialization code in the client's langauge crate
 962pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
 963    match operation.variant? {
 964        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
 965            timestamp: clock::Lamport {
 966                replica_id: edit.replica_id as text::ReplicaId,
 967                value: edit.lamport_timestamp,
 968            },
 969            version: version_from_wire(&edit.version),
 970            ranges: edit
 971                .ranges
 972                .into_iter()
 973                .map(|range| {
 974                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
 975                })
 976                .collect(),
 977            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
 978        })),
 979        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
 980            timestamp: clock::Lamport {
 981                replica_id: undo.replica_id as text::ReplicaId,
 982                value: undo.lamport_timestamp,
 983            },
 984            version: version_from_wire(&undo.version),
 985            counts: undo
 986                .counts
 987                .into_iter()
 988                .map(|c| {
 989                    (
 990                        clock::Lamport {
 991                            replica_id: c.replica_id as text::ReplicaId,
 992                            value: c.lamport_timestamp,
 993                        },
 994                        c.count,
 995                    )
 996                })
 997                .collect(),
 998        })),
 999        _ => None,
1000    }
1001}
1002
1003fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
1004    let mut version = clock::Global::new();
1005    for entry in message {
1006        version.observe(clock::Lamport {
1007            replica_id: entry.replica_id as text::ReplicaId,
1008            value: entry.timestamp,
1009        });
1010    }
1011    version
1012}
1013
1014fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
1015    let mut message = Vec::new();
1016    for entry in version.iter() {
1017        message.push(proto::VectorClockEntry {
1018            replica_id: entry.replica_id as u32,
1019            timestamp: entry.value,
1020        });
1021    }
1022    message
1023}
1024
1025#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1026enum QueryOperationSerializationVersion {
1027    OperationSerializationVersion,
1028}
1029
1030mod storage {
1031    #![allow(non_snake_case)]
1032    use prost::Message;
1033    pub const SERIALIZATION_VERSION: i32 = 1;
1034
1035    #[derive(Message)]
1036    pub struct Operation {
1037        #[prost(message, repeated, tag = "2")]
1038        pub version: Vec<VectorClockEntry>,
1039        #[prost(bool, tag = "3")]
1040        pub is_undo: bool,
1041        #[prost(message, repeated, tag = "4")]
1042        pub edit_ranges: Vec<Range>,
1043        #[prost(string, repeated, tag = "5")]
1044        pub edit_texts: Vec<String>,
1045        #[prost(message, repeated, tag = "6")]
1046        pub undo_counts: Vec<UndoCount>,
1047    }
1048
1049    #[derive(Message)]
1050    pub struct VectorClockEntry {
1051        #[prost(uint32, tag = "1")]
1052        pub replica_id: u32,
1053        #[prost(uint32, tag = "2")]
1054        pub timestamp: u32,
1055    }
1056
1057    #[derive(Message)]
1058    pub struct Range {
1059        #[prost(uint64, tag = "1")]
1060        pub start: u64,
1061        #[prost(uint64, tag = "2")]
1062        pub end: u64,
1063    }
1064
1065    #[derive(Message)]
1066    pub struct UndoCount {
1067        #[prost(uint32, tag = "1")]
1068        pub replica_id: u32,
1069        #[prost(uint32, tag = "2")]
1070        pub lamport_timestamp: u32,
1071        #[prost(uint32, tag = "3")]
1072        pub count: u32,
1073    }
1074}