buffers.rs

   1use super::*;
   2use anyhow::Context as _;
   3use prost::Message;
   4use text::{EditOperation, UndoOperation};
   5
   6pub struct LeftChannelBuffer {
   7    pub channel_id: ChannelId,
   8    pub collaborators: Vec<proto::Collaborator>,
   9    pub connections: Vec<ConnectionId>,
  10}
  11
  12impl Database {
  13    /// Open a channel buffer. Returns the current contents, and adds you to the list of people
  14    /// to notify on changes.
  15    pub async fn join_channel_buffer(
  16        &self,
  17        channel_id: ChannelId,
  18        user_id: UserId,
  19        connection: ConnectionId,
  20    ) -> Result<proto::JoinChannelBufferResponse> {
  21        self.transaction(|tx| async move {
  22            let channel = self.get_channel_internal(channel_id, &tx).await?;
  23            self.check_user_is_channel_participant(&channel, user_id, &tx)
  24                .await?;
  25
  26            let buffer = channel::Model {
  27                id: channel_id,
  28                ..Default::default()
  29            }
  30            .find_related(buffer::Entity)
  31            .one(&*tx)
  32            .await?;
  33
  34            let buffer = if let Some(buffer) = buffer {
  35                buffer
  36            } else {
  37                let buffer = buffer::ActiveModel {
  38                    channel_id: ActiveValue::Set(channel_id),
  39                    ..Default::default()
  40                }
  41                .insert(&*tx)
  42                .await?;
  43                buffer_snapshot::ActiveModel {
  44                    buffer_id: ActiveValue::Set(buffer.id),
  45                    epoch: ActiveValue::Set(0),
  46                    text: ActiveValue::Set(String::new()),
  47                    operation_serialization_version: ActiveValue::Set(
  48                        storage::SERIALIZATION_VERSION,
  49                    ),
  50                }
  51                .insert(&*tx)
  52                .await?;
  53                buffer
  54            };
  55
  56            // Join the collaborators
  57            let mut collaborators = channel_buffer_collaborator::Entity::find()
  58                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
  59                .all(&*tx)
  60                .await?;
  61            let replica_ids = collaborators
  62                .iter()
  63                .map(|c| c.replica_id)
  64                .collect::<HashSet<_>>();
  65            let mut replica_id = ReplicaId(clock::ReplicaId::FIRST_COLLAB_ID.as_u16() as i32);
  66            while replica_ids.contains(&replica_id) {
  67                replica_id = ReplicaId(replica_id.0 + 1);
  68            }
  69            let collaborator = channel_buffer_collaborator::ActiveModel {
  70                channel_id: ActiveValue::Set(channel_id),
  71                connection_id: ActiveValue::Set(connection.id as i32),
  72                connection_server_id: ActiveValue::Set(ServerId(connection.owner_id as i32)),
  73                user_id: ActiveValue::Set(user_id),
  74                replica_id: ActiveValue::Set(replica_id),
  75                ..Default::default()
  76            }
  77            .insert(&*tx)
  78            .await?;
  79            collaborators.push(collaborator);
  80
  81            let (base_text, operations, max_operation) =
  82                self.get_buffer_state(&buffer, &tx).await?;
  83
  84            // Save the last observed operation
  85            if let Some(op) = max_operation {
  86                observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
  87                    user_id: ActiveValue::Set(user_id),
  88                    buffer_id: ActiveValue::Set(buffer.id),
  89                    epoch: ActiveValue::Set(op.epoch),
  90                    lamport_timestamp: ActiveValue::Set(op.lamport_timestamp),
  91                    replica_id: ActiveValue::Set(op.replica_id),
  92                })
  93                .on_conflict(
  94                    OnConflict::columns([
  95                        observed_buffer_edits::Column::UserId,
  96                        observed_buffer_edits::Column::BufferId,
  97                    ])
  98                    .update_columns([
  99                        observed_buffer_edits::Column::Epoch,
 100                        observed_buffer_edits::Column::LamportTimestamp,
 101                    ])
 102                    .to_owned(),
 103                )
 104                .exec(&*tx)
 105                .await?;
 106            }
 107
 108            Ok(proto::JoinChannelBufferResponse {
 109                buffer_id: buffer.id.to_proto(),
 110                replica_id: replica_id.to_proto() as u32,
 111                base_text,
 112                operations,
 113                epoch: buffer.epoch as u64,
 114                collaborators: collaborators
 115                    .into_iter()
 116                    .map(|collaborator| proto::Collaborator {
 117                        peer_id: Some(collaborator.connection().into()),
 118                        user_id: collaborator.user_id.to_proto(),
 119                        replica_id: collaborator.replica_id.0 as u32,
 120                        is_host: false,
 121                        committer_name: None,
 122                        committer_email: None,
 123                    })
 124                    .collect(),
 125            })
 126        })
 127        .await
 128    }
 129
 130    /// Rejoin a channel buffer (after a connection interruption)
 131    pub async fn rejoin_channel_buffers(
 132        &self,
 133        buffers: &[proto::ChannelBufferVersion],
 134        user_id: UserId,
 135        connection_id: ConnectionId,
 136    ) -> Result<Vec<RejoinedChannelBuffer>> {
 137        self.transaction(|tx| async move {
 138            let mut results = Vec::new();
 139            for client_buffer in buffers {
 140                let channel = self
 141                    .get_channel_internal(ChannelId::from_proto(client_buffer.channel_id), &tx)
 142                    .await?;
 143                if self
 144                    .check_user_is_channel_participant(&channel, user_id, &tx)
 145                    .await
 146                    .is_err()
 147                {
 148                    log::info!("user is not a member of channel");
 149                    continue;
 150                }
 151
 152                let buffer = self.get_channel_buffer(channel.id, &tx).await?;
 153                let mut collaborators = channel_buffer_collaborator::Entity::find()
 154                    .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel.id))
 155                    .all(&*tx)
 156                    .await?;
 157
 158                // If the buffer epoch hasn't changed since the client lost
 159                // connection, then the client's buffer can be synchronized with
 160                // the server's buffer.
 161                if buffer.epoch as u64 != client_buffer.epoch {
 162                    log::info!("can't rejoin buffer, epoch has changed");
 163                    continue;
 164                }
 165
 166                // Find the collaborator record for this user's previous lost
 167                // connection. Update it with the new connection id.
 168                let Some(self_collaborator) =
 169                    collaborators.iter_mut().find(|c| c.user_id == user_id)
 170                else {
 171                    log::info!("can't rejoin buffer, no previous collaborator found");
 172                    continue;
 173                };
 174                let old_connection_id = self_collaborator.connection();
 175                *self_collaborator = channel_buffer_collaborator::ActiveModel {
 176                    id: ActiveValue::Unchanged(self_collaborator.id),
 177                    connection_id: ActiveValue::Set(connection_id.id as i32),
 178                    connection_server_id: ActiveValue::Set(ServerId(connection_id.owner_id as i32)),
 179                    connection_lost: ActiveValue::Set(false),
 180                    ..Default::default()
 181                }
 182                .update(&*tx)
 183                .await?;
 184
 185                let client_version = version_from_wire(&client_buffer.version);
 186                let serialization_version = self
 187                    .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &tx)
 188                    .await?;
 189
 190                let mut rows = buffer_operation::Entity::find()
 191                    .filter(
 192                        buffer_operation::Column::BufferId
 193                            .eq(buffer.id)
 194                            .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 195                    )
 196                    .stream(&*tx)
 197                    .await?;
 198
 199                // Find the server's version vector and any operations
 200                // that the client has not seen.
 201                let mut server_version = clock::Global::new();
 202                let mut operations = Vec::new();
 203                while let Some(row) = rows.next().await {
 204                    let row = row?;
 205                    let timestamp = clock::Lamport {
 206                        replica_id: clock::ReplicaId::new(row.replica_id as u16),
 207                        value: row.lamport_timestamp as u32,
 208                    };
 209                    server_version.observe(timestamp);
 210                    if !client_version.observed(timestamp) {
 211                        operations.push(proto::Operation {
 212                            variant: Some(operation_from_storage(row, serialization_version)?),
 213                        })
 214                    }
 215                }
 216
 217                results.push(RejoinedChannelBuffer {
 218                    old_connection_id,
 219                    buffer: proto::RejoinedChannelBuffer {
 220                        channel_id: client_buffer.channel_id,
 221                        version: version_to_wire(&server_version),
 222                        operations,
 223                        collaborators: collaborators
 224                            .into_iter()
 225                            .map(|collaborator| proto::Collaborator {
 226                                peer_id: Some(collaborator.connection().into()),
 227                                user_id: collaborator.user_id.to_proto(),
 228                                replica_id: collaborator.replica_id.0 as u32,
 229                                is_host: false,
 230                                committer_name: None,
 231                                committer_email: None,
 232                            })
 233                            .collect(),
 234                    },
 235                });
 236            }
 237
 238            Ok(results)
 239        })
 240        .await
 241    }
 242
 243    /// Clear out any buffer collaborators who are no longer collaborating.
 244    pub async fn clear_stale_channel_buffer_collaborators(
 245        &self,
 246        channel_id: ChannelId,
 247        server_id: ServerId,
 248    ) -> Result<RefreshedChannelBuffer> {
 249        self.transaction(|tx| async move {
 250            let db_collaborators = channel_buffer_collaborator::Entity::find()
 251                .filter(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 252                .all(&*tx)
 253                .await?;
 254
 255            let mut connection_ids = Vec::new();
 256            let mut collaborators = Vec::new();
 257            let mut collaborator_ids_to_remove = Vec::new();
 258            for db_collaborator in &db_collaborators {
 259                if !db_collaborator.connection_lost
 260                    && db_collaborator.connection_server_id == server_id
 261                {
 262                    connection_ids.push(db_collaborator.connection());
 263                    collaborators.push(proto::Collaborator {
 264                        peer_id: Some(db_collaborator.connection().into()),
 265                        replica_id: db_collaborator.replica_id.0 as u32,
 266                        user_id: db_collaborator.user_id.to_proto(),
 267                        is_host: false,
 268                        committer_name: None,
 269                        committer_email: None,
 270                    })
 271                } else {
 272                    collaborator_ids_to_remove.push(db_collaborator.id);
 273                }
 274            }
 275
 276            channel_buffer_collaborator::Entity::delete_many()
 277                .filter(channel_buffer_collaborator::Column::Id.is_in(collaborator_ids_to_remove))
 278                .exec(&*tx)
 279                .await?;
 280
 281            Ok(RefreshedChannelBuffer {
 282                connection_ids,
 283                collaborators,
 284            })
 285        })
 286        .await
 287    }
 288
 289    /// Close the channel buffer, and stop receiving updates for it.
 290    pub async fn leave_channel_buffer(
 291        &self,
 292        channel_id: ChannelId,
 293        connection: ConnectionId,
 294    ) -> Result<LeftChannelBuffer> {
 295        self.transaction(|tx| async move {
 296            self.leave_channel_buffer_internal(channel_id, connection, &tx)
 297                .await
 298        })
 299        .await
 300    }
 301
 302    /// Close the channel buffer, and stop receiving updates for it.
 303    pub async fn channel_buffer_connection_lost(
 304        &self,
 305        connection: ConnectionId,
 306        tx: &DatabaseTransaction,
 307    ) -> Result<()> {
 308        channel_buffer_collaborator::Entity::update_many()
 309            .filter(
 310                Condition::all()
 311                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 312                    .add(
 313                        channel_buffer_collaborator::Column::ConnectionServerId
 314                            .eq(connection.owner_id as i32),
 315                    ),
 316            )
 317            .set(channel_buffer_collaborator::ActiveModel {
 318                connection_lost: ActiveValue::set(true),
 319                ..Default::default()
 320            })
 321            .exec(tx)
 322            .await?;
 323        Ok(())
 324    }
 325
 326    /// Close all open channel buffers
 327    pub async fn leave_channel_buffers(
 328        &self,
 329        connection: ConnectionId,
 330    ) -> Result<Vec<LeftChannelBuffer>> {
 331        self.transaction(|tx| async move {
 332            #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 333            enum QueryChannelIds {
 334                ChannelId,
 335            }
 336
 337            let channel_ids: Vec<ChannelId> = channel_buffer_collaborator::Entity::find()
 338                .select_only()
 339                .column(channel_buffer_collaborator::Column::ChannelId)
 340                .filter(Condition::all().add(
 341                    channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32),
 342                ))
 343                .into_values::<_, QueryChannelIds>()
 344                .all(&*tx)
 345                .await?;
 346
 347            let mut result = Vec::new();
 348            for channel_id in channel_ids {
 349                let left_channel_buffer = self
 350                    .leave_channel_buffer_internal(channel_id, connection, &tx)
 351                    .await?;
 352                result.push(left_channel_buffer);
 353            }
 354
 355            Ok(result)
 356        })
 357        .await
 358    }
 359
 360    async fn leave_channel_buffer_internal(
 361        &self,
 362        channel_id: ChannelId,
 363        connection: ConnectionId,
 364        tx: &DatabaseTransaction,
 365    ) -> Result<LeftChannelBuffer> {
 366        let result = channel_buffer_collaborator::Entity::delete_many()
 367            .filter(
 368                Condition::all()
 369                    .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id))
 370                    .add(channel_buffer_collaborator::Column::ConnectionId.eq(connection.id as i32))
 371                    .add(
 372                        channel_buffer_collaborator::Column::ConnectionServerId
 373                            .eq(connection.owner_id as i32),
 374                    ),
 375            )
 376            .exec(tx)
 377            .await?;
 378        if result.rows_affected == 0 {
 379            Err(anyhow!("not a collaborator on this project"))?;
 380        }
 381
 382        let mut collaborators = Vec::new();
 383        let mut connections = Vec::new();
 384        let mut rows = channel_buffer_collaborator::Entity::find()
 385            .filter(
 386                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 387            )
 388            .stream(tx)
 389            .await?;
 390        while let Some(row) = rows.next().await {
 391            let row = row?;
 392            let connection = row.connection();
 393            connections.push(connection);
 394            collaborators.push(proto::Collaborator {
 395                peer_id: Some(connection.into()),
 396                replica_id: row.replica_id.0 as u32,
 397                user_id: row.user_id.to_proto(),
 398                is_host: false,
 399                committer_name: None,
 400                committer_email: None,
 401            });
 402        }
 403
 404        drop(rows);
 405
 406        if collaborators.is_empty() {
 407            self.snapshot_channel_buffer(channel_id, tx).await?;
 408        }
 409
 410        Ok(LeftChannelBuffer {
 411            channel_id,
 412            collaborators,
 413            connections,
 414        })
 415    }
 416
 417    pub async fn get_channel_buffer_collaborators(
 418        &self,
 419        channel_id: ChannelId,
 420    ) -> Result<Vec<UserId>> {
 421        self.transaction(|tx| async move {
 422            self.get_channel_buffer_collaborators_internal(channel_id, &tx)
 423                .await
 424        })
 425        .await
 426    }
 427
 428    async fn get_channel_buffer_collaborators_internal(
 429        &self,
 430        channel_id: ChannelId,
 431        tx: &DatabaseTransaction,
 432    ) -> Result<Vec<UserId>> {
 433        #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
 434        enum QueryUserIds {
 435            UserId,
 436        }
 437
 438        let users: Vec<UserId> = channel_buffer_collaborator::Entity::find()
 439            .select_only()
 440            .column(channel_buffer_collaborator::Column::UserId)
 441            .filter(
 442                Condition::all().add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 443            )
 444            .into_values::<_, QueryUserIds>()
 445            .all(tx)
 446            .await?;
 447
 448        Ok(users)
 449    }
 450
 451    pub async fn update_channel_buffer(
 452        &self,
 453        channel_id: ChannelId,
 454        user: UserId,
 455        operations: &[proto::Operation],
 456    ) -> Result<(HashSet<ConnectionId>, i32, Vec<proto::VectorClockEntry>)> {
 457        self.transaction(move |tx| async move {
 458            let channel = self.get_channel_internal(channel_id, &tx).await?;
 459
 460            let mut requires_write_permission = false;
 461            for op in operations.iter() {
 462                match op.variant {
 463                    None | Some(proto::operation::Variant::UpdateSelections(_)) => {}
 464                    Some(_) => requires_write_permission = true,
 465                }
 466            }
 467            if requires_write_permission {
 468                self.check_user_is_channel_member(&channel, user, &tx)
 469                    .await?;
 470            } else {
 471                self.check_user_is_channel_participant(&channel, user, &tx)
 472                    .await?;
 473            }
 474
 475            let buffer = buffer::Entity::find()
 476                .filter(buffer::Column::ChannelId.eq(channel_id))
 477                .one(&*tx)
 478                .await?
 479                .context("no such buffer")?;
 480
 481            let serialization_version = self
 482                .get_buffer_operation_serialization_version(buffer.id, buffer.epoch, &tx)
 483                .await?;
 484
 485            let operations = operations
 486                .iter()
 487                .filter_map(|op| operation_to_storage(op, &buffer, serialization_version))
 488                .collect::<Vec<_>>();
 489
 490            let max_version;
 491
 492            if !operations.is_empty() {
 493                let max_operation = operations
 494                    .iter()
 495                    .max_by_key(|op| (op.lamport_timestamp.as_ref(), op.replica_id.as_ref()))
 496                    .unwrap();
 497
 498                max_version = vec![proto::VectorClockEntry {
 499                    replica_id: *max_operation.replica_id.as_ref() as u32,
 500                    timestamp: *max_operation.lamport_timestamp.as_ref() as u32,
 501                }];
 502
 503                // get current channel participants and save the max operation above
 504                self.save_max_operation(
 505                    user,
 506                    buffer.id,
 507                    buffer.epoch,
 508                    *max_operation.replica_id.as_ref(),
 509                    *max_operation.lamport_timestamp.as_ref(),
 510                    &tx,
 511                )
 512                .await?;
 513
 514                buffer_operation::Entity::insert_many(operations)
 515                    .on_conflict(
 516                        OnConflict::columns([
 517                            buffer_operation::Column::BufferId,
 518                            buffer_operation::Column::Epoch,
 519                            buffer_operation::Column::LamportTimestamp,
 520                            buffer_operation::Column::ReplicaId,
 521                        ])
 522                        .do_nothing()
 523                        .to_owned(),
 524                    )
 525                    .exec(&*tx)
 526                    .await?;
 527            } else {
 528                max_version = Vec::new();
 529            }
 530
 531            let mut connections = HashSet::default();
 532            let mut rows = channel_buffer_collaborator::Entity::find()
 533                .filter(
 534                    Condition::all()
 535                        .add(channel_buffer_collaborator::Column::ChannelId.eq(channel_id)),
 536                )
 537                .stream(&*tx)
 538                .await?;
 539            while let Some(row) = rows.next().await {
 540                let row = row?;
 541                connections.insert(ConnectionId {
 542                    id: row.connection_id as u32,
 543                    owner_id: row.connection_server_id.0 as u32,
 544                });
 545            }
 546
 547            Ok((connections, buffer.epoch, max_version))
 548        })
 549        .await
 550    }
 551
 552    async fn save_max_operation(
 553        &self,
 554        user_id: UserId,
 555        buffer_id: BufferId,
 556        epoch: i32,
 557        replica_id: i32,
 558        lamport_timestamp: i32,
 559        tx: &DatabaseTransaction,
 560    ) -> Result<()> {
 561        buffer::Entity::update(buffer::ActiveModel {
 562            id: ActiveValue::Unchanged(buffer_id),
 563            epoch: ActiveValue::Unchanged(epoch),
 564            latest_operation_epoch: ActiveValue::Set(Some(epoch)),
 565            latest_operation_replica_id: ActiveValue::Set(Some(replica_id)),
 566            latest_operation_lamport_timestamp: ActiveValue::Set(Some(lamport_timestamp)),
 567            channel_id: ActiveValue::NotSet,
 568        })
 569        .exec(tx)
 570        .await?;
 571
 572        use observed_buffer_edits::Column;
 573        observed_buffer_edits::Entity::insert(observed_buffer_edits::ActiveModel {
 574            user_id: ActiveValue::Set(user_id),
 575            buffer_id: ActiveValue::Set(buffer_id),
 576            epoch: ActiveValue::Set(epoch),
 577            replica_id: ActiveValue::Set(replica_id),
 578            lamport_timestamp: ActiveValue::Set(lamport_timestamp),
 579        })
 580        .on_conflict(
 581            OnConflict::columns([Column::UserId, Column::BufferId])
 582                .update_columns([Column::Epoch, Column::LamportTimestamp, Column::ReplicaId])
 583                .action_cond_where(
 584                    Condition::any().add(Column::Epoch.lt(epoch)).add(
 585                        Condition::all().add(Column::Epoch.eq(epoch)).add(
 586                            Condition::any()
 587                                .add(Column::LamportTimestamp.lt(lamport_timestamp))
 588                                .add(
 589                                    Column::LamportTimestamp
 590                                        .eq(lamport_timestamp)
 591                                        .and(Column::ReplicaId.lt(replica_id)),
 592                                ),
 593                        ),
 594                    ),
 595                )
 596                .to_owned(),
 597        )
 598        .exec_without_returning(tx)
 599        .await?;
 600
 601        Ok(())
 602    }
 603
 604    async fn get_buffer_operation_serialization_version(
 605        &self,
 606        buffer_id: BufferId,
 607        epoch: i32,
 608        tx: &DatabaseTransaction,
 609    ) -> Result<i32> {
 610        Ok(buffer_snapshot::Entity::find()
 611            .filter(buffer_snapshot::Column::BufferId.eq(buffer_id))
 612            .filter(buffer_snapshot::Column::Epoch.eq(epoch))
 613            .select_only()
 614            .column(buffer_snapshot::Column::OperationSerializationVersion)
 615            .into_values::<_, QueryOperationSerializationVersion>()
 616            .one(tx)
 617            .await?
 618            .context("missing buffer snapshot")?)
 619    }
 620
 621    pub async fn get_channel_buffer(
 622        &self,
 623        channel_id: ChannelId,
 624        tx: &DatabaseTransaction,
 625    ) -> Result<buffer::Model> {
 626        Ok(channel::Model {
 627            id: channel_id,
 628            ..Default::default()
 629        }
 630        .find_related(buffer::Entity)
 631        .one(tx)
 632        .await?
 633        .context("no such buffer")?)
 634    }
 635
 636    async fn get_buffer_state(
 637        &self,
 638        buffer: &buffer::Model,
 639        tx: &DatabaseTransaction,
 640    ) -> Result<(
 641        String,
 642        Vec<proto::Operation>,
 643        Option<buffer_operation::Model>,
 644    )> {
 645        let id = buffer.id;
 646        let (base_text, version) = if buffer.epoch > 0 {
 647            let snapshot = buffer_snapshot::Entity::find()
 648                .filter(
 649                    buffer_snapshot::Column::BufferId
 650                        .eq(id)
 651                        .and(buffer_snapshot::Column::Epoch.eq(buffer.epoch)),
 652                )
 653                .one(tx)
 654                .await?
 655                .context("no such snapshot")?;
 656
 657            let version = snapshot.operation_serialization_version;
 658            (snapshot.text, version)
 659        } else {
 660            (String::new(), storage::SERIALIZATION_VERSION)
 661        };
 662
 663        let mut rows = buffer_operation::Entity::find()
 664            .filter(
 665                buffer_operation::Column::BufferId
 666                    .eq(id)
 667                    .and(buffer_operation::Column::Epoch.eq(buffer.epoch)),
 668            )
 669            .order_by_asc(buffer_operation::Column::LamportTimestamp)
 670            .order_by_asc(buffer_operation::Column::ReplicaId)
 671            .stream(tx)
 672            .await?;
 673
 674        let mut operations = Vec::new();
 675        let mut last_row = None;
 676        while let Some(row) = rows.next().await {
 677            let row = row?;
 678            last_row = Some(buffer_operation::Model {
 679                buffer_id: row.buffer_id,
 680                epoch: row.epoch,
 681                lamport_timestamp: row.lamport_timestamp,
 682                replica_id: row.replica_id,
 683                value: Default::default(),
 684            });
 685            operations.push(proto::Operation {
 686                variant: Some(operation_from_storage(row, version)?),
 687            });
 688        }
 689
 690        Ok((base_text, operations, last_row))
 691    }
 692
 693    async fn snapshot_channel_buffer(
 694        &self,
 695        channel_id: ChannelId,
 696        tx: &DatabaseTransaction,
 697    ) -> Result<()> {
 698        let buffer = self.get_channel_buffer(channel_id, tx).await?;
 699        let (base_text, operations, _) = self.get_buffer_state(&buffer, tx).await?;
 700        if operations.is_empty() {
 701            return Ok(());
 702        }
 703
 704        let mut text_buffer = text::Buffer::new_slow(
 705            clock::ReplicaId::LOCAL,
 706            text::BufferId::new(1).unwrap(),
 707            base_text,
 708        );
 709        text_buffer.apply_ops(operations.into_iter().filter_map(operation_from_wire), None);
 710
 711        let base_text = text_buffer.text();
 712        let epoch = buffer.epoch + 1;
 713
 714        buffer_snapshot::Model {
 715            buffer_id: buffer.id,
 716            epoch,
 717            text: base_text,
 718            operation_serialization_version: storage::SERIALIZATION_VERSION,
 719        }
 720        .into_active_model()
 721        .insert(tx)
 722        .await?;
 723
 724        buffer::ActiveModel {
 725            id: ActiveValue::Unchanged(buffer.id),
 726            epoch: ActiveValue::Set(epoch),
 727            latest_operation_epoch: ActiveValue::NotSet,
 728            latest_operation_replica_id: ActiveValue::NotSet,
 729            latest_operation_lamport_timestamp: ActiveValue::NotSet,
 730            channel_id: ActiveValue::NotSet,
 731        }
 732        .save(tx)
 733        .await?;
 734
 735        Ok(())
 736    }
 737
 738    pub async fn observe_buffer_version(
 739        &self,
 740        buffer_id: BufferId,
 741        user_id: UserId,
 742        epoch: i32,
 743        version: &[proto::VectorClockEntry],
 744    ) -> Result<()> {
 745        self.transaction(|tx| async move {
 746            // For now, combine concurrent operations.
 747            let Some(component) = version.iter().max_by_key(|version| version.timestamp) else {
 748                return Ok(());
 749            };
 750            self.save_max_operation(
 751                user_id,
 752                buffer_id,
 753                epoch,
 754                component.replica_id as i32,
 755                component.timestamp as i32,
 756                &tx,
 757            )
 758            .await?;
 759            Ok(())
 760        })
 761        .await
 762    }
 763
 764    pub async fn observed_channel_buffer_changes(
 765        &self,
 766        channel_ids_by_buffer_id: &HashMap<BufferId, ChannelId>,
 767        user_id: UserId,
 768        tx: &DatabaseTransaction,
 769    ) -> Result<Vec<proto::ChannelBufferVersion>> {
 770        let observed_operations = observed_buffer_edits::Entity::find()
 771            .filter(observed_buffer_edits::Column::UserId.eq(user_id))
 772            .filter(
 773                observed_buffer_edits::Column::BufferId
 774                    .is_in(channel_ids_by_buffer_id.keys().copied()),
 775            )
 776            .all(tx)
 777            .await?;
 778
 779        Ok(observed_operations
 780            .iter()
 781            .flat_map(|op| {
 782                Some(proto::ChannelBufferVersion {
 783                    channel_id: channel_ids_by_buffer_id.get(&op.buffer_id)?.to_proto(),
 784                    epoch: op.epoch as u64,
 785                    version: vec![proto::VectorClockEntry {
 786                        replica_id: op.replica_id as u32,
 787                        timestamp: op.lamport_timestamp as u32,
 788                    }],
 789                })
 790            })
 791            .collect())
 792    }
 793
 794    /// Update language server capabilities for a given id.
 795    pub async fn update_server_capabilities(
 796        &self,
 797        project_id: ProjectId,
 798        server_id: u64,
 799        new_capabilities: String,
 800    ) -> Result<()> {
 801        self.transaction(|tx| {
 802            let new_capabilities = new_capabilities.clone();
 803            async move {
 804                Ok(
 805                    language_server::Entity::update(language_server::ActiveModel {
 806                        project_id: ActiveValue::unchanged(project_id),
 807                        id: ActiveValue::unchanged(server_id as i64),
 808                        capabilities: ActiveValue::set(new_capabilities),
 809                        ..Default::default()
 810                    })
 811                    .exec(&*tx)
 812                    .await?,
 813                )
 814            }
 815        })
 816        .await?;
 817        Ok(())
 818    }
 819}
 820
 821fn operation_to_storage(
 822    operation: &proto::Operation,
 823    buffer: &buffer::Model,
 824    _format: i32,
 825) -> Option<buffer_operation::ActiveModel> {
 826    let (replica_id, lamport_timestamp, value) = match operation.variant.as_ref()? {
 827        proto::operation::Variant::Edit(operation) => (
 828            operation.replica_id,
 829            operation.lamport_timestamp,
 830            storage::Operation {
 831                version: version_to_storage(&operation.version),
 832                is_undo: false,
 833                edit_ranges: operation
 834                    .ranges
 835                    .iter()
 836                    .map(|range| storage::Range {
 837                        start: range.start,
 838                        end: range.end,
 839                    })
 840                    .collect(),
 841                edit_texts: operation.new_text.clone(),
 842                undo_counts: Vec::new(),
 843            },
 844        ),
 845        proto::operation::Variant::Undo(operation) => (
 846            operation.replica_id,
 847            operation.lamport_timestamp,
 848            storage::Operation {
 849                version: version_to_storage(&operation.version),
 850                is_undo: true,
 851                edit_ranges: Vec::new(),
 852                edit_texts: Vec::new(),
 853                undo_counts: operation
 854                    .counts
 855                    .iter()
 856                    .map(|entry| storage::UndoCount {
 857                        replica_id: entry.replica_id,
 858                        lamport_timestamp: entry.lamport_timestamp,
 859                        count: entry.count,
 860                    })
 861                    .collect(),
 862            },
 863        ),
 864        _ => None?,
 865    };
 866
 867    Some(buffer_operation::ActiveModel {
 868        buffer_id: ActiveValue::Set(buffer.id),
 869        epoch: ActiveValue::Set(buffer.epoch),
 870        replica_id: ActiveValue::Set(replica_id as i32),
 871        lamport_timestamp: ActiveValue::Set(lamport_timestamp as i32),
 872        value: ActiveValue::Set(value.encode_to_vec()),
 873    })
 874}
 875
 876fn operation_from_storage(
 877    row: buffer_operation::Model,
 878    _format_version: i32,
 879) -> Result<proto::operation::Variant, Error> {
 880    let operation =
 881        storage::Operation::decode(row.value.as_slice()).map_err(|error| anyhow!("{error}"))?;
 882    let version = version_from_storage(&operation.version);
 883    Ok(if operation.is_undo {
 884        proto::operation::Variant::Undo(proto::operation::Undo {
 885            replica_id: row.replica_id as u32,
 886            lamport_timestamp: row.lamport_timestamp as u32,
 887            version,
 888            counts: operation
 889                .undo_counts
 890                .iter()
 891                .map(|entry| proto::UndoCount {
 892                    replica_id: entry.replica_id,
 893                    lamport_timestamp: entry.lamport_timestamp,
 894                    count: entry.count,
 895                })
 896                .collect(),
 897        })
 898    } else {
 899        proto::operation::Variant::Edit(proto::operation::Edit {
 900            replica_id: row.replica_id as u32,
 901            lamport_timestamp: row.lamport_timestamp as u32,
 902            version,
 903            ranges: operation
 904                .edit_ranges
 905                .into_iter()
 906                .map(|range| proto::Range {
 907                    start: range.start,
 908                    end: range.end,
 909                })
 910                .collect(),
 911            new_text: operation.edit_texts,
 912        })
 913    })
 914}
 915
 916fn version_to_storage(version: &[proto::VectorClockEntry]) -> Vec<storage::VectorClockEntry> {
 917    version
 918        .iter()
 919        .map(|entry| storage::VectorClockEntry {
 920            replica_id: entry.replica_id,
 921            timestamp: entry.timestamp,
 922        })
 923        .collect()
 924}
 925
 926fn version_from_storage(version: &[storage::VectorClockEntry]) -> Vec<proto::VectorClockEntry> {
 927    version
 928        .iter()
 929        .map(|entry| proto::VectorClockEntry {
 930            replica_id: entry.replica_id,
 931            timestamp: entry.timestamp,
 932        })
 933        .collect()
 934}
 935
 936// This is currently a manual copy of the deserialization code in the client's language crate
 937pub fn operation_from_wire(operation: proto::Operation) -> Option<text::Operation> {
 938    match operation.variant? {
 939        proto::operation::Variant::Edit(edit) => Some(text::Operation::Edit(EditOperation {
 940            timestamp: clock::Lamport {
 941                replica_id: clock::ReplicaId::new(edit.replica_id as u16),
 942                value: edit.lamport_timestamp,
 943            },
 944            version: version_from_wire(&edit.version),
 945            ranges: edit
 946                .ranges
 947                .into_iter()
 948                .map(|range| {
 949                    text::FullOffset(range.start as usize)..text::FullOffset(range.end as usize)
 950                })
 951                .collect(),
 952            new_text: edit.new_text.into_iter().map(Arc::from).collect(),
 953        })),
 954        proto::operation::Variant::Undo(undo) => Some(text::Operation::Undo(UndoOperation {
 955            timestamp: clock::Lamport {
 956                replica_id: clock::ReplicaId::new(undo.replica_id as u16),
 957                value: undo.lamport_timestamp,
 958            },
 959            version: version_from_wire(&undo.version),
 960            counts: undo
 961                .counts
 962                .into_iter()
 963                .map(|c| {
 964                    (
 965                        clock::Lamport {
 966                            replica_id: clock::ReplicaId::new(c.replica_id as u16),
 967                            value: c.lamport_timestamp,
 968                        },
 969                        c.count,
 970                    )
 971                })
 972                .collect(),
 973        })),
 974        _ => None,
 975    }
 976}
 977
 978fn version_from_wire(message: &[proto::VectorClockEntry]) -> clock::Global {
 979    let mut version = clock::Global::new();
 980    for entry in message {
 981        version.observe(clock::Lamport {
 982            replica_id: clock::ReplicaId::new(entry.replica_id as u16),
 983            value: entry.timestamp,
 984        });
 985    }
 986    version
 987}
 988
 989fn version_to_wire(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
 990    let mut message = Vec::new();
 991    for entry in version.iter() {
 992        message.push(proto::VectorClockEntry {
 993            replica_id: entry.replica_id.as_u16() as u32,
 994            timestamp: entry.value,
 995        });
 996    }
 997    message
 998}
 999
