buffers.rs

   1use super::*;
   2use prost::Message;
   3use text::{EditOperation, UndoOperation};
   4
   5pub struct LeftChannelBuffer {
   6    pub channel_id: ChannelId,
   7    pub collaborators: Vec<proto::Collaborator>,
   8    pub connections: Vec<ConnectionId>,
   9}
  10
  11impl Database {
  12    pub async fn join_channel_buffer(
  13        &self,
  14        channel_id: ChannelId,
  15        user_id: UserId,
  16        connection: ConnectionId,
  17    ) -> Result<proto::JoinChannelBufferResponse> {
  18        self.transaction(|tx| async move {
  19            self.check_user_is_channel_member(channel_id, user_id, &tx)
  20                .await?;
  21
  22            let buffer = channel::Model {
  23                id: channel_id,
  24                ..Default::default()
  25            }
  26            .find_related(buffer::Entity)
  27            .one(&*tx)
  28            .await?;
  29
  30            let buffer = if let Some(buffer) = buffer {
  31                buffer
  32            } else {
  33                let buffer = buffer::ActiveModel {
  34                    channel_id: ActiveValue::Set(channel_id),
  35                    ..Default::default()
  36                }
  37                .insert(&*tx)
  38                .await?;
  39                buffer_snapshot::ActiveModel {
  40                    buffer_id: ActiveValue::Set(buffer.id),
  41                    epoch: ActiveValue::Set(0),
  42                    text: ActiveValue::Set(String::new()),
  43                    operation_serialization_version: ActiveValue::Set(
  44                        storage::SERIALIZATION_VERSION,
  45                    ),
  46                }
  47                .insert(&*tx)
  48                .await?;
  49                buffer
  50            };
  51
  52            // Join the collaborators
  53            let mut collaborators = channel_buffer_collaborator::Entity::find()
  54                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
  55                .all(&*tx)
  56                .await?;
  57            let replica_ids = collaborators
  58                .iter()
  59                .map(|c| c.replica_id)
  60                .collect::<HashSet<_>>();
  61            let mut replica_id = ReplicaId(0);
  62            while replica_ids.contains(&replica_id) {
  63                replica_id.0 += 1;
  64            }
  65            let collaborator = channel_buffer_collaborator::ActiveModel {
  66                channel_id: ActiveValue::Set(channel_id),
  67                connection_id: ActiveValue::Set(connection.id as i32),
  68                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
  69                user_id: ActiveValue::Set(user_id),
  70                replica_id: ActiveValue::Set(replica_id),
  71                ..Default::default()
  72            }
  73            .insert(&*tx)
  74            .await?;
  75            collaborators.push(collaborator);
  76
  77            let (base_text, operations, max_operation) =
  78                self.get_buffer_state(&buffer, &tx).await?;
  79
  80            // Save the last observed operation
  81            if let Some(op) = max_operation {
  82                observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
  83                    user_id: ActiveValue::Set(user_id),
  84                    buffer_id: ActiveValue::Set(buffer.id),
  85                    epoch: ActiveValue::Set(op.epoch),
  86                    lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
  87                    replica_id: ActiveValue::Set(op.replica_id),
  88                })
  89                .on_conflict(
  90                    OnConflict::columns([
  91                        observed_buffer_edits::Column::UserId,
  92                        observed_buffer_edits::Column::BufferId,
  93                    ])
  94                    .update_columns([
  95                        observed_buffer_edits::Column::Epoch,
  96                        observed_buffer_edits::Column::LamportTimestamp,
  97                    ])
  98                    .to_owned(),
  99                )
 100                .exec(&*tx)
 101                .await?;
 102            }
 103
 104            Ok(proto::JoinChannelBufferResponse {
 105                buffer_id: buffer.id.to_proto(),
 106                replica_id: replica_id.to_proto() as u32,
 107                base_text,
 108                operations,
 109                epoch: buffer.epoch as u64,
 110                collaborators: collaborators
 111                    .into_iter()
 112                    .map(|collaborator| proto::Collaborator {
 113                        peer_id: Some(collaborator.connection().into()),
 114                        user_id: collaborator.user_id.to_proto(),
 115                        replica_id: collaborator.replica_id.0 as u32,
 116                    })
 117                    .collect(),
 118            })
 119        })
 120        .await
 121    }
 122
 123    pub async fn rejoin_channel_buffers(
 124        &self,
 125        buffers: &[proto::ChannelBufferVersion],
 126        user_id: UserId,
 127        connection_id: ConnectionId,
 128    ) -> Result<Vec<RejoinedChannelBuffer>> {
 129        self.transaction(|tx| async move {
 130            let mut results = Vec::new();
 131            for client_buffer in buffers {
 132                let channel_id = ChannelId::from_proto(client_buffer.channel_id);
 133                if self
 134                    .check_user_is_channel_member(channel_id, user_id, &*tx)
 135                    .await
 136                    .is_err()
 137                {
 138                    log::info!("user is not a member of channel");
 139                    continue;
 140                }
 141
 142                let buffer = self.get_channel_buffer(channel_id, &*tx).await?;
 143                let mut collaborators = channel_buffer_collaborator::Entity::find()
 144                    .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 145                    .all(&*tx)
 146                    .await?;
 147
 148                // If the buffer epoch hasn't changed since the client lost
 149                // connection, then the client's buffer can be syncronized with
 150                // the server's buffer.
 151                if buffer.epoch as u64 != client_buffer.epoch {
 152                    log::info!("can't rejoin buffer, epoch has changed");
 153                    continue;
 154                }
 155
 156                // Find the collaborator record for this user's previous lost
 157                // connection. Update it with the new connection id.
 158                let server_id = ServerId(connection_id.owner_id as i32);
 159                let Some(self_collaborator) = collaborators.iter_mut().find(|c| {
 160                    c.user_id == user_id
 161                        && (c.connection_lost || c.connection_server_id != server_id)
 162                }) else {
 163                    log::info!("can't rejoin buffer, no previous collaborator found");
 164                    continue;
 165                };
 166                let old_connection_id = self_collaborator.connection();
 167                *self_collaborator = channel_buffer_collaborator::ActiveModel {
 168                    id: ActiveValue::Unchanged(self_collaborator.id),
 169                    connection_id: ActiveValue::Set(connection_id.id as i32),
 170                    connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 171                    connection_lost: ActiveValue::Set(false),
 172                    ..Default::default()
 173                }
 174                .update(&*tx)
 175                .await?;
 176
 177                let client_version = version_from_wire(&client_buffer.version);
 178                let serialization_version = self
 179                    .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
 180                    .await?;
 181
 182                let mut rows = buffer_operation::Entity::find()
 183                    .filter(
 184                        buffer_operation::Column::BufferId
 185                            .eq(buffer.id)
 186                            .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 187                    )
 188                    .stream(&*tx)
 189                    .await?;
 190
 191                // Find the server's version vector and any operations
 192                // that the client has not seen.
 193                let mut server_version = clock::Global::new();
 194                let mut operations = Vec::new();
 195                while let Some(row) = rows.next().await {
 196                    let row = row?;
 197                    let timestamp = clock::Lamport {
 198                        replica_id: row.replica_id as u16,
 199                        value: row.lamport_timestamp as u32,
 200                    };
 201                    server_version.observe(timestamp);
 202                    if !client_version.observed(timestamp) {
 203                        operations.push(proto::Operation {
 204                            variant: Some(operation_from_storage(row, serialization_version)?),
 205                        })
 206                    }
 207                }
 208
 209                results.push(RejoinedChannelBuffer {
 210                    old_connection_id,
 211                    buffer: proto::RejoinedChannelBuffer {
 212                        channel_id: client_buffer.channel_id,
 213                        version: version_to_wire(&server_version),
 214                        operations,
 215                        collaborators: collaborators
 216                            .into_iter()
 217                            .map(|collaborator| proto::Collaborator {
 218                                peer_id: Some(collaborator.connection().into()),
 219                                user_id: collaborator.user_id.to_proto(),
 220                                replica_id: collaborator.replica_id.0 as u32,
 221                            })
 222                            .collect(),
 223                    },
 224                });
 225            }
 226
 227            Ok(results)
 228        })
 229        .await
 230    }
 231
 232    pub async fn clear_stale_channel_buffer_collaborators(
 233        &self,
 234        channel_id: ChannelId,
 235        server_id: ServerId,
 236    ) -> Result<RefreshedChannelBuffer> {
 237        self.transaction(|tx| async move {
 238            let db_collaborators = channel_buffer_collaborator::Entity::find()
 239                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 240                .all(&*tx)
 241                .await?;
 242
 243            let mut connection_ids = Vec::new();
 244            let mut collaborators = Vec::new();
 245            let mut collaborator_ids_to_remove = Vec::new();
 246            for db_collaborator in &db_collaborators {
 247                if !db_collaborator.connection_lost
 248                    && db_collaborator.connection_server_id == server_id
 249                {
 250                    connection_ids.push(db_collaborator.connection());
 251                    collaborators.push(proto::Collaborator {
 252                        peer_id: Some(db_collaborator.connection().into()),
 253                        replica_id: db_collaborator.replica_id.0 as u32,
 254                        user_id: db_collaborator.user_id.to_proto(),
 255                    })
 256                } else {
 257                    collaborator_ids_to_remove.push(db_collaborator.id);
 258                }
 259            }
 260
 261            channel_buffer_collaborator::Entity::delete_many()
 262                .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
 263                .exec(&*tx)
 264                .await?;
 265
 266            Ok(RefreshedChannelBuffer {
 267                connection_ids,
 268                collaborators,
 269            })
 270        })
 271        .await
 272    }
 273
 274    pub async fn leave_channel_buffer(
 275        &self,
 276        channel_id: ChannelId,
 277        connection: ConnectionId,
 278    ) -> Result<LeftChannelBuffer> {
 279        self.transaction(|tx| async move {
 280            self.leave_channel_buffer_internal(channel_id, connection, &*tx)
 281                .await
 282        })
 283        .await
 284    }
 285
 286    pub async fn channel_buffer_connection_lost(
 287        &self,
 288        connection: ConnectionId,
 289        tx: &DatabaseTransaction,
 290    ) -> Result<()> {
 291        channel_buffer_collaborator::Entity::update_many()
 292            .filter(
 293                Condition::all()
 294                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 295                    .add(
 296                        channel_buffer_collaborator::Column::ConnectionServerId
 297                            .eq(connection.owner_id as i32),
 298                    ),
 299            )
 300            .set(channel_buffer_collaborator::ActiveModel {
 301                connection_lost: ActiveValue::set(true),
 302                ..Default::default()
 303            })
 304            .exec(&*tx)
 305            .await?;
 306        Ok(())
 307    }
 308
 309    pub async fn leave_channel_buffers(
 310        &self,
 311        connection: ConnectionId,
 312    ) -> Result<Vec<LeftChannelBuffer>> {
 313        self.transaction(|tx| async move {
 314            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 315            enum QueryChannelIds {
 316                ChannelId,
 317            }
 318
 319            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
 320                .select_only()
 321                .column(channel_buffer_collaborator::Column::ChannelId)
 322                .filter(Condition::all().add(
 323                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
 324                ))
 325                .into_values::<_, QueryChannelIds>()
 326                .all(&*tx)
 327                .await?;
 328
 329            let mut result = Vec::new();
 330            for channel_id in channel_ids {
 331                let left_channel_buffer = self
 332                    .leave_channel_buffer_internal(channel_id, connection, &*tx)
 333                    .await?;
 334                result.push(left_channel_buffer);
 335            }
 336
 337            Ok(result)
 338        })
 339        .await
 340    }
 341
 342    pub async fn leave_channel_buffer_internal(
 343        &self,
 344        channel_id: ChannelId,
 345        connection: ConnectionId,
 346        tx: &DatabaseTransaction,
 347    ) -> Result<LeftChannelBuffer> {
 348        let result = channel_buffer_collaborator::Entity::delete_many()
 349            .filter(
 350                Condition::all()
 351                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 352                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 353                    .add(
 354                        channel_buffer_collaborator::Column::ConnectionServerId
 355                            .eq(connection.owner_id as i32),
 356                    ),
 357            )
 358            .exec(&*tx)
 359            .await?;
 360        if result.rows_affected == 0 {
 361            Err(anyhow!("not a collaborator on this project"))?;
 362        }
 363
 364        let mut collaborators = Vec::new();
 365        let mut connections = Vec::new();
 366        let mut rows = channel_buffer_collaborator::Entity::find()
 367            .filter(
 368                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 369            )
 370            .stream(&*tx)
 371            .await?;
 372        while let Some(row) = rows.next().await {
 373            let row = row?;
 374            let connection = row.connection();
 375            connections.push(connection);
 376            collaborators.push(proto::Collaborator {
 377                peer_id: Some(connection.into()),
 378                replica_id: row.replica_id.0 as u32,
 379                user_id: row.user_id.to_proto(),
 380            });
 381        }
 382
 383        drop(rows);
 384
 385        if collaborators.is_empty() {
 386            self.snapshot_channel_buffer(channel_id, &tx).await?;
 387        }
 388
 389        Ok(LeftChannelBuffer {
 390            channel_id,
 391            collaborators,
 392            connections,
 393        })
 394    }
 395
 396    pub async fn get_channel_buffer_collaborators(
 397        &self,
 398        channel_id: ChannelId,
 399    ) -> Result<Vec<UserId>> {
 400        self.transaction(|tx| async move {
 401            self.get_channel_buffer_collaborators_internal(channel_id, &*tx)
 402                .await
 403        })
 404        .await
 405    }
 406
 407    async fn get_channel_buffer_collaborators_internal(
 408        &self,
 409        channel_id: ChannelId,
 410        tx: &DatabaseTransaction,
 411    ) -> Result<Vec<UserId>> {
 412        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 413        enum QueryUserIds {
 414            UserId,
 415        }
 416
 417        let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
 418            .select_only()
 419            .column(channel_buffer_collaborator::Column::UserId)
 420            .filter(
 421                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 422            )
 423            .into_values::<_, QueryUserIds>()
 424            .all(&*tx)
 425            .await?;
 426
 427        Ok(users)
 428    }
 429
 430    pub async fn update_channel_buffer(
 431        &self,
 432        channel_id: ChannelId,
 433        user: UserId,
 434        operations: &[proto::Operation],
 435    ) -> Result<(
 436        Vec<ConnectionId>,
 437        Vec<UserId>,
 438        i32,
 439        Vec<proto::VectorClockEntry>,
 440    )> {
 441        self.transaction(move |tx| async move {
 442            self.check_user_is_channel_member(channel_id, user, &*tx)
 443                .await?;
 444
 445            let buffer = buffer::Entity::find()
 446                .filter(buffer::Column::ChannelId.eq(channel_id))
 447                .one(&*tx)
 448                .await?
 449                .ok_or_else(|| anyhow!("no such buffer"))?;
 450
 451            let serialization_version = self
 452                .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &*tx)
 453                .await?;
 454
 455            let operations = operations
 456                .iter()
 457                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
 458                .collect::<Vec<_>>();
 459
 460            let mut channel_members;
 461            let max_version;
 462
 463            if !operations.is_empty() {
 464                let max_operation = operations
 465                    .iter()
 466                    .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
 467                    .unwrap();
 468
 469                max_version = vec![proto::VectorClockEntry {
 470                    replica_id: *max_operation.replica_id.as_ref() as u32,
 471                    timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
 472                }];
 473
 474                // get current channel participants and save the max operation above
 475                self.save_max_operation(
 476                    user,
 477                    buffer.id,
 478                    buffer.epoch,
 479                    *max_operation.replica_id.as_ref(),
 480                    *max_operation.lamport_timestamp.as_ref(),
 481                    &*tx,
 482                )
 483                .await?;
 484
 485                channel_members = self
 486                    .get_channel_participants_internal(channel_id, &*tx)
 487                    .await?;
 488                let collaborators = self
 489                    .get_channel_buffer_collaborators_internal(channel_id, &*tx)
 490                    .await?;
 491                channel_members.retain(|member| !collaborators.contains(member));
 492
 493                buffer_operation::Entity::insert_many(operations)
 494                    .on_conflict(
 495                        OnConflict::columns([
 496                            buffer_operation::Column::BufferId,
 497                            buffer_operation::Column::Epoch,
 498                            buffer_operation::Column::LamportTimestamp,
 499                            buffer_operation::Column::ReplicaId,
 500                        ])
 501                        .do_nothing()
 502                        .to_owned(),
 503                    )
 504                    .exec(&*tx)
 505                    .await?;
 506            } else {
 507                channel_members = Vec::new();
 508                max_version = Vec::new();
 509            }
 510
 511            let mut connections = Vec::new();
 512            let mut rows = channel_buffer_collaborator::Entity::find()
 513                .filter(
 514                    Condition::all()
 515                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 516                )
 517                .stream(&*tx)
 518                .await?;
 519            while let Some(row) = rows.next().await {
 520                let row = row?;
 521                connections.push(ConnectionId {
 522                    id: row.connection_id as u32,
 523                    owner_id: row.connection_server_id.0 as u32,
 524                });
 525            }
 526
 527            Ok((connections, channel_members, buffer.epoch, max_version))
 528        })
 529        .await
 530    }
 531
 532    async fn save_max_operation(
 533        &self,
 534        user_id: UserId,
 535        buffer_id: BufferId,
 536        epoch: i32,
 537        replica_id: i32,
 538        lamport_timestamp: i32,
 539        tx: &DatabaseTransaction,
 540    ) -> Result<()> {
 541        use observed_buffer_edits::Column;
 542
 543        observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
 544            user_id: ActiveValue::Set(user_id),
 545            buffer_id: ActiveValue::Set(buffer_id),
 546            epoch: ActiveValue::Set(epoch),
 547            replica_id: ActiveValue::Set(replica_id),
 548            lamport_timestamp: ActiveValue::Set(lamport_timestamp),
 549        })
 550        .on_conflict(
 551            OnConflict::columns([Column::UserId, Column::BufferId])
 552                .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
 553                .action_cond_where(
 554                    Condition::any().add(Column::Epoch.lt(epoch)).add(
 555                        Condition::all().add(Column::Epoch.eq(epoch)).add(
 556                            Condition::any()
 557                                .add(Column::LamportTimestamp.lt(lamport_timestamp))
 558                                .add(
 559                                    Column::LamportTimestamp
 560                                        .eq(lamport_timestamp)
 561                                        .and(Column::ReplicaId.lt(replica_id)),
 562                                ),
 563                        ),
 564                    ),
 565                )
 566                .to_owned(),
 567        )
 568        .exec_without_returning(tx)
 569        .await?;
 570
 571        Ok(())
 572    }
 573
 574    async fn get_buffer_operation_serialization_version(
 575        &self,
 576        buffer_id: BufferId,
 577        epoch: i32,
 578        tx: &DatabaseTransaction,
 579    ) -> Result<i32> {
 580        Ok(buffer_snapshot::Entity::find()
 581            .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
 582            .filter(buffer_snapshot::Column::Epoch.eq(epoch))
 583            .select_only()
 584            .column(buffer_snapshot::Column::OperationSerializationVersion)
 585            .into_values::<_, QueryOperationSerializationVersion>()
 586            .one(&*tx)
 587            .await?
 588            .ok_or_else(|| anyhow!("missing buffer snapshot"))?)
 589    }
 590
 591    pub async fn get_channel_buffer(
 592        &self,
 593        channel_id: ChannelId,
 594        tx: &DatabaseTransaction,
 595    ) -> Result<buffer::Model> {
 596        Ok(channel::Model {
 597            id: channel_id,
 598            ..Default::default()
 599        }
 600        .find_related(buffer::Entity)
 601        .one(&*tx)
 602        .await?
 603        .ok_or_else(|| anyhow!("no such buffer"))?)
 604    }
 605
 606    async fn get_buffer_state(
 607        &self,
 608        buffer: &buffer::Model,
 609        tx: &DatabaseTransaction,
 610    ) -> Result<(
 611        String,
 612        Vec<proto::Operation>,
 613        Option<buffer_operation::Model>,
 614    )> {
 615        let id = buffer.id;
 616        let (base_text, version) = if buffer.epoch > 0 {
 617            let snapshot = buffer_snapshot::Entity::find()
 618                .filter(
 619                    buffer_snapshot::Column::BufferId
 620                        .eq(id)
 621                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
 622                )
 623                .one(&*tx)
 624                .await?
 625                .ok_or_else(|| anyhow!("no such snapshot"))?;
 626
 627            let version = snapshot.operation_serialization_version;
 628            (snapshot.text, version)
 629        } else {
 630            (String::new(), storage::SERIALIZATION_VERSION)
 631        };
 632
 633        let mut rows = buffer_operation::Entity::find()
 634            .filter(
 635                buffer_operation::Column::BufferId
 636                    .eq(id)
 637                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 638            )
 639            .order_by_asc(buffer_operation::Column::LamportTimestamp)
 640            .order_by_asc(buffer_operation::Column::ReplicaId)
 641            .stream(&*tx)
 642            .await?;
 643
 644        let mut operations = Vec::new();
 645        let mut last_row = None;
 646        while let Some(row) = rows.next().await {
 647            let row = row?;
 648            last_row = Some(buffer_operation::Model {
 649                buffer_id: row.buffer_id,
 650                epoch: row.epoch,
 651                lamport_timestamp: row.lamport_timestamp,
 652                replica_id: row.lamport_timestamp,
 653                value: Default::default(),
 654            });
 655            operations.push(proto::Operation {
 656                variant: Some(operation_from_storage(row, version)?),
 657            });
 658        }
 659
 660        Ok((base_text, operations, last_row))
 661    }
 662
 663    async fn snapshot_channel_buffer(
 664        &self,
 665        channel_id: ChannelId,
 666        tx: &DatabaseTransaction,
 667    ) -> Result<()> {
 668        let buffer = self.get_channel_buffer(channel_id, tx).await?;
 669        let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
 670        if operations.is_empty() {
 671            return Ok(());
 672        }
 673
 674        let mut text_buffer = text::Buffer::new(0, 0, base_text);
 675        text_buffer
 676            .apply_ops(operations.into_iter().filter_map(operation_from_wire))
 677            .unwrap();
 678
 679        let base_text = text_buffer.text();
 680        let epoch = buffer.epoch + 1;
 681
 682        buffer_snapshot::Model {
 683            buffer_id: buffer.id,
 684            epoch,
 685            text: base_text,
 686            operation_serialization_version: storage::SERIALIZATION_VERSION,
 687        }
 688        .into_active_model()
 689        .insert(tx)
 690        .await?;
 691
 692        buffer::ActiveModel {
 693            id: ActiveValue::Unchanged(buffer.id),
 694            epoch: ActiveValue::Set(epoch),
 695            ..Default::default()
 696        }
 697        .save(tx)
 698        .await?;
 699
 700        Ok(())
 701    }
 702
 703    pub async fn observe_buffer_version(
 704        &self,
 705        buffer_id: BufferId,
 706        user_id: UserId,
 707        epoch: i32,
 708        version: &[proto::VectorClockEntry],
 709    ) -> Result<()> {
 710        self.transaction(|tx| async move {
 711            // For now, combine concurrent operations.
 712            let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
 713                return Ok(());
 714            };
 715            self.save_max_operation(
 716                user_id,
 717                buffer_id,
 718                epoch,
 719                component.replica_id as i32,
 720                component.timestamp as i32,
 721                &*tx,
 722            )
 723            .await?;
 724            Ok(())
 725        })
 726        .await
 727    }
 728
 729    pub async fn unseen_channel_buffer_changes(
 730        &self,
 731        user_id: UserId,
 732        channel_ids: &[ChannelId],
 733        tx: &DatabaseTransaction,
 734    ) -> Result<Vec<proto::UnseenChannelBufferChange>> {
 735        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 736        enum QueryIds {
 737            ChannelId,
 738            Id,
 739        }
 740
 741        let mut channel_ids_by_buffer_id = HashMap::default();
 742        let mut rows = buffer::Entity::find()
 743            .filter(buffer::Column::ChannelId.is_in(channel_ids.iter().copied()))
 744            .stream(&*tx)
 745            .await?;
 746        while let Some(row) = rows.next().await {
 747            let row = row?;
 748            channel_ids_by_buffer_id.insert(row.id, row.channel_id);
 749        }
 750        drop(rows);
 751
 752        let mut observed_edits_by_buffer_id = HashMap::default();
 753        let mut rows = observed_buffer_edits::Entity::find()
 754            .filter(observed_buffer_edits::Column::UserId.eq(user_id))
 755            .filter(
 756                observed_buffer_edits::Column::BufferId
 757                    .is_in(channel_ids_by_buffer_id.keys().copied()),
 758            )
 759            .stream(&*tx)
 760            .await?;
 761        while let Some(row) = rows.next().await {
 762            let row = row?;
 763            observed_edits_by_buffer_id.insert(row.buffer_id, row);
 764        }
 765        drop(rows);
 766
 767        let latest_operations = self
 768            .get_latest_operations_for_buffers(channel_ids_by_buffer_id.keys().copied(), &*tx)
 769            .await?;
 770
 771        let mut changes = Vec::default();
 772        for latest in latest_operations {
 773            if let Some(observed) = observed_edits_by_buffer_id.get(&latest.buffer_id) {
 774                if (
 775                    observed.epoch,
 776                    observed.lamport_timestamp,
 777                    observed.replica_id,
 778                ) >= (latest.epoch, latest.lamport_timestamp, latest.replica_id)
 779                {
 780                    continue;
 781                }
 782            }
 783
 784            if let Some(channel_id) = channel_ids_by_buffer_id.get(&latest.buffer_id) {
 785                changes.push(proto::UnseenChannelBufferChange {
 786                    channel_id: channel_id.to_proto(),
 787                    epoch: latest.epoch as u64,
 788                    version: vec![proto::VectorClockEntry {
 789                        replica_id: latest.replica_id as u32,
 790                        timestamp: latest.lamport_timestamp as u32,
 791                    }],
 792                });
 793            }
 794        }
 795
 796        Ok(changes)
 797    }
 798
 799    pub async fn get_latest_operations_for_buffers(
 800        &self,
 801        buffer_ids: impl IntoIterator<Item = BufferId>,
 802        tx: &DatabaseTransaction,
 803    ) -> Result<Vec<buffer_operation::Model>> {
 804        let mut values = String::new();
 805        for id in buffer_ids {
 806            if !values.is_empty() {
 807                values.push_str(", ");
 808            }
 809            write!(&mut values, "({})", id).unwrap();
 810        }
 811
 812        if values.is_empty() {
 813            return Ok(Vec::default());
 814        }
 815
 816        let sql = format!(
 817            r#"
 818            SELECT
 819                *
 820            FROM
 821            (
 822                SELECT
 823                    *,
 824                    row_number() OVER (
 825                        PARTITION BY buffer_id
 826                        ORDER BY
 827                            epoch DESC,
 828                            lamport_timestamp DESC,
 829                            replica_id DESC
 830                    ) as row_number
 831                FROM buffer_operations
 832                WHERE
 833                    buffer_id in ({values})
 834            ) AS last_operations
 835            WHERE
 836                row_number = 1
 837            "#,
 838        );
 839
 840        let stmt = Statement::from_string(self.pool.get_database_backend(), sql);
 841        Ok(buffer_operation::Entity::find()
 842            .from_raw_sql(stmt)
 843            .all(&*tx)
 844            .await?)
 845    }
 846}
 847
 848fn operation_to_storage(
 849    operation: &proto::Operation,
 850    buffer: &buffer::Model,
 851    _format: i32,
 852) -> Option<buffer_operation::ActiveModel> {
 853    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
 854        proto::operation::Variant::Edit(operation) => (
 855            operation.replica_id,
 856            operation.lamport_timestamp,
 857            storage::Operation {
 858                version: version_to_storage(&operation.version),
 859                is_undo: false,
 860                edit_ranges: operation
 861                    .ranges
 862                    .iter()
 863                    .map(|range| storage::Range {
 864                        start: range.start,
 865                        end: range.end,
 866                    })
 867                    .collect(),
 868                edit_texts: operation.new_text.clone(),
 869                undo_counts: Vec::new(),
 870            },
 871        ),
 872        proto::operation::Variant::Undo(operation) => (
 873            operation.replica_id,
 874            operation.lamport_timestamp,
 875            storage::Operation {
 876                version: version_to_storage(&operation.version),
 877                is_undo: true,
 878                edit_ranges: Vec::new(),
 879                edit_texts: Vec::new(),
 880                undo_counts: operation
 881                    .counts
 882                    .iter()
 883                    .map(|entry| storage::UndoCount {
 884                        replica_id: entry.replica_id,
 885                        lamport_timestamp: entry.lamport_timestamp,
 886                        count: entry.count,
 887                    })
 888                    .collect(),
 889            },
 890        ),
 891        _ => None?,
 892    };
 893
 894    Some(buffer_operation::ActiveModel {
 895        buffer_id: ActiveValue::Set(buffer.id),
 896        epoch: ActiveValue::Set(buffer.epoch),
 897        replica_id: ActiveValue::Set(replica_id as i32),
 898        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
 899        value: ActiveValue::Set(value.encode_to_vec()),
 900    })
 901}
 902
 903fn operation_from_storage(
 904    row: buffer_operation::Model,
 905    _format_version: i32,
 906) -> Result<proto::operation::Variant, Error> {
 907    let operation =
 908        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{}", error))?;
 909    let version = version_from_storage(&operation.version);
 910    Ok(if operation.is_undo {
 911        proto::operation::Variant::Undo(proto::operation::Undo {
 912            replica_id: row.replica_id as u32,
 913            lamport_timestamp: row.lamport_timestamp as u32,
 914            version,
 915            counts: operation
 916                .undo_counts
 917                .iter()
 918                .map(|entry| proto::UndoCount {
 919                    replica_id: entry.replica_id,
 920                    lamport_timestamp: entry.lamport_timestamp,
 921                    count: entry.count,
 922                })
 923                .collect(),
 924        })
 925    } else {
 926        proto::operation::Variant::Edit(proto::operation::Edit {
 927            replica_id: row.replica_id as u32,
 928            lamport_timestamp: row.lamport_timestamp as u32,
 929            version,
 930            ranges: operation
 931                .edit_ranges
 932                .into_iter()
 933                .map(|range| proto::Range {
 934                    start: range.start,
 935                    end: range.end,
 936                })
 937                .collect(),
 938            new_text: operation.edit_texts,
 939        })
 940    })
 941}
 942
 943fn version_to_storage(version: &Vec<proto::VectorClockEntry>) -> Vec<storage::VectorClockEntry> {
 944    version
 945        .iter()
 946        .map(|entry| storage::VectorClockEntry {
 947            replica_id: entry.replica_id,
 948            timestamp: entry.timestamp,
 949        })
 950        .collect()
 951}
 952
 953fn version_from_storage(version: &Vec<storage::VectorClockEntry>) -> Vec<proto::VectorClockEntry> {
 954    version
 955        .iter()
 956        .map(|entry| proto::VectorClockEntry {
 957            replica_id: entry.replica_id,
 958            timestamp: entry.timestamp,
 959        })
 960        .collect()
 961}
 962
 963// This is currently a manual copy of the deserialization code in the client's langauge crate
 964pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
 965    match operation.variant? {
 966        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
 967            timestamp: clock::Lamport {
 968                replica_id: edit.replica_id as text::ReplicaId,
 969                value: edit.lamport_timestamp,
 970            },
 971            version: version_from_wire(&edit.version),
 972            ranges: edit
 973                .ranges
 974                .into_iter()
 975                .map(|range| {
 976                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
 977                })
 978                .collect(),
 979            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
 980        })),
 981        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
 982            timestamp: clock::Lamport {
 983                replica_id: undo.replica_id as text::ReplicaId,
 984                value: undo.lamport_timestamp,
 985            },
 986            version: version_from_wire(&undo.version),
 987            counts: undo
 988                .counts
 989                .into_iter()
 990                .map(|c| {
 991                    (
 992                        clock::Lamport {
 993                            replica_id: c.replica_id as text::ReplicaId,
 994                            value: c.lamport_timestamp,
 995                        },
 996                        c.count,
 997                    )
 998                })
 999                .collect(),
1000        })),
1001        _ => None,
1002    }
1003}
1004
1005fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
1006    let mut version = clock::Global::new();
1007    for entry in message {
1008        version.observe(clock::Lamport {
1009            replica_id: entry.replica_id as text::ReplicaId,
1010            value: entry.timestamp,
1011        });
1012    }
1013    version
1014}
1015
1016fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
1017    let mut message = Vec::new();
1018    for entry in version.iter() {
1019        message.push(proto::VectorClockEntry {
1020            replica_id: entry.replica_id as u32,
1021            timestamp: entry.value,
1022        });
1023    }
1024    message
1025}
1026
1027#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1028enum QueryOperationSerializationVersion {
1029    OperationSerializationVersion,
1030}
1031
1032mod storage {
1033    #![allow(non_snake_case)]
1034    use prost::Message;
1035    pub const SERIALIZATION_VERSION: i32 = 1;
1036
1037    #[derive(Message)]
1038    pub struct Operation {
1039        #[prost(message, repeated, tag = "2")]
1040        pub version: Vec<VectorClockEntry>,
1041        #[prost(bool, tag = "3")]
1042        pub is_undo: bool,
1043        #[prost(message, repeated, tag = "4")]
1044        pub edit_ranges: Vec<Range>,
1045        #[prost(string, repeated, tag = "5")]
1046        pub edit_texts: Vec<String>,
1047        #[prost(message, repeated, tag = "6")]
1048        pub undo_counts: Vec<UndoCount>,
1049    }
1050
1051    #[derive(Message)]
1052    pub struct VectorClockEntry {
1053        #[prost(uint32, tag = "1")]
1054        pub replica_id: u32,
1055        #[prost(uint32, tag = "2")]
1056        pub timestamp: u32,
1057    }
1058
1059    #[derive(Message)]
1060    pub struct Range {
1061        #[prost(uint64, tag = "1")]
1062        pub start: u64,
1063        #[prost(uint64, tag = "2")]
1064        pub end: u64,
1065    }
1066
1067    #[derive(Message)]
1068    pub struct UndoCount {
1069        #[prost(uint32, tag = "1")]
1070        pub replica_id: u32,
1071        #[prost(uint32, tag = "2")]
1072        pub lamport_timestamp: u32,
1073        #[prost(uint32, tag = "3")]
1074        pub count: u32,
1075    }
1076}