1000#[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
1001enum QueryOperationSerializationVersion {
1002    OperationSerializationVersion,
1003}
1004
1005mod storage {
1006    #![allow(non_snake_case)]
1007    use prost::Message;
1008    pub const SERIALIZATION_VERSION: i32 = 1;
1009
1010    #[derive(Message)]
1011    pub struct Operation {
1012        #[prost(message, repeated, tag = "2")]
1013        pub version: Vec<VectorClockEntry>,
1014        #[prost(bool, tag = "3")]
1015        pub is_undo: bool,
1016        #[prost(message, repeated, tag = "4")]
1017        pub edit_ranges: Vec<Range>,
1018        #[prost(string, repeated, tag = "5")]
1019        pub edit_texts: Vec<String>,
1020        #[prost(message, repeated, tag = "6")]
1021        pub undo_counts: Vec<UndoCount>,
1022    }
1023
1024    #[derive(Message)]
1025    pub struct VectorClockEntry {
1026        #[prost(uint32, tag = "1")]
1027        pub replica_id: u32,
1028        #[prost(uint32, tag = "2")]
1029        pub timestamp: u32,
1030    }
1031
1032    #[derive(Message)]
1033    pub struct Range {
1034        #[prost(uint64, tag = "1")]
1035        pub start: u64,
1036        #[prost(uint64, tag = "2")]
1037        pub end: u64,
1038    }
1039
1040    #[derive(Message)]
1041    pub struct UndoCount {
1042        #[prost(uint32, tag = "1")]
1043        pub replica_id: u32,
1044        #[prost(uint32, tag = "2")]
1045        pub lamport_timestamp: u32,
1046        #[prost(uint32, tag = "3")]
1047        pub count: u32,
1048    }
1049